Skip to content

Commit 56d7b45

Browse files
committed
session: replace RevokeSession with ShiftState
1 parent 85d8ffd commit 56d7b45

File tree

4 files changed

+19
-46
lines changed

4 files changed

+19
-46
lines changed

session/interface.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,6 @@ type Store interface {
215215
// that are in the given states.
216216
ListSessionsByState(...State) ([]*Session, error)
217217

218-
// RevokeSession updates the state of the session with the given local
219-
// public key to be revoked.
220-
RevokeSession(*btcec.PublicKey) error
221-
222218
// UpdateSessionRemotePubKey can be used to add the given remote pub key
223219
// to the session with the given local pub key.
224220
UpdateSessionRemotePubKey(localPubKey,

session/kvdb_store.go

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -562,35 +562,6 @@ func (db *BoltStore) ShiftState(key *btcec.PublicKey, dest State) error {
562562
})
563563
}
564564

565-
// RevokeSession updates the state of the session with the given local
566-
// public key to be revoked.
567-
//
568-
// NOTE: this is part of the Store interface.
569-
func (db *BoltStore) RevokeSession(key *btcec.PublicKey) error {
570-
var session *Session
571-
return db.Update(func(tx *bbolt.Tx) error {
572-
sessionBucket, err := getBucket(tx, sessionBucketKey)
573-
if err != nil {
574-
return err
575-
}
576-
577-
sessionBytes := sessionBucket.Get(key.SerializeCompressed())
578-
if len(sessionBytes) == 0 {
579-
return ErrSessionNotFound
580-
}
581-
582-
session, err = DeserializeSession(bytes.NewReader(sessionBytes))
583-
if err != nil {
584-
return err
585-
}
586-
587-
session.State = StateRevoked
588-
session.RevokedAt = db.clock.Now().UTC()
589-
590-
return putSession(sessionBucket, session)
591-
})
592-
}
593-
594565
// GetSessionByID fetches the session with the given ID.
595566
//
596567
// NOTE: this is part of the Store interface.

session/store_test.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func TestBasicSessionStore(t *testing.T) {
106106
require.Equal(t, session1.State, StateCreated)
107107

108108
// Now revoke the session and assert that the state is revoked.
109-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
109+
require.NoError(t, db.ShiftState(s1.LocalPublicKey, StateRevoked))
110110
s1, err = db.GetSession(s1.LocalPublicKey)
111111
require.NoError(t, err)
112112
require.Equal(t, s1.State, StateRevoked)
@@ -225,7 +225,7 @@ func TestLinkingSessions(t *testing.T) {
225225
require.ErrorContains(t, db.CreateSession(s2), "is still active")
226226

227227
// Revoke the first session.
228-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
228+
require.NoError(t, db.ShiftState(s1.LocalPublicKey, StateRevoked))
229229

230230
// Persisting the second linked session should now work.
231231
require.NoError(t, db.CreateSession(s2))
@@ -248,16 +248,20 @@ func TestLinkedSessions(t *testing.T) {
248248
// the same group. The group ID is equivalent to the session ID of the
249249
// first session.
250250
s1 := newSession(t, db, clock, "session 1")
251-
s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID))
252-
s3 := newSession(t, db, clock, "session 3", withLinkedGroupID(&s2.GroupID))
251+
s2 := newSession(
252+
t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID),
253+
)
254+
s3 := newSession(
255+
t, db, clock, "session 3", withLinkedGroupID(&s2.GroupID),
256+
)
253257

254258
// Persist the sessions.
255259
require.NoError(t, db.CreateSession(s1))
256260

257-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
261+
require.NoError(t, db.ShiftState(s1.LocalPublicKey, StateRevoked))
258262
require.NoError(t, db.CreateSession(s2))
259263

260-
require.NoError(t, db.RevokeSession(s2.LocalPublicKey))
264+
require.NoError(t, db.ShiftState(s2.LocalPublicKey, StateRevoked))
261265
require.NoError(t, db.CreateSession(s3))
262266

263267
// Assert that the session ID to group ID index works as expected.
@@ -282,7 +286,7 @@ func TestLinkedSessions(t *testing.T) {
282286

283287
// Persist the sessions.
284288
require.NoError(t, db.CreateSession(s4))
285-
require.NoError(t, db.RevokeSession(s4.LocalPublicKey))
289+
require.NoError(t, db.ShiftState(s4.LocalPublicKey, StateRevoked))
286290

287291
require.NoError(t, db.CreateSession(s5))
288292

@@ -337,7 +341,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
337341
require.False(t, ok)
338342

339343
// Revoke the first session.
340-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
344+
require.NoError(t, db.ShiftState(s1.LocalPublicKey, StateRevoked))
341345

342346
// Add a new session to the same group as the first one.
343347
s2 := newSession(t, db, clock, "label 2", withLinkedGroupID(&s1.GroupID))

session_rpcserver.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,9 @@ func (s *sessionRpcServer) start(ctx context.Context) error {
154154
err)
155155

156156
if perm {
157-
err := s.cfg.db.RevokeSession(
157+
err := s.cfg.db.ShiftState(
158158
sess.LocalPublicKey,
159+
session.StateRevoked,
159160
)
160161
if err != nil {
161162
log.Errorf("error revoking "+
@@ -360,7 +361,8 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
360361
log.Debugf("Not resuming session %x with expiry %s",
361362
pubKeyBytes, sess.Expiry)
362363

363-
if err := s.cfg.db.RevokeSession(pubKey); err != nil {
364+
err := s.cfg.db.ShiftState(pubKey, session.StateRevoked)
365+
if err != nil {
364366
return fmt.Errorf("error revoking session: %v", err)
365367
}
366368

@@ -436,7 +438,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
436438
log.Debugf("Deadline for session %x has already "+
437439
"passed. Revoking session", pubKeyBytes)
438440

439-
return s.cfg.db.RevokeSession(pubKey)
441+
return s.cfg.db.ShiftState(pubKey, session.StateRevoked)
440442
}
441443

442444
// Start the deadline timer.
@@ -515,7 +517,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
515517
log.Debugf("Error stopping session: %v", err)
516518
}
517519

518-
err = s.cfg.db.RevokeSession(pubKey)
520+
err = s.cfg.db.ShiftState(pubKey, session.StateRevoked)
519521
if err != nil {
520522
log.Debugf("error revoking session: %v", err)
521523
}
@@ -557,7 +559,7 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context,
557559
return nil, fmt.Errorf("error parsing public key: %v", err)
558560
}
559561

560-
if err := s.cfg.db.RevokeSession(pubKey); err != nil {
562+
if err := s.cfg.db.ShiftState(pubKey, session.StateRevoked); err != nil {
561563
return nil, fmt.Errorf("error revoking session: %v", err)
562564
}
563565

0 commit comments

Comments
 (0)