Skip to content

Commit aa7b1db

Browse files
committed
session: pass contexts through to all IDToGroupIndex methods
1 parent 0d6eefa commit aa7b1db

File tree

6 files changed

+25
-18
lines changed

6 files changed

+25
-18
lines changed

firewalldb/actions.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ func (db *DB) ListSessionActions(sessionID session.ID,
391391
// pass the filterFn requirements.
392392
//
393393
// TODO: update to allow for pagination.
394-
func (db *DB) ListGroupActions(_ context.Context, groupID session.ID,
394+
func (db *DB) ListGroupActions(ctx context.Context, groupID session.ID,
395395
filterFn ListActionsFilterFn) ([]*Action, error) {
396396

397397
if filterFn == nil {
@@ -400,7 +400,7 @@ func (db *DB) ListGroupActions(_ context.Context, groupID session.ID,
400400
}
401401
}
402402

403-
sessionIDs, err := db.sessionIDIndex.GetSessionIDs(groupID)
403+
sessionIDs, err := db.sessionIDIndex.GetSessionIDs(ctx, groupID)
404404
if err != nil {
405405
return nil, err
406406
}

firewalldb/mock.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ func (m *mockSessionDB) AddPair(sessionID, groupID session.ID) {
3434
}
3535

3636
// GetGroupID returns the group ID for the given session ID.
37-
func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) {
37+
func (m *mockSessionDB) GetGroupID(_ context.Context, sessionID session.ID) (
38+
session.ID, error) {
39+
3840
id, ok := m.sessionToGroupID[sessionID]
3941
if !ok {
4042
return session.ID{}, fmt.Errorf("no group ID found for " +
@@ -45,7 +47,9 @@ func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) {
4547
}
4648

4749
// GetSessionIDs returns the set of session IDs that are in the group
48-
func (m *mockSessionDB) GetSessionIDs(groupID session.ID) ([]session.ID, error) {
50+
func (m *mockSessionDB) GetSessionIDs(_ context.Context, groupID session.ID) (
51+
[]session.ID, error) {
52+
4953
ids, ok := m.groupToSessionIDs[groupID]
5054
if !ok {
5155
return nil, fmt.Errorf("no session IDs found for group ID")

session/interface.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,11 @@ func WithMacaroonRecipe(caveats []macaroon.Caveat, perms []bakery.Op) Option {
261261
// IDToGroupIndex defines an interface for the session ID to group ID index.
262262
type IDToGroupIndex interface {
263263
// GetGroupID will return the group ID for the given session ID.
264-
GetGroupID(sessionID ID) (ID, error)
264+
GetGroupID(ctx context.Context, sessionID ID) (ID, error)
265265

266266
// GetSessionIDs will return the set of session IDs that are in the
267267
// group with the given ID.
268-
GetSessionIDs(groupID ID) ([]ID, error)
268+
GetSessionIDs(ctx context.Context, groupID ID) ([]ID, error)
269269
}
270270

271271
// Store is the interface a persistent storage must implement for storing and

session/kvdb_store.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ func getUnusedIDAndKeyPair(bucket *bbolt.Bucket) (ID, *btcec.PrivateKey,
624624
// GetGroupID will return the group ID for the given session ID.
625625
//
626626
// NOTE: this is part of the IDToGroupIndex interface.
627-
func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {
627+
func (db *BoltStore) GetGroupID(_ context.Context, sessionID ID) (ID, error) {
628628
var groupID ID
629629
err := db.View(func(tx *bbolt.Tx) error {
630630
sessionBkt, err := getBucket(tx, sessionBucketKey)
@@ -664,7 +664,9 @@ func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {
664664
// group with the given ID.
665665
//
666666
// NOTE: this is part of the IDToGroupIndex interface.
667-
func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
667+
func (db *BoltStore) GetSessionIDs(_ context.Context, groupID ID) ([]ID,
668+
error) {
669+
668670
var (
669671
sessionIDs []ID
670672
err error

session/store_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ func TestBasicSessionStore(t *testing.T) {
170170

171171
// Show that the group ID/session ID index has also been populated with
172172
// this session.
173-
groupID, err := db.GetGroupID(s4.ID)
173+
groupID, err := db.GetGroupID(ctx, s4.ID)
174174
require.NoError(t, err)
175175
require.Equal(t, s1.ID, groupID)
176176

177-
sessIDs, err := db.GetSessionIDs(s4.GroupID)
177+
sessIDs, err := db.GetSessionIDs(ctx, s4.GroupID)
178178
require.NoError(t, err)
179179
require.ElementsMatch(t, []ID{s4.ID, s1.ID}, sessIDs)
180180

@@ -186,11 +186,11 @@ func TestBasicSessionStore(t *testing.T) {
186186
require.NoError(t, err)
187187
require.Empty(t, sessions)
188188

189-
_, err = db.GetGroupID(s4.ID)
189+
_, err = db.GetGroupID(ctx, s4.ID)
190190
require.ErrorIs(t, err, ErrUnknownGroup)
191191

192192
// Only session 1 should remain in this group.
193-
sessIDs, err = db.GetSessionIDs(s4.GroupID)
193+
sessIDs, err = db.GetSessionIDs(ctx, s4.GroupID)
194194
require.NoError(t, err)
195195
require.ElementsMatch(t, []ID{s1.ID}, sessIDs)
196196
}
@@ -235,6 +235,7 @@ func TestLinkingSessions(t *testing.T) {
235235
// of the GetGroupID and GetSessionIDs methods.
236236
func TestLinkedSessions(t *testing.T) {
237237
t.Parallel()
238+
ctx := context.Background()
238239

239240
// Set up a new DB.
240241
clock := clock.NewTestClock(testTime)
@@ -254,14 +255,14 @@ func TestLinkedSessions(t *testing.T) {
254255

255256
// Assert that the session ID to group ID index works as expected.
256257
for _, s := range []*Session{s1, s2, s3} {
257-
groupID, err := db.GetGroupID(s.ID)
258+
groupID, err := db.GetGroupID(ctx, s.ID)
258259
require.NoError(t, err)
259260
require.Equal(t, s1.ID, groupID)
260261
require.Equal(t, s.GroupID, groupID)
261262
}
262263

263264
// Assert that the group ID to session ID index works as expected.
264-
sIDs, err := db.GetSessionIDs(s1.GroupID)
265+
sIDs, err := db.GetSessionIDs(ctx, s1.GroupID)
265266
require.NoError(t, err)
266267
require.EqualValues(t, []ID{s1.ID, s2.ID, s3.ID}, sIDs)
267268

@@ -274,14 +275,14 @@ func TestLinkedSessions(t *testing.T) {
274275

275276
// Assert that the session ID to group ID index works as expected.
276277
for _, s := range []*Session{s4, s5} {
277-
groupID, err := db.GetGroupID(s.ID)
278+
groupID, err := db.GetGroupID(ctx, s.ID)
278279
require.NoError(t, err)
279280
require.Equal(t, s4.ID, groupID)
280281
require.Equal(t, s.GroupID, groupID)
281282
}
282283

283284
// Assert that the group ID to session ID index works as expected.
284-
sIDs, err = db.GetSessionIDs(s5.GroupID)
285+
sIDs, err = db.GetSessionIDs(ctx, s5.GroupID)
285286
require.NoError(t, err)
286287
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
287288
}

session_rpcserver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context,
587587

588588
// PrivacyMapConversion can be used map real values to their pseudo counterpart
589589
// and vice versa.
590-
func (s *sessionRpcServer) PrivacyMapConversion(_ context.Context,
590+
func (s *sessionRpcServer) PrivacyMapConversion(ctx context.Context,
591591
req *litrpc.PrivacyMapConversionRequest) (
592592
*litrpc.PrivacyMapConversionResponse, error) {
593593

@@ -606,7 +606,7 @@ func (s *sessionRpcServer) PrivacyMapConversion(_ context.Context,
606606
return nil, err
607607
}
608608

609-
groupID, err = s.cfg.db.GetGroupID(sessionID)
609+
groupID, err = s.cfg.db.GetGroupID(ctx, sessionID)
610610
if err != nil {
611611
return nil, err
612612
}

0 commit comments

Comments
 (0)