@@ -85,15 +85,54 @@ func (db *DB) CreateSession(session *Session) error {
8585
8686 // If this is a linked session (meaning the group ID is
8787 // different from the ID) the make sure that the Group ID of
88- // this session is an ID known by the store. We can do this by
89- // checking that an entry for this ID exists in the id-to-key
90- // index .
88+ // this session is an ID known by the store. We also need to
89+ // check that all older sessions in this group have been
90+ // revoked .
9191 if session .ID != session .GroupID {
9292 _ , err = getKeyForID (sessionBucket , session .GroupID )
9393 if err != nil {
9494 return fmt .Errorf ("unknown linked session " +
9595 "%x: %w" , session .GroupID , err )
9696 }
97+
98+ // Fetch all the session IDs for this group. This will
99+ // through an error if this group does not exist.
100+ sessionIDs , err := getSessionIDs (
101+ sessionBucket , session .GroupID ,
102+ )
103+ if err != nil {
104+ return err
105+ }
106+
107+ for _ , id := range sessionIDs {
108+ keyBytes , err := getKeyForID (
109+ sessionBucket , id ,
110+ )
111+ if err != nil {
112+ return err
113+ }
114+
115+ v := sessionBucket .Get (keyBytes )
116+ if len (v ) == 0 {
117+ return ErrSessionNotFound
118+ }
119+
120+ sess , err := DeserializeSession (
121+ bytes .NewReader (v ),
122+ )
123+ if err != nil {
124+ return err
125+ }
126+
127+ // Ensure that the session is no longer active.
128+ if sess .State == StateCreated ||
129+ sess .State == StateInUse {
130+
131+ return fmt .Errorf ("session (id=%x) " +
132+ "in group %x is still active" ,
133+ sess .ID , sess .GroupID )
134+ }
135+ }
97136 }
98137
99138 // Add the mapping from session ID to session key to the ID
@@ -390,7 +429,12 @@ func (db *DB) GetSessionIDs(groupID ID) ([]ID, error) {
390429 err error
391430 )
392431 err = db .View (func (tx * bbolt.Tx ) error {
393- sessionIDs , err = getSessionIDs (tx , groupID )
432+ sessionBkt , err := getBucket (tx , sessionBucketKey )
433+ if err != nil {
434+ return err
435+ }
436+
437+ sessionIDs , err = getSessionIDs (sessionBkt , groupID )
394438
395439 return err
396440 })
@@ -419,7 +463,7 @@ func (db *DB) CheckSessionGroupPredicate(groupID ID,
419463 return err
420464 }
421465
422- sessionIDs , err := getSessionIDs (tx , groupID )
466+ sessionIDs , err := getSessionIDs (sessionBkt , groupID )
423467 if err != nil {
424468 return err
425469 }
@@ -461,14 +505,9 @@ func (db *DB) CheckSessionGroupPredicate(groupID ID,
461505}
462506
463507// getSessionIDs returns all the session IDs associated with the given group ID.
464- func getSessionIDs (tx * bbolt.Tx , groupID ID ) ([]ID , error ) {
508+ func getSessionIDs (sessionBkt * bbolt.Bucket , groupID ID ) ([]ID , error ) {
465509 var sessionIDs []ID
466510
467- sessionBkt , err := getBucket (tx , sessionBucketKey )
468- if err != nil {
469- return nil , err
470- }
471-
472511 groupIndexBkt := sessionBkt .Bucket (groupIDIndexKey )
473512 if groupIndexBkt == nil {
474513 return nil , ErrDBInitErr
@@ -486,7 +525,7 @@ func getSessionIDs(tx *bbolt.Tx, groupID ID) ([]ID, error) {
486525 groupID )
487526 }
488527
489- err = sessionIDsBkt .ForEach (func (_ ,
528+ err : = sessionIDsBkt .ForEach (func (_ ,
490529 sessionIDBytes []byte ) error {
491530
492531 var sessionID ID
0 commit comments