Skip to content

Commit 9642ce1

Browse files
committed
session+firewall: pass context to GetSessionByID
1 parent aa5674c commit 9642ce1

File tree

8 files changed

+24
-17
lines changed

8 files changed

+24
-17
lines changed

firewall/privacy_mapper.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context,
190190
uri string, req proto.Message, sessionID session.ID) (proto.Message,
191191
error) {
192192

193-
session, err := p.sessionDB.GetSessionByID(sessionID)
193+
session, err := p.sessionDB.GetSessionByID(ctx, sessionID)
194194
if err != nil {
195195
return nil, err
196196
}
@@ -220,7 +220,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context,
220220
func (p *PrivacyMapper) replaceOutgoingResponse(ctx context.Context, uri string,
221221
resp proto.Message, sessionID session.ID) (proto.Message, error) {
222222

223-
session, err := p.sessionDB.GetSessionByID(sessionID)
223+
session, err := p.sessionDB.GetSessionByID(ctx, sessionID)
224224
if err != nil {
225225
return nil, err
226226
}

firewall/rule_enforcer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string,
386386
return nil, err
387387
}
388388

389-
session, err := r.sessionDB.GetSessionByID(sessionID)
389+
session, err := r.sessionDB.GetSessionByID(ctx, sessionID)
390390
if err != nil {
391391
return nil, err
392392
}

firewalldb/interface.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
package firewalldb
22

3-
import "github.com/lightninglabs/lightning-terminal/session"
3+
import (
4+
"context"
5+
6+
"github.com/lightninglabs/lightning-terminal/session"
7+
)
48

59
// SessionDB is an interface that abstracts the database operations needed for
610
// the privacy mapper to function.
711
type SessionDB interface {
812
session.IDToGroupIndex
913

1014
// GetSessionByID returns the session for a specific id.
11-
GetSessionByID(session.ID) (*session.Session, error)
15+
GetSessionByID(context.Context, session.ID) (*session.Session, error)
1216
}

firewalldb/mock.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package firewalldb
22

33
import (
4+
"context"
45
"fmt"
56

67
"github.com/lightninglabs/lightning-terminal/session"
@@ -54,8 +55,8 @@ func (m *mockSessionDB) GetSessionIDs(groupID session.ID) ([]session.ID, error)
5455
}
5556

5657
// GetSessionByID returns the session for a specific id.
57-
func (m *mockSessionDB) GetSessionByID(sessionID session.ID) (*session.Session,
58-
error) {
58+
func (m *mockSessionDB) GetSessionByID(_ context.Context,
59+
sessionID session.ID) (*session.Session, error) {
5960

6061
s, ok := m.sessionToGroupID[sessionID]
6162
if !ok {

session/interface.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ type Store interface {
298298
remotePubKey *btcec.PublicKey) error
299299

300300
// GetSessionByID fetches the session with the given ID.
301-
GetSessionByID(id ID) (*Session, error)
301+
GetSessionByID(ctx context.Context, id ID) (*Session, error)
302302

303303
// DeleteReservedSessions deletes all sessions that are in the
304304
// StateReserved state.

session/kvdb_store.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,9 @@ func (db *BoltStore) ShiftState(id ID, dest State) error {
569569
// GetSessionByID fetches the session with the given ID.
570570
//
571571
// NOTE: this is part of the Store interface.
572-
func (db *BoltStore) GetSessionByID(id ID) (*Session, error) {
572+
func (db *BoltStore) GetSessionByID(_ context.Context, id ID) (*Session,
573+
error) {
574+
573575
var session *Session
574576
err := db.View(func(tx *bbolt.Tx) error {
575577
sessionBucket, err := getBucket(tx, sessionBucketKey)

session/store_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ func TestBasicSessionStore(t *testing.T) {
2323
db := NewTestDB(t, clock)
2424

2525
// Try fetch a session that doesn't exist yet.
26-
_, err := db.GetSessionByID(ID{1, 3, 4, 4})
26+
_, err := db.GetSessionByID(ctx, ID{1, 3, 4, 4})
2727
require.ErrorIs(t, err, ErrSessionNotFound)
2828

2929
// Reserve a session. This should succeed.
3030
s1, err := reserveSession(db, "session 1")
3131
require.NoError(t, err)
3232

3333
// Show that the session starts in the reserved state.
34-
s1, err = db.GetSessionByID(s1.ID)
34+
s1, err = db.GetSessionByID(ctx, s1.ID)
3535
require.NoError(t, err)
3636
require.Equal(t, StateReserved, s1.State)
3737

@@ -40,7 +40,7 @@ func TestBasicSessionStore(t *testing.T) {
4040
require.NoError(t, err)
4141

4242
// Show that the session is now in the created state.
43-
s1, err = db.GetSessionByID(s1.ID)
43+
s1, err = db.GetSessionByID(ctx, s1.ID)
4444
require.NoError(t, err)
4545
require.Equal(t, StateCreated, s1.State)
4646

@@ -80,7 +80,7 @@ func TestBasicSessionStore(t *testing.T) {
8080
require.NoError(t, err)
8181
assertEqualSessions(t, s, session)
8282

83-
session, err = db.GetSessionByID(s.ID)
83+
session, err = db.GetSessionByID(ctx, s.ID)
8484
require.NoError(t, err)
8585
assertEqualSessions(t, s, session)
8686
}
@@ -386,7 +386,7 @@ func createSession(t *testing.T, db Store, label string,
386386
err = db.ShiftState(s.ID, StateCreated)
387387
require.NoError(t, err)
388388

389-
s, err = db.GetSessionByID(s.ID)
389+
s, err = db.GetSessionByID(context.Background(), s.ID)
390390
require.NoError(t, err)
391391

392392
return s

session_rpcserver.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ func (s *sessionRpcServer) AddSession(ctx context.Context,
335335

336336
// Re-fetch the session to get the latest state of it before marshaling
337337
// it.
338-
sess, err = s.cfg.db.GetSessionByID(sess.ID)
338+
sess, err = s.cfg.db.GetSessionByID(ctx, sess.ID)
339339
if err != nil {
340340
return nil, fmt.Errorf("error fetching session: %v", err)
341341
}
@@ -867,7 +867,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
867867
copy(groupID[:], req.LinkedGroupId)
868868

869869
// Check that the group actually does exist.
870-
groupSess, err := s.cfg.db.GetSessionByID(groupID)
870+
groupSess, err := s.cfg.db.GetSessionByID(ctx, groupID)
871871
if err != nil {
872872
return nil, err
873873
}
@@ -1252,7 +1252,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
12521252

12531253
// Re-fetch the session to get the latest state of it before marshaling
12541254
// it.
1255-
sess, err = s.cfg.db.GetSessionByID(sess.ID)
1255+
sess, err = s.cfg.db.GetSessionByID(ctx, sess.ID)
12561256
if err != nil {
12571257
return nil, fmt.Errorf("error fetching session: %v", err)
12581258
}

0 commit comments

Comments
 (0)