Skip to content

Commit 6f49f90

Browse files
committed
session: add context to various session List methods
1 parent d397186 commit 6f49f90

File tree

4 files changed

+28
-24
lines changed

4 files changed

+28
-24
lines changed

session/interface.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,15 @@ type Store interface {
201201
GetSession(ctx context.Context, key *btcec.PublicKey) (*Session, error)
202202

203203
// ListAllSessions returns all sessions currently known to the store.
204-
ListAllSessions() ([]*Session, error)
204+
ListAllSessions(ctx context.Context) ([]*Session, error)
205205

206206
// ListSessionsByType returns all sessions of the given type.
207-
ListSessionsByType(t Type) ([]*Session, error)
207+
ListSessionsByType(ctx context.Context, t Type) ([]*Session, error)
208208

209209
// ListSessionsByState returns all sessions currently known to the store
210210
// that are in the given states.
211-
ListSessionsByState(...State) ([]*Session, error)
211+
ListSessionsByState(ctx context.Context, state ...State) ([]*Session,
212+
error)
212213

213214
// UpdateSessionRemotePubKey can be used to add the given remote pub key
214215
// to the session with the given local pub key.

session/kvdb_store.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ func (db *BoltStore) GetSession(_ context.Context, key *btcec.PublicKey) (
352352
// ListAllSessions returns all sessions currently known to the store.
353353
//
354354
// NOTE: this is part of the Store interface.
355-
func (db *BoltStore) ListAllSessions() ([]*Session, error) {
355+
func (db *BoltStore) ListAllSessions(_ context.Context) ([]*Session, error) {
356356
return db.listSessions(func(s *Session) bool {
357357
return true
358358
})
@@ -362,7 +362,9 @@ func (db *BoltStore) ListAllSessions() ([]*Session, error) {
362362
// have the given type.
363363
//
364364
// NOTE: this is part of the Store interface.
365-
func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
365+
func (db *BoltStore) ListSessionsByType(_ context.Context, t Type) ([]*Session,
366+
error) {
367+
366368
return db.listSessions(func(s *Session) bool {
367369
return s.Type == t
368370
})
@@ -372,7 +374,9 @@ func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
372374
// are in the given states.
373375
//
374376
// NOTE: this is part of the Store interface.
375-
func (db *BoltStore) ListSessionsByState(states ...State) ([]*Session, error) {
377+
func (db *BoltStore) ListSessionsByState(_ context.Context, states ...State) (
378+
[]*Session, error) {
379+
376380
return db.listSessions(func(s *Session) bool {
377381
for _, state := range states {
378382
if s.State == state {

session/store_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,18 @@ func TestBasicSessionStore(t *testing.T) {
5858
s3 := createSession(t, db, "session 3", withType(TypeAutopilot))
5959

6060
// Test the ListSessionsByType method.
61-
sessions, err := db.ListSessionsByType(TypeMacaroonAdmin)
61+
sessions, err := db.ListSessionsByType(ctx, TypeMacaroonAdmin)
6262
require.NoError(t, err)
6363
require.Equal(t, 2, len(sessions))
6464
assertEqualSessions(t, s1, sessions[0])
6565
assertEqualSessions(t, s2, sessions[1])
6666

67-
sessions, err = db.ListSessionsByType(TypeAutopilot)
67+
sessions, err = db.ListSessionsByType(ctx, TypeAutopilot)
6868
require.NoError(t, err)
6969
require.Equal(t, 1, len(sessions))
7070
assertEqualSessions(t, s3, sessions[0])
7171

72-
sessions, err = db.ListSessionsByType(TypeMacaroonReadonly)
72+
sessions, err = db.ListSessionsByType(ctx, TypeMacaroonReadonly)
7373
require.NoError(t, err)
7474
require.Empty(t, sessions)
7575

@@ -113,37 +113,37 @@ func TestBasicSessionStore(t *testing.T) {
113113
require.Equal(t, s1.State, StateRevoked)
114114

115115
// Test that ListAllSessions works.
116-
sessions, err = db.ListAllSessions()
116+
sessions, err = db.ListAllSessions(ctx)
117117
require.NoError(t, err)
118118
require.Equal(t, 3, len(sessions))
119119
assertEqualSessions(t, s1, sessions[0])
120120
assertEqualSessions(t, s2, sessions[1])
121121
assertEqualSessions(t, s3, sessions[2])
122122

123123
// Test that ListSessionsByState works.
124-
sessions, err = db.ListSessionsByState(StateRevoked)
124+
sessions, err = db.ListSessionsByState(ctx, StateRevoked)
125125
require.NoError(t, err)
126126
require.Equal(t, 1, len(sessions))
127127
assertEqualSessions(t, s1, sessions[0])
128128

129-
sessions, err = db.ListSessionsByState(StateCreated)
129+
sessions, err = db.ListSessionsByState(ctx, StateCreated)
130130
require.NoError(t, err)
131131
require.Equal(t, 2, len(sessions))
132132
assertEqualSessions(t, s2, sessions[0])
133133
assertEqualSessions(t, s3, sessions[1])
134134

135-
sessions, err = db.ListSessionsByState(StateCreated, StateRevoked)
135+
sessions, err = db.ListSessionsByState(ctx, StateCreated, StateRevoked)
136136
require.NoError(t, err)
137137
require.Equal(t, 3, len(sessions))
138138
assertEqualSessions(t, s1, sessions[0])
139139
assertEqualSessions(t, s2, sessions[1])
140140
assertEqualSessions(t, s3, sessions[2])
141141

142-
sessions, err = db.ListSessionsByState()
142+
sessions, err = db.ListSessionsByState(ctx)
143143
require.NoError(t, err)
144144
require.Empty(t, sessions)
145145

146-
sessions, err = db.ListSessionsByState(StateReserved)
146+
sessions, err = db.ListSessionsByState(ctx, StateReserved)
147147
require.NoError(t, err)
148148
require.Empty(t, sessions)
149149

@@ -153,7 +153,7 @@ func TestBasicSessionStore(t *testing.T) {
153153
// of the sessions are reserved.
154154
require.NoError(t, db.DeleteReservedSessions())
155155

156-
sessions, err = db.ListSessionsByState(StateReserved)
156+
sessions, err = db.ListSessionsByState(ctx, StateReserved)
157157
require.NoError(t, err)
158158
require.Empty(t, sessions)
159159

@@ -163,7 +163,7 @@ func TestBasicSessionStore(t *testing.T) {
163163
)
164164
require.NoError(t, err)
165165

166-
sessions, err = db.ListSessionsByState(StateReserved)
166+
sessions, err = db.ListSessionsByState(ctx, StateReserved)
167167
require.NoError(t, err)
168168
require.Equal(t, 1, len(sessions))
169169
assertEqualSessions(t, s4, sessions[0])
@@ -182,7 +182,7 @@ func TestBasicSessionStore(t *testing.T) {
182182
// database and no longer in the group ID/session ID index.
183183
require.NoError(t, db.DeleteReservedSessions())
184184

185-
sessions, err = db.ListSessionsByState(StateReserved)
185+
sessions, err = db.ListSessionsByState(ctx, StateReserved)
186186
require.NoError(t, err)
187187
require.Empty(t, sessions)
188188

session_rpcserver.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ func (s *sessionRpcServer) start(ctx context.Context) error {
103103

104104
// Start up all previously created sessions.
105105
sessions, err := s.cfg.db.ListSessionsByState(
106-
session.StateCreated,
107-
session.StateInUse,
106+
ctx, session.StateCreated, session.StateInUse,
108107
)
109108
if err != nil {
110109
return fmt.Errorf("error listing sessions: %v", err)
@@ -524,10 +523,10 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
524523
}
525524

526525
// ListSessions returns all sessions known to the session store.
527-
func (s *sessionRpcServer) ListSessions(_ context.Context,
526+
func (s *sessionRpcServer) ListSessions(ctx context.Context,
528527
_ *litrpc.ListSessionsRequest) (*litrpc.ListSessionsResponse, error) {
529528

530-
sessions, err := s.cfg.db.ListAllSessions()
529+
sessions, err := s.cfg.db.ListAllSessions(ctx)
531530
if err != nil {
532531
return nil, fmt.Errorf("error fetching sessions: %v", err)
533532
}
@@ -1243,11 +1242,11 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
12431242

12441243
// ListAutopilotSessions fetches and returns all the sessions from the DB that
12451244
// are of type TypeAutopilot.
1246-
func (s *sessionRpcServer) ListAutopilotSessions(_ context.Context,
1245+
func (s *sessionRpcServer) ListAutopilotSessions(ctx context.Context,
12471246
_ *litrpc.ListAutopilotSessionsRequest) (
12481247
*litrpc.ListAutopilotSessionsResponse, error) {
12491248

1250-
sessions, err := s.cfg.db.ListSessionsByType(session.TypeAutopilot)
1249+
sessions, err := s.cfg.db.ListSessionsByType(ctx, session.TypeAutopilot)
12511250
if err != nil {
12521251
return nil, fmt.Errorf("error fetching sessions: %v", err)
12531252
}

0 commit comments

Comments
 (0)