diff --git a/accounts/store_sql.go b/accounts/store_sql.go index 4b34e3f7d..830f16587 100644 --- a/accounts/store_sql.go +++ b/accounts/store_sql.go @@ -67,8 +67,8 @@ type SQLStore struct { // in order to implement all its CRUD logic. db BatchedSQLQueries - // DB represents the underlying database connection. - *sql.DB + // BaseDB represents the underlying database connection. + *db.BaseDB clock clock.Clock } @@ -83,9 +83,9 @@ func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore { ) return &SQLStore{ - db: executor, - DB: sqlDB.DB, - clock: clock, + db: executor, + BaseDB: sqlDB, + clock: clock, } } diff --git a/session/interface.go b/session/interface.go index a861f7e34..1ce3854f3 100644 --- a/session/interface.go +++ b/session/interface.go @@ -288,8 +288,8 @@ type Store interface { ListSessionsByType(ctx context.Context, t Type) ([]*Session, error) // ListSessionsByState returns all sessions currently known to the store - // that are in the given states. - ListSessionsByState(ctx context.Context, state ...State) ([]*Session, + // that are in the given state. + ListSessionsByState(ctx context.Context, state State) ([]*Session, error) // UpdateSessionRemotePubKey can be used to add the given remote pub key diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 69b2eac87..00524e3d8 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -370,20 +370,14 @@ func (db *BoltStore) ListSessionsByType(_ context.Context, t Type) ([]*Session, } // ListSessionsByState returns all sessions currently known to the store that -// are in the given states. +// are in the given state. // // NOTE: this is part of the Store interface. -func (db *BoltStore) ListSessionsByState(_ context.Context, states ...State) ( +func (db *BoltStore) ListSessionsByState(_ context.Context, state State) ( []*Session, error) { return db.listSessions(func(s *Session) bool { - for _, state := range states { - if s.State == state { - return true - } - } - - return false + return s.State == state }) } diff --git a/session/store_test.go b/session/store_test.go index a3c6c4289..a853a6133 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -134,17 +134,6 @@ func TestBasicSessionStore(t *testing.T) { assertEqualSessions(t, s2, sessions[0]) assertEqualSessions(t, s3, sessions[1]) - sessions, err = db.ListSessionsByState(ctx, StateCreated, StateRevoked) - require.NoError(t, err) - require.Equal(t, 3, len(sessions)) - assertEqualSessions(t, s1, sessions[0]) - assertEqualSessions(t, s2, sessions[1]) - assertEqualSessions(t, s3, sessions[2]) - - sessions, err = db.ListSessionsByState(ctx) - require.NoError(t, err) - require.Empty(t, sessions) - sessions, err = db.ListSessionsByState(ctx, StateReserved) require.NoError(t, err) require.Empty(t, sessions) @@ -373,7 +362,7 @@ func reserveSession(db Store, label string, return db.NewSession( context.Background(), label, opts.sessType, - time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(9999, 1, 1, 0, 0, 0, 0, time.UTC), "foo.bar.baz:1234", WithDevServer(), WithPrivacy(PrivacyFlags{ClearPubkeys}), diff --git a/session_rpcserver.go b/session_rpcserver.go index 7362d8c7f..c185a3a94 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -102,13 +102,23 @@ func (s *sessionRpcServer) start(ctx context.Context) error { } // Start up all previously created sessions. - sessions, err := s.cfg.db.ListSessionsByState( - ctx, session.StateCreated, session.StateInUse, + sessions, err := s.cfg.db.ListSessionsByState(ctx, session.StateCreated) + if err != nil { + return fmt.Errorf("error listing sessions: %v", err) + } + + // For backwards compatibility, we will also resume sessions that are in + // the InUse state even though we no longer put sessions into this + // state. + inUseSessions, err := s.cfg.db.ListSessionsByState( + ctx, session.StateInUse, ) if err != nil { return fmt.Errorf("error listing sessions: %v", err) } + sessions = append(sessions, inUseSessions...) + for _, sess := range sessions { key := sess.LocalPublicKey.SerializeCompressed()