Skip to content

Commit d397186

Browse files
committed
session: add context to GetSession
1 parent 52dc4bd commit d397186

File tree

4 files changed

+18
-10
lines changed

4 files changed

+18
-10
lines changed

session/interface.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ type Store interface {
198198
flags PrivacyFlags) (*Session, error)
199199

200200
// GetSession fetches the session with the given key.
201-
GetSession(key *btcec.PublicKey) (*Session, error)
201+
GetSession(ctx context.Context, key *btcec.PublicKey) (*Session, error)
202202

203203
// ListAllSessions returns all sessions currently known to the store.
204204
ListAllSessions() ([]*Session, error)

session/kvdb_store.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,9 @@ func (db *BoltStore) UpdateSessionRemotePubKey(localPubKey,
320320
// GetSession fetches the session with the given key.
321321
//
322322
// NOTE: this is part of the Store interface.
323-
func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
323+
func (db *BoltStore) GetSession(_ context.Context, key *btcec.PublicKey) (
324+
*Session, error) {
325+
324326
var session *Session
325327
err := db.View(func(tx *bbolt.Tx) error {
326328
sessionBucket, err := getBucket(tx, sessionBucketKey)

session/store_test.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ var testTime = time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
1515
// TestBasicSessionStore tests the basic getters and setters of the session
1616
// store.
1717
func TestBasicSessionStore(t *testing.T) {
18+
t.Parallel()
19+
ctx := context.Background()
20+
1821
// Set up a new DB.
1922
clock := clock.NewTestClock(testTime)
2023
db, err := NewDB(t.TempDir(), "test.db", clock)
@@ -73,7 +76,7 @@ func TestBasicSessionStore(t *testing.T) {
7376
// Ensure that we can retrieve each session by both its local pub key
7477
// and by its ID.
7578
for _, s := range []*Session{s1, s2, s3} {
76-
session, err := db.GetSession(s.LocalPublicKey)
79+
session, err := db.GetSession(ctx, s.LocalPublicKey)
7780
require.NoError(t, err)
7881
assertEqualSessions(t, s, session)
7982

@@ -83,7 +86,7 @@ func TestBasicSessionStore(t *testing.T) {
8386
}
8487

8588
// Fetch session 1 and assert that it currently has no remote pub key.
86-
session1, err := db.GetSession(s1.LocalPublicKey)
89+
session1, err := db.GetSession(ctx, s1.LocalPublicKey)
8790
require.NoError(t, err)
8891
require.Nil(t, session1.RemotePublicKey)
8992

@@ -96,7 +99,7 @@ func TestBasicSessionStore(t *testing.T) {
9699
require.NoError(t, err)
97100

98101
// Assert that the session now does have the remote pub key.
99-
session1, err = db.GetSession(s1.LocalPublicKey)
102+
session1, err = db.GetSession(ctx, s1.LocalPublicKey)
100103
require.NoError(t, err)
101104
require.True(t, remotePub.IsEqual(session1.RemotePublicKey))
102105

@@ -105,7 +108,7 @@ func TestBasicSessionStore(t *testing.T) {
105108

106109
// Now revoke the session and assert that the state is revoked.
107110
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
108-
s1, err = db.GetSession(s1.LocalPublicKey)
111+
s1, err = db.GetSession(ctx, s1.LocalPublicKey)
109112
require.NoError(t, err)
110113
require.Equal(t, s1.State, StateRevoked)
111114

@@ -293,6 +296,9 @@ func TestLinkedSessions(t *testing.T) {
293296

294297
// TestStateShift tests that the ShiftState method works as expected.
295298
func TestStateShift(t *testing.T) {
299+
t.Parallel()
300+
ctx := context.Background()
301+
296302
// Set up a new DB.
297303
clock := clock.NewTestClock(testTime)
298304
db, err := NewDB(t.TempDir(), "test.db", clock)
@@ -306,7 +312,7 @@ func TestStateShift(t *testing.T) {
306312

307313
// Check that the session is in the StateCreated state. Also check that
308314
// the "RevokedAt" time has not yet been set.
309-
s1, err = db.GetSession(s1.LocalPublicKey)
315+
s1, err = db.GetSession(ctx, s1.LocalPublicKey)
310316
require.NoError(t, err)
311317
require.Equal(t, StateCreated, s1.State)
312318
require.Equal(t, time.Time{}, s1.RevokedAt)
@@ -317,7 +323,7 @@ func TestStateShift(t *testing.T) {
317323

318324
// This should have worked. Since it is now in a terminal state, the
319325
// "RevokedAt" time should be set.
320-
s1, err = db.GetSession(s1.LocalPublicKey)
326+
s1, err = db.GetSession(ctx, s1.LocalPublicKey)
321327
require.NoError(t, err)
322328
require.Equal(t, StateRevoked, s1.State)
323329
require.True(t, clock.Now().Equal(s1.RevokedAt))

session_rpcserver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context,
556556
return nil, fmt.Errorf("error parsing public key: %v", err)
557557
}
558558

559-
sess, err := s.cfg.db.GetSession(pubKey)
559+
sess, err := s.cfg.db.GetSession(ctx, pubKey)
560560
if err != nil {
561561
return nil, fmt.Errorf("error fetching session: %v", err)
562562
}
@@ -1276,7 +1276,7 @@ func (s *sessionRpcServer) RevokeAutopilotSession(ctx context.Context,
12761276
return nil, fmt.Errorf("error parsing public key: %v", err)
12771277
}
12781278

1279-
sess, err := s.cfg.db.GetSession(pubKey)
1279+
sess, err := s.cfg.db.GetSession(ctx, pubKey)
12801280
if err != nil {
12811281
return nil, err
12821282
}

0 commit comments

Comments
 (0)