@@ -106,7 +106,7 @@ func TestBasicSessionStore(t *testing.T) {
106106 require .Equal (t , session1 .State , StateCreated )
107107
108108 // Now revoke the session and assert that the state is revoked.
109- require .NoError (t , db .RevokeSession (s1 .LocalPublicKey ))
109+ require .NoError (t , db .ShiftState (s1 .ID , StateRevoked ))
110110 s1 , err = db .GetSession (s1 .LocalPublicKey )
111111 require .NoError (t , err )
112112 require .Equal (t , s1 .State , StateRevoked )
@@ -225,7 +225,7 @@ func TestLinkingSessions(t *testing.T) {
225225 require .ErrorContains (t , db .CreateSession (s2 ), "is still active" )
226226
227227 // Revoke the first session.
228- require .NoError (t , db .RevokeSession (s1 .LocalPublicKey ))
228+ require .NoError (t , db .ShiftState (s1 .ID , StateRevoked ))
229229
230230 // Persisting the second linked session should now work.
231231 require .NoError (t , db .CreateSession (s2 ))
@@ -248,16 +248,20 @@ func TestLinkedSessions(t *testing.T) {
248248 // the same group. The group ID is equivalent to the session ID of the
249249 // first session.
250250 s1 := newSession (t , db , clock , "session 1" )
251- s2 := newSession (t , db , clock , "session 2" , withLinkedGroupID (& s1 .GroupID ))
252- s3 := newSession (t , db , clock , "session 3" , withLinkedGroupID (& s2 .GroupID ))
251+ s2 := newSession (
252+ t , db , clock , "session 2" , withLinkedGroupID (& s1 .GroupID ),
253+ )
254+ s3 := newSession (
255+ t , db , clock , "session 3" , withLinkedGroupID (& s2 .GroupID ),
256+ )
253257
254258 // Persist the sessions.
255259 require .NoError (t , db .CreateSession (s1 ))
256260
257- require .NoError (t , db .RevokeSession (s1 .LocalPublicKey ))
261+ require .NoError (t , db .ShiftState (s1 .ID , StateRevoked ))
258262 require .NoError (t , db .CreateSession (s2 ))
259263
260- require .NoError (t , db .RevokeSession (s2 .LocalPublicKey ))
264+ require .NoError (t , db .ShiftState (s2 .ID , StateRevoked ))
261265 require .NoError (t , db .CreateSession (s3 ))
262266
263267 // Assert that the session ID to group ID index works as expected.
@@ -282,7 +286,7 @@ func TestLinkedSessions(t *testing.T) {
282286
283287 // Persist the sessions.
284288 require .NoError (t , db .CreateSession (s4 ))
285- require .NoError (t , db .RevokeSession (s4 .LocalPublicKey ))
289+ require .NoError (t , db .ShiftState (s4 .ID , StateRevoked ))
286290
287291 require .NoError (t , db .CreateSession (s5 ))
288292
@@ -337,7 +341,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
337341 require .False (t , ok )
338342
339343 // Revoke the first session.
340- require .NoError (t , db .RevokeSession (s1 .LocalPublicKey ))
344+ require .NoError (t , db .ShiftState (s1 .ID , StateRevoked ))
341345
342346 // Add a new session to the same group as the first one.
343347 s2 := newSession (t , db , clock , "label 2" , withLinkedGroupID (& s1 .GroupID ))
@@ -392,6 +396,53 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
392396 require .True (t , ok )
393397}
394398
399+ // TestStateShift tests that the ShiftState method works as expected.
400+ func TestStateShift (t * testing.T ) {
401+ // Set up a new DB.
402+ clock := clock .NewTestClock (testTime )
403+ db , err := NewDB (t .TempDir (), "test.db" , clock )
404+ require .NoError (t , err )
405+ t .Cleanup (func () {
406+ _ = db .Close ()
407+ })
408+
409+ // Add a new session to the DB.
410+ s1 := newSession (t , db , clock , "label 1" )
411+ require .NoError (t , db .CreateSession (s1 ))
412+
413+ // Check that the session is in the StateCreated state. Also check that
414+ // the "RevokedAt" time has not yet been set.
415+ s1 , err = db .GetSession (s1 .LocalPublicKey )
416+ require .NoError (t , err )
417+ require .Equal (t , StateCreated , s1 .State )
418+ require .Equal (t , time.Time {}, s1 .RevokedAt )
419+
420+ // Shift the state of the session to StateRevoked.
421+ err = db .ShiftState (s1 .ID , StateRevoked )
422+ require .NoError (t , err )
423+
424+ // This should have worked. Since it is now in a terminal state, the
425+ // "RevokedAt" time should be set.
426+ s1 , err = db .GetSession (s1 .LocalPublicKey )
427+ require .NoError (t , err )
428+ require .Equal (t , StateRevoked , s1 .State )
429+ require .True (t , clock .Now ().Equal (s1 .RevokedAt ))
430+
431+ // Trying to do the same state shift again should succeed since the
432+ // session is already in the expected "dest" state. The revoked-at time
433+ // should not have changed though.
434+ prevTime := clock .Now ()
435+ clock .SetTime (prevTime .Add (time .Second ))
436+ err = db .ShiftState (s1 .ID , StateRevoked )
437+ require .NoError (t , err )
438+ require .True (t , prevTime .Equal (s1 .RevokedAt ))
439+
440+ // Trying to shift the state from a terminal state back to StateCreated
441+ // should also fail since this is not a legal state transition.
442+ err = db .ShiftState (s1 .ID , StateCreated )
443+ require .ErrorContains (t , err , "illegal session state transition" )
444+ }
445+
395446// testSessionModifier is a functional option that can be used to modify the
396447// default test session created by newSession.
397448type testSessionModifier func (* Session )
0 commit comments