@@ -23,11 +23,17 @@ func TestBasicSessionStore(t *testing.T) {
2323 _ = db .Close ()
2424 })
2525
26- // Create a few sessions.
27- s1 := newSession (t , db , clock , "session 1" , nil )
28- s2 := newSession (t , db , clock , "session 2" , nil )
29- s3 := newSession (t , db , clock , "session 3" , nil )
30- s4 := newSession (t , db , clock , "session 4" , nil )
26+ // Create a few sessions. We increment the time by one second between
27+ // each session to ensure that the created at time is unique and hence
28+ // that the ListSessions method returns the sessions in a deterministic
29+ // order.
30+ s1 := newSession (t , db , clock , "session 1" )
31+ clock .SetTime (testTime .Add (time .Second ))
32+ s2 := newSession (t , db , clock , "session 2" )
33+ clock .SetTime (testTime .Add (2 * time .Second ))
34+ s3 := newSession (t , db , clock , "session 3" , withType (TypeAutopilot ))
35+ clock .SetTime (testTime .Add (3 * time .Second ))
36+ s4 := newSession (t , db , clock , "session 4" )
3137
3238 // Persist session 1. This should now succeed.
3339 require .NoError (t , db .CreateSession (s1 ))
@@ -50,6 +56,22 @@ func TestBasicSessionStore(t *testing.T) {
5056 require .NoError (t , db .CreateSession (s2 ))
5157 require .NoError (t , db .CreateSession (s3 ))
5258
59+ // Test the ListSessionsByType method.
60+ sessions , err := db .ListSessionsByType (TypeMacaroonAdmin )
61+ require .NoError (t , err )
62+ require .Equal (t , 2 , len (sessions ))
63+ assertEqualSessions (t , s1 , sessions [0 ])
64+ assertEqualSessions (t , s2 , sessions [1 ])
65+
66+ sessions , err = db .ListSessionsByType (TypeAutopilot )
67+ require .NoError (t , err )
68+ require .Equal (t , 1 , len (sessions ))
69+ assertEqualSessions (t , s3 , sessions [0 ])
70+
71+ sessions , err = db .ListSessionsByType (TypeMacaroonReadonly )
72+ require .NoError (t , err )
73+ require .Empty (t , sessions )
74+
5375 // Ensure that we can retrieve each session by both its local pub key
5476 // and by its ID.
5577 for _ , s := range []* Session {s1 , s2 , s3 } {
@@ -85,9 +107,44 @@ func TestBasicSessionStore(t *testing.T) {
85107
86108 // Now revoke the session and assert that the state is revoked.
87109 require .NoError (t , db .RevokeSession (s1 .LocalPublicKey ))
88- session1 , err = db .GetSession (s1 .LocalPublicKey )
110+ s1 , err = db .GetSession (s1 .LocalPublicKey )
111+ require .NoError (t , err )
112+ require .Equal (t , s1 .State , StateRevoked )
113+
114+ // Test that ListAllSessions works.
115+ sessions , err = db .ListAllSessions ()
116+ require .NoError (t , err )
117+ require .Equal (t , 3 , len (sessions ))
118+ assertEqualSessions (t , s1 , sessions [0 ])
119+ assertEqualSessions (t , s2 , sessions [1 ])
120+ assertEqualSessions (t , s3 , sessions [2 ])
121+
122+ // Test that ListSessionsByState works.
123+ sessions , err = db .ListSessionsByState (StateRevoked )
124+ require .NoError (t , err )
125+ require .Equal (t , 1 , len (sessions ))
126+ assertEqualSessions (t , s1 , sessions [0 ])
127+
128+ sessions , err = db .ListSessionsByState (StateCreated )
129+ require .NoError (t , err )
130+ require .Equal (t , 2 , len (sessions ))
131+ assertEqualSessions (t , s2 , sessions [0 ])
132+ assertEqualSessions (t , s3 , sessions [1 ])
133+
134+ sessions , err = db .ListSessionsByState (StateCreated , StateRevoked )
135+ require .NoError (t , err )
136+ require .Equal (t , 3 , len (sessions ))
137+ assertEqualSessions (t , s1 , sessions [0 ])
138+ assertEqualSessions (t , s2 , sessions [1 ])
139+ assertEqualSessions (t , s3 , sessions [2 ])
140+
141+ sessions , err = db .ListSessionsByState ()
89142 require .NoError (t , err )
90- require .Equal (t , session1 .State , StateRevoked )
143+ require .Empty (t , sessions )
144+
145+ sessions , err = db .ListSessionsByState (StateInUse )
146+ require .NoError (t , err )
147+ require .Empty (t , sessions )
91148}
92149
93150// TestLinkingSessions tests that session linking works as expected.
@@ -101,10 +158,10 @@ func TestLinkingSessions(t *testing.T) {
101158 })
102159
103160 // Create a new session with no previous link.
104- s1 := newSession (t , db , clock , "session 1" , nil )
161+ s1 := newSession (t , db , clock , "session 1" )
105162
106163 // Create another session and link it to the first.
107- s2 := newSession (t , db , clock , "session 2" , & s1 .GroupID )
164+ s2 := newSession (t , db , clock , "session 2" , withLinkedGroupID ( & s1 .GroupID ) )
108165
109166 // Try to persist the second session and assert that it fails due to the
110167 // linked session not existing in the DB yet.
@@ -141,9 +198,9 @@ func TestLinkedSessions(t *testing.T) {
141198 // after are all linked to the prior one. All these sessions belong to
142199 // the same group. The group ID is equivalent to the session ID of the
143200 // first session.
144- s1 := newSession (t , db , clock , "session 1" , nil )
145- s2 := newSession (t , db , clock , "session 2" , & s1 .GroupID )
146- s3 := newSession (t , db , clock , "session 3" , & s2 .GroupID )
201+ s1 := newSession (t , db , clock , "session 1" )
202+ s2 := newSession (t , db , clock , "session 2" , withLinkedGroupID ( & s1 .GroupID ) )
203+ s3 := newSession (t , db , clock , "session 3" , withLinkedGroupID ( & s2 .GroupID ) )
147204
148205 // Persist the sessions.
149206 require .NoError (t , db .CreateSession (s1 ))
@@ -169,8 +226,8 @@ func TestLinkedSessions(t *testing.T) {
169226
170227 // To ensure that different groups don't interfere with each other,
171228 // let's add another set of linked sessions not linked to the first.
172- s4 := newSession (t , db , clock , "session 4" , nil )
173- s5 := newSession (t , db , clock , "session 5" , & s4 .GroupID )
229+ s4 := newSession (t , db , clock , "session 4" )
230+ s5 := newSession (t , db , clock , "session 5" , withLinkedGroupID ( & s4 .GroupID ) )
174231
175232 require .NotEqual (t , s4 .GroupID , s1 .GroupID )
176233
@@ -209,7 +266,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
209266 // function is checked correctly.
210267
211268 // Add a new session to the DB.
212- s1 := newSession (t , db , clock , "label 1" , nil )
269+ s1 := newSession (t , db , clock , "label 1" )
213270 require .NoError (t , db .CreateSession (s1 ))
214271
215272 // Check that the group passes against an appropriate predicate.
@@ -234,7 +291,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
234291 require .NoError (t , db .RevokeSession (s1 .LocalPublicKey ))
235292
236293 // Add a new session to the same group as the first one.
237- s2 := newSession (t , db , clock , "label 2" , & s1 .GroupID )
294+ s2 := newSession (t , db , clock , "label 2" , withLinkedGroupID ( & s1 .GroupID ) )
238295 require .NoError (t , db .CreateSession (s2 ))
239296
240297 // Check that the group passes against an appropriate predicate.
@@ -256,7 +313,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
256313 require .False (t , ok )
257314
258315 // Add a new session that is not linked to the first one.
259- s3 := newSession (t , db , clock , "completely different" , nil )
316+ s3 := newSession (t , db , clock , "completely different" )
260317 require .NoError (t , db .CreateSession (s3 ))
261318
262319 // Ensure that the first group is unaffected.
@@ -286,8 +343,24 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
286343 require .True (t , ok )
287344}
288345
346+ // testSessionModifier is a functional option that can be used to modify the
347+ // default test session created by newSession.
348+ type testSessionModifier func (* Session )
349+
350+ func withLinkedGroupID (groupID * ID ) testSessionModifier {
351+ return func (s * Session ) {
352+ s .GroupID = * groupID
353+ }
354+ }
355+
356+ func withType (t Type ) testSessionModifier {
357+ return func (s * Session ) {
358+ s .Type = t
359+ }
360+ }
361+
289362func newSession (t * testing.T , db Store , clock clock.Clock , label string ,
290- linkedGroupID * ID ) * Session {
363+ mods ... testSessionModifier ) * Session {
291364
292365 id , priv , err := db .GetUnusedIDAndKeyPair ()
293366 require .NoError (t , err )
@@ -296,11 +369,15 @@ func newSession(t *testing.T, db Store, clock clock.Clock, label string,
296369 id , priv , label , TypeMacaroonAdmin ,
297370 clock .Now (),
298371 time .Date (99999 , 1 , 1 , 0 , 0 , 0 , 0 , time .UTC ),
299- "foo.bar.baz:1234" , true , nil , nil , nil , true , linkedGroupID ,
372+ "foo.bar.baz:1234" , true , nil , nil , nil , true , nil ,
300373 []PrivacyFlag {ClearPubkeys },
301374 )
302375 require .NoError (t , err )
303376
377+ for _ , mod := range mods {
378+ mod (session )
379+ }
380+
304381 return session
305382}
306383
0 commit comments