Skip to content

Commit 64dba89

Browse files
authored
Merge pull request #992 from ellemouton/sql19Sessions11
[sql-19] sessions: last misc prep commits
2 parents 66b0f15 + a82cda6 commit 64dba89

File tree

5 files changed

+23
-30
lines changed

5 files changed

+23
-30
lines changed

accounts/store_sql.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ type SQLStore struct {
6767
// in order to implement all its CRUD logic.
6868
db BatchedSQLQueries
6969

70-
// DB represents the underlying database connection.
71-
*sql.DB
70+
// BaseDB represents the underlying database connection.
71+
*db.BaseDB
7272

7373
clock clock.Clock
7474
}
@@ -83,9 +83,9 @@ func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore {
8383
)
8484

8585
return &SQLStore{
86-
db: executor,
87-
DB: sqlDB.DB,
88-
clock: clock,
86+
db: executor,
87+
BaseDB: sqlDB,
88+
clock: clock,
8989
}
9090
}
9191

session/interface.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,8 @@ type Store interface {
288288
ListSessionsByType(ctx context.Context, t Type) ([]*Session, error)
289289

290290
// ListSessionsByState returns all sessions currently known to the store
291-
// that are in the given states.
292-
ListSessionsByState(ctx context.Context, state ...State) ([]*Session,
291+
// that are in the given state.
292+
ListSessionsByState(ctx context.Context, state State) ([]*Session,
293293
error)
294294

295295
// UpdateSessionRemotePubKey can be used to add the given remote pub key

session/kvdb_store.go

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -370,20 +370,14 @@ func (db *BoltStore) ListSessionsByType(_ context.Context, t Type) ([]*Session,
370370
}
371371

372372
// ListSessionsByState returns all sessions currently known to the store that
373-
// are in the given states.
373+
// are in the given state.
374374
//
375375
// NOTE: this is part of the Store interface.
376-
func (db *BoltStore) ListSessionsByState(_ context.Context, states ...State) (
376+
func (db *BoltStore) ListSessionsByState(_ context.Context, state State) (
377377
[]*Session, error) {
378378

379379
return db.listSessions(func(s *Session) bool {
380-
for _, state := range states {
381-
if s.State == state {
382-
return true
383-
}
384-
}
385-
386-
return false
380+
return s.State == state
387381
})
388382
}
389383

session/store_test.go

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,6 @@ func TestBasicSessionStore(t *testing.T) {
134134
assertEqualSessions(t, s2, sessions[0])
135135
assertEqualSessions(t, s3, sessions[1])
136136

137-
sessions, err = db.ListSessionsByState(ctx, StateCreated, StateRevoked)
138-
require.NoError(t, err)
139-
require.Equal(t, 3, len(sessions))
140-
assertEqualSessions(t, s1, sessions[0])
141-
assertEqualSessions(t, s2, sessions[1])
142-
assertEqualSessions(t, s3, sessions[2])
143-
144-
sessions, err = db.ListSessionsByState(ctx)
145-
require.NoError(t, err)
146-
require.Empty(t, sessions)
147-
148137
sessions, err = db.ListSessionsByState(ctx, StateReserved)
149138
require.NoError(t, err)
150139
require.Empty(t, sessions)
@@ -373,7 +362,7 @@ func reserveSession(db Store, label string,
373362

374363
return db.NewSession(
375364
context.Background(), label, opts.sessType,
376-
time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC),
365+
time.Date(9999, 1, 1, 0, 0, 0, 0, time.UTC),
377366
"foo.bar.baz:1234",
378367
WithDevServer(),
379368
WithPrivacy(PrivacyFlags{ClearPubkeys}),

session_rpcserver.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,23 @@ func (s *sessionRpcServer) start(ctx context.Context) error {
102102
}
103103

104104
// Start up all previously created sessions.
105-
sessions, err := s.cfg.db.ListSessionsByState(
106-
ctx, session.StateCreated, session.StateInUse,
105+
sessions, err := s.cfg.db.ListSessionsByState(ctx, session.StateCreated)
106+
if err != nil {
107+
return fmt.Errorf("error listing sessions: %v", err)
108+
}
109+
110+
// For backwards compatibility, we will also resume sessions that are in
111+
// the InUse state even though we no longer put sessions into this
112+
// state.
113+
inUseSessions, err := s.cfg.db.ListSessionsByState(
114+
ctx, session.StateInUse,
107115
)
108116
if err != nil {
109117
return fmt.Errorf("error listing sessions: %v", err)
110118
}
111119

120+
sessions = append(sessions, inUseSessions...)
121+
112122
for _, sess := range sessions {
113123
key := sess.LocalPublicKey.SerializeCompressed()
114124

0 commit comments

Comments
 (0)