Skip to content

Commit f93093e

Browse files
committed
session: add context to ShiftState
1 parent aa7b1db commit f93093e

File tree

4 files changed

+21
-20
lines changed

4 files changed

+21
-20
lines changed

session/interface.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ type Store interface {
306306

307307
// ShiftState updates the state of the session with the given ID to the
308308
// "dest" state.
309-
ShiftState(id ID, dest State) error
309+
ShiftState(ctx context.Context, id ID, dest State) error
310310

311311
IDToGroupIndex
312312
}

session/kvdb_store.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ func (db *BoltStore) DeleteReservedSessions(_ context.Context) error {
529529
// state.
530530
//
531531
// NOTE: this is part of the Store interface.
532-
func (db *BoltStore) ShiftState(id ID, dest State) error {
532+
func (db *BoltStore) ShiftState(_ context.Context, id ID, dest State) error {
533533
return db.Update(func(tx *bbolt.Tx) error {
534534
sessionBucket, err := getBucket(tx, sessionBucketKey)
535535
if err != nil {

session/store_test.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func TestBasicSessionStore(t *testing.T) {
3636
require.Equal(t, StateReserved, s1.State)
3737

3838
// Move session 1 to the created state. This should succeed.
39-
err = db.ShiftState(s1.ID, StateCreated)
39+
err = db.ShiftState(ctx, s1.ID, StateCreated)
4040
require.NoError(t, err)
4141

4242
// Show that the session is now in the created state.
@@ -46,7 +46,7 @@ func TestBasicSessionStore(t *testing.T) {
4646

4747
// Trying to move session 1 again should have no effect since it is
4848
// already in the created state.
49-
require.NoError(t, db.ShiftState(s1.ID, StateCreated))
49+
require.NoError(t, db.ShiftState(ctx, s1.ID, StateCreated))
5050

5151
// Reserve and create a few more sessions. We increment the time by one
5252
// second between each session to ensure that the created at time is
@@ -107,7 +107,7 @@ func TestBasicSessionStore(t *testing.T) {
107107
require.Equal(t, session1.State, StateCreated)
108108

109109
// Now revoke the session and assert that the state is revoked.
110-
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
110+
require.NoError(t, db.ShiftState(ctx, s1.ID, StateRevoked))
111111
s1, err = db.GetSession(ctx, s1.LocalPublicKey)
112112
require.NoError(t, err)
113113
require.Equal(t, s1.State, StateRevoked)
@@ -198,6 +198,7 @@ func TestBasicSessionStore(t *testing.T) {
198198
// TestLinkingSessions tests that session linking works as expected.
199199
func TestLinkingSessions(t *testing.T) {
200200
t.Parallel()
201+
ctx := context.Background()
201202

202203
// Set up a new DB.
203204
clock := clock.NewTestClock(testTime)
@@ -223,7 +224,7 @@ func TestLinkingSessions(t *testing.T) {
223224
require.ErrorIs(t, err, ErrSessionsInGroupStillActive)
224225

225226
// Revoke the first session.
226-
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
227+
require.NoError(t, db.ShiftState(ctx, s1.ID, StateRevoked))
227228

228229
// Persisting the second linked session should now work.
229230
_, err = reserveSession(db, "session 2", withLinkedGroupID(&s1.GroupID))
@@ -247,10 +248,10 @@ func TestLinkedSessions(t *testing.T) {
247248
// first session.
248249
s1 := createSession(t, db, "session 1")
249250

250-
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
251+
require.NoError(t, db.ShiftState(ctx, s1.ID, StateRevoked))
251252
s2 := createSession(t, db, "session 2", withLinkedGroupID(&s1.GroupID))
252253

253-
require.NoError(t, db.ShiftState(s2.ID, StateRevoked))
254+
require.NoError(t, db.ShiftState(ctx, s2.ID, StateRevoked))
254255
s3 := createSession(t, db, "session 3", withLinkedGroupID(&s2.GroupID))
255256

256257
// Assert that the session ID to group ID index works as expected.
@@ -269,7 +270,7 @@ func TestLinkedSessions(t *testing.T) {
269270
// To ensure that different groups don't interfere with each other,
270271
// let's add another set of linked sessions not linked to the first.
271272
s4 := createSession(t, db, "session 4")
272-
require.NoError(t, db.ShiftState(s4.ID, StateRevoked))
273+
require.NoError(t, db.ShiftState(ctx, s4.ID, StateRevoked))
273274
s5 := createSession(t, db, "session 5", withLinkedGroupID(&s4.GroupID))
274275
require.NotEqual(t, s4.GroupID, s1.GroupID)
275276

@@ -307,7 +308,7 @@ func TestStateShift(t *testing.T) {
307308
require.Equal(t, time.Time{}, s1.RevokedAt)
308309

309310
// Shift the state of the session to StateRevoked.
310-
err = db.ShiftState(s1.ID, StateRevoked)
311+
err = db.ShiftState(ctx, s1.ID, StateRevoked)
311312
require.NoError(t, err)
312313

313314
// This should have worked. Since it is now in a terminal state, the
@@ -322,13 +323,13 @@ func TestStateShift(t *testing.T) {
322323
// should not have changed though.
323324
prevTime := clock.Now()
324325
clock.SetTime(prevTime.Add(time.Second))
325-
err = db.ShiftState(s1.ID, StateRevoked)
326+
err = db.ShiftState(ctx, s1.ID, StateRevoked)
326327
require.NoError(t, err)
327328
require.True(t, prevTime.Equal(s1.RevokedAt))
328329

329330
// Trying to shift the state from a terminal state back to StateCreated
330331
// should also fail since this is not a legal state transition.
331-
err = db.ShiftState(s1.ID, StateCreated)
332+
err = db.ShiftState(ctx, s1.ID, StateCreated)
332333
require.ErrorContains(t, err, "illegal session state transition")
333334
}
334335

@@ -384,7 +385,7 @@ func createSession(t *testing.T, db Store, label string,
384385
s, err := reserveSession(db, label, mods...)
385386
require.NoError(t, err)
386387

387-
err = db.ShiftState(s.ID, StateCreated)
388+
err = db.ShiftState(context.Background(), s.ID, StateCreated)
388389
require.NoError(t, err)
389390

390391
s, err = db.GetSessionByID(context.Background(), s.ID)

session_rpcserver.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func (s *sessionRpcServer) start(ctx context.Context) error {
149149

150150
if perm {
151151
err := s.cfg.db.ShiftState(
152-
sess.ID, session.StateRevoked,
152+
ctx, sess.ID, session.StateRevoked,
153153
)
154154
if err != nil {
155155
log.Errorf("error revoking "+
@@ -323,7 +323,7 @@ func (s *sessionRpcServer) AddSession(ctx context.Context,
323323
return nil, fmt.Errorf("error creating new session: %v", err)
324324
}
325325

326-
err = s.cfg.db.ShiftState(sess.ID, session.StateCreated)
326+
err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateCreated)
327327
if err != nil {
328328
return nil, fmt.Errorf("error shifting session state to "+
329329
"Created: %v", err)
@@ -362,7 +362,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
362362
log.Debugf("Not resuming session %x with expiry %s",
363363
pubKeyBytes, sess.Expiry)
364364

365-
err := s.cfg.db.ShiftState(sess.ID, session.StateExpired)
365+
err := s.cfg.db.ShiftState(ctx, sess.ID, session.StateExpired)
366366
if err != nil {
367367
return fmt.Errorf("error revoking session: %v", err)
368368
}
@@ -440,7 +440,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
440440
"passed. Revoking session", pubKeyBytes)
441441

442442
return s.cfg.db.ShiftState(
443-
sess.ID, session.StateRevoked,
443+
ctx, sess.ID, session.StateRevoked,
444444
)
445445
}
446446

@@ -520,7 +520,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
520520
log.Debugf("Error stopping session: %v", err)
521521
}
522522

523-
err = s.cfg.db.ShiftState(sess.ID, session.StateRevoked)
523+
err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateRevoked)
524524
if err != nil {
525525
log.Debugf("error revoking session: %v", err)
526526
}
@@ -567,7 +567,7 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context,
567567
return nil, fmt.Errorf("error fetching session: %v", err)
568568
}
569569

570-
err = s.cfg.db.ShiftState(sess.ID, session.StateRevoked)
570+
err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateRevoked)
571571
if err != nil {
572572
return nil, fmt.Errorf("error revoking session: %v", err)
573573
}
@@ -1240,7 +1240,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
12401240

12411241
// We only activate the session if the Autopilot server registration
12421242
// was successful.
1243-
err = s.cfg.db.ShiftState(sess.ID, session.StateCreated)
1243+
err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateCreated)
12441244
if err != nil {
12451245
return nil, fmt.Errorf("error shifting session state to "+
12461246
"Created: %v", err)

0 commit comments

Comments
 (0)