@@ -24,10 +24,10 @@ func TestBasicSessionStore(t *testing.T) {
2424 })
2525
2626 // Create a few sessions.
27- s1 := newSession (t , db , clock , "session 1" , nil )
28- s2 := newSession (t , db , clock , "session 2" , nil )
29- s3 := newSession (t , db , clock , "session 3" , nil )
30- s4 := newSession (t , db , clock , "session 4" , nil )
27+ s1 := newSession (t , db , clock , "session 1" )
28+ s2 := newSession (t , db , clock , "session 2" )
29+ s3 := newSession (t , db , clock , "session 3" )
30+ s4 := newSession (t , db , clock , "session 4" )
3131
3232 // Persist session 1. This should now succeed.
3333 require .NoError (t , db .CreateSession (s1 ))
@@ -101,10 +101,10 @@ func TestLinkingSessions(t *testing.T) {
101101 })
102102
103103 // Create a new session with no previous link.
104- s1 := newSession (t , db , clock , "session 1" , nil )
104+ s1 := newSession (t , db , clock , "session 1" )
105105
106106 // Create another session and link it to the first.
107- s2 := newSession (t , db , clock , "session 2" , & s1 .GroupID )
107+ s2 := newSession (t , db , clock , "session 2" , withLinkedGroupID ( & s1 .GroupID ) )
108108
109109 // Try to persist the second session and assert that it fails due to the
110110 // linked session not existing in the DB yet.
@@ -141,9 +141,9 @@ func TestLinkedSessions(t *testing.T) {
141141 // after are all linked to the prior one. All these sessions belong to
142142 // the same group. The group ID is equivalent to the session ID of the
143143 // first session.
144- s1 := newSession (t , db , clock , "session 1" , nil )
145- s2 := newSession (t , db , clock , "session 2" , & s1 .GroupID )
146- s3 := newSession (t , db , clock , "session 3" , & s2 .GroupID )
144+ s1 := newSession (t , db , clock , "session 1" )
145+ s2 := newSession (t , db , clock , "session 2" , withLinkedGroupID ( & s1 .GroupID ) )
146+ s3 := newSession (t , db , clock , "session 3" , withLinkedGroupID ( & s2 .GroupID ) )
147147
148148 // Persist the sessions.
149149 require .NoError (t , db .CreateSession (s1 ))
@@ -169,8 +169,8 @@ func TestLinkedSessions(t *testing.T) {
169169
170170 // To ensure that different groups don't interfere with each other,
171171 // let's add another set of linked sessions not linked to the first.
172- s4 := newSession (t , db , clock , "session 4" , nil )
173- s5 := newSession (t , db , clock , "session 5" , & s4 .GroupID )
172+ s4 := newSession (t , db , clock , "session 4" )
173+ s5 := newSession (t , db , clock , "session 5" , withLinkedGroupID ( & s4 .GroupID ) )
174174
175175 require .NotEqual (t , s4 .GroupID , s1 .GroupID )
176176
@@ -209,7 +209,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
209209 // function is checked correctly.
210210
211211 // Add a new session to the DB.
212- s1 := newSession (t , db , clock , "label 1" , nil )
212+ s1 := newSession (t , db , clock , "label 1" )
213213 require .NoError (t , db .CreateSession (s1 ))
214214
215215 // Check that the group passes against an appropriate predicate.
@@ -234,7 +234,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
234234 require .NoError (t , db .RevokeSession (s1 .LocalPublicKey ))
235235
236236 // Add a new session to the same group as the first one.
237- s2 := newSession (t , db , clock , "label 2" , & s1 .GroupID )
237+ s2 := newSession (t , db , clock , "label 2" , withLinkedGroupID ( & s1 .GroupID ) )
238238 require .NoError (t , db .CreateSession (s2 ))
239239
240240 // Check that the group passes against an appropriate predicate.
@@ -256,7 +256,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
256256 require .False (t , ok )
257257
258258 // Add a new session that is not linked to the first one.
259- s3 := newSession (t , db , clock , "completely different" , nil )
259+ s3 := newSession (t , db , clock , "completely different" )
260260 require .NoError (t , db .CreateSession (s3 ))
261261
262262 // Ensure that the first group is unaffected.
@@ -286,8 +286,18 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
286286 require .True (t , ok )
287287}
288288
289+ // testSessionModifier is a functional option that can be used to modify the
290+ // default test session created by newSession.
291+ type testSessionModifier func (* Session )
292+
293+ func withLinkedGroupID (groupID * ID ) testSessionModifier {
294+ return func (s * Session ) {
295+ s .GroupID = * groupID
296+ }
297+ }
298+
289299func newSession (t * testing.T , db Store , clock clock.Clock , label string ,
290- linkedGroupID * ID ) * Session {
300+ mods ... testSessionModifier ) * Session {
291301
292302 id , priv , err := db .GetUnusedIDAndKeyPair ()
293303 require .NoError (t , err )
@@ -296,11 +306,15 @@ func newSession(t *testing.T, db Store, clock clock.Clock, label string,
296306 id , priv , label , TypeMacaroonAdmin ,
297307 clock .Now (),
298308 time .Date (99999 , 1 , 1 , 0 , 0 , 0 , 0 , time .UTC ),
299- "foo.bar.baz:1234" , true , nil , nil , nil , true , linkedGroupID ,
309+ "foo.bar.baz:1234" , true , nil , nil , nil , true , nil ,
300310 []PrivacyFlag {ClearPubkeys },
301311 )
302312 require .NoError (t , err )
303313
314+ for _ , mod := range mods {
315+ mod (session )
316+ }
317+
304318 return session
305319}
306320
0 commit comments