Skip to content

Commit d55f9a1

Browse files
committed
session: remove Session Group Predicate method
This was used to check that all linked sessions are no longer active before attempting to register an autopilot session. But this is no longer needed since this is done within NewSession.
1 parent c098e54 commit d55f9a1

File tree

4 files changed

+2
-175
lines changed

4 files changed

+2
-175
lines changed

session/interface.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,6 @@ type Store interface {
205205
// GetSessionByID fetches the session with the given ID.
206206
GetSessionByID(id ID) (*Session, error)
207207

208-
// CheckSessionGroupPredicate iterates over all the sessions in a group
209-
// and checks if each one passes the given predicate function. True is
210-
// returned if each session passes.
211-
CheckSessionGroupPredicate(groupID ID,
212-
fn func(s *Session) bool) (bool, error)
213-
214208
// DeleteReservedSessions deletes all sessions that are in the
215209
// StateReserved state.
216210
DeleteReservedSessions() error

session/store.go

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ func (db *DB) NewSession(label string, typ Type, expiry time.Time,
129129

130130
// Ensure that the session is no longer active.
131131
if sess.State == StateCreated ||
132-
sess.State == StateInUse {
132+
sess.State == StateInUse ||
133+
sess.State == StateReserved {
133134

134135
return fmt.Errorf("session (id=%x) "+
135136
"in group %x is still active",
@@ -574,65 +575,6 @@ func (db *DB) GetSessionIDs(groupID ID) ([]ID, error) {
574575
return sessionIDs, nil
575576
}
576577

577-
// CheckSessionGroupPredicate iterates over all the sessions in a group and
578-
// checks if each one passes the given predicate function. True is returned if
579-
// each session passes.
580-
//
581-
// NOTE: this is part of the Store interface.
582-
func (db *DB) CheckSessionGroupPredicate(groupID ID,
583-
fn func(s *Session) bool) (bool, error) {
584-
585-
var (
586-
pass bool
587-
errFailedPred = errors.New("session failed predicate")
588-
)
589-
err := db.View(func(tx *bbolt.Tx) error {
590-
sessionBkt, err := getBucket(tx, sessionBucketKey)
591-
if err != nil {
592-
return err
593-
}
594-
595-
sessionIDs, err := getSessionIDs(sessionBkt, groupID)
596-
if err != nil {
597-
return err
598-
}
599-
600-
// Iterate over all the sessions.
601-
for _, id := range sessionIDs {
602-
key, err := getKeyForID(sessionBkt, id)
603-
if err != nil {
604-
return err
605-
}
606-
607-
v := sessionBkt.Get(key)
608-
if len(v) == 0 {
609-
return ErrSessionNotFound
610-
}
611-
612-
session, err := DeserializeSession(bytes.NewReader(v))
613-
if err != nil {
614-
return err
615-
}
616-
617-
if !fn(session) {
618-
return errFailedPred
619-
}
620-
}
621-
622-
pass = true
623-
624-
return nil
625-
})
626-
if errors.Is(err, errFailedPred) {
627-
return pass, nil
628-
}
629-
if err != nil {
630-
return pass, err
631-
}
632-
633-
return pass, nil
634-
}
635-
636578
// getSessionIDs returns all the session IDs associated with the given group ID.
637579
func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) {
638580
var sessionIDs []ID

session/store_test.go

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package session
22

33
import (
4-
"strings"
54
"testing"
65
"time"
76

@@ -283,97 +282,6 @@ func TestLinkedSessions(t *testing.T) {
283282
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
284283
}
285284

286-
// TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate
287-
// method correctly checks if each session in a group passes a predicate.
288-
func TestCheckSessionGroupPredicate(t *testing.T) {
289-
t.Parallel()
290-
291-
// Set up a new DB.
292-
clock := clock.NewTestClock(testTime)
293-
db, err := NewDB(t.TempDir(), "test.db", clock)
294-
require.NoError(t, err)
295-
t.Cleanup(func() {
296-
_ = db.Close()
297-
})
298-
299-
// We will use the Label of the Session to test that the predicate
300-
// function is checked correctly.
301-
302-
// Add a new session to the DB.
303-
s1 := createSession(t, db, "label 1")
304-
305-
// Check that the group passes against an appropriate predicate.
306-
ok, err := db.CheckSessionGroupPredicate(
307-
s1.GroupID, func(s *Session) bool {
308-
return strings.Contains(s.Label, "label 1")
309-
},
310-
)
311-
require.NoError(t, err)
312-
require.True(t, ok)
313-
314-
// Check that the group fails against an appropriate predicate.
315-
ok, err = db.CheckSessionGroupPredicate(
316-
s1.GroupID, func(s *Session) bool {
317-
return strings.Contains(s.Label, "label 2")
318-
},
319-
)
320-
require.NoError(t, err)
321-
require.False(t, ok)
322-
323-
// Revoke the first session.
324-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
325-
326-
// Add a new session to the same group as the first one.
327-
_ = createSession(t, db, "label 2", withLinkedGroupID(&s1.GroupID))
328-
329-
// Check that the group passes against an appropriate predicate.
330-
ok, err = db.CheckSessionGroupPredicate(
331-
s1.GroupID, func(s *Session) bool {
332-
return strings.Contains(s.Label, "label")
333-
},
334-
)
335-
require.NoError(t, err)
336-
require.True(t, ok)
337-
338-
// Check that the group fails against an appropriate predicate.
339-
ok, err = db.CheckSessionGroupPredicate(
340-
s1.GroupID, func(s *Session) bool {
341-
return strings.Contains(s.Label, "label 1")
342-
},
343-
)
344-
require.NoError(t, err)
345-
require.False(t, ok)
346-
347-
// Add a new session that is not linked to the first one.
348-
s3 := createSession(t, db, "completely different")
349-
350-
// Ensure that the first group is unaffected.
351-
ok, err = db.CheckSessionGroupPredicate(
352-
s1.GroupID, func(s *Session) bool {
353-
return strings.Contains(s.Label, "label")
354-
},
355-
)
356-
require.NoError(t, err)
357-
require.True(t, ok)
358-
359-
// And that the new session is evaluated separately.
360-
ok, err = db.CheckSessionGroupPredicate(
361-
s3.GroupID, func(s *Session) bool {
362-
return strings.Contains(s.Label, "label")
363-
},
364-
)
365-
require.NoError(t, err)
366-
require.False(t, ok)
367-
368-
ok, err = db.CheckSessionGroupPredicate(
369-
s3.GroupID, func(s *Session) bool {
370-
return strings.Contains(s.Label, "different")
371-
},
372-
)
373-
require.NoError(t, err)
374-
require.True(t, ok)
375-
}
376-
377285
type testSessionOpts struct {
378286
groupID *ID
379287
sessType Type

session_rpcserver.go

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -865,23 +865,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
865865
"group %x", groupSess.ID, groupSess.GroupID)
866866
}
867867

868-
// Now we need to check that all the sessions in the group are
869-
// no longer active.
870-
ok, err := s.cfg.db.CheckSessionGroupPredicate(
871-
groupID, func(s *session.Session) bool {
872-
return s.State == session.StateRevoked ||
873-
s.State == session.StateExpired
874-
},
875-
)
876-
if err != nil {
877-
return nil, err
878-
}
879-
880-
if !ok {
881-
return nil, fmt.Errorf("a linked session in group "+
882-
"%x is still active", groupID)
883-
}
884-
885868
linkedGroupID = &groupID
886869
linkedGroupSession = groupSess
887870

0 commit comments

Comments
 (0)