Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions accounts/store_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ type SQLStore struct {
// in order to implement all its CRUD logic.
db BatchedSQLQueries

// DB represents the underlying database connection.
*sql.DB
// BaseDB represents the underlying database connection.
*db.BaseDB

clock clock.Clock
}
Expand All @@ -83,9 +83,9 @@ func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore {
)

return &SQLStore{
db: executor,
DB: sqlDB.DB,
clock: clock,
db: executor,
BaseDB: sqlDB,
clock: clock,
}
}

Expand Down
4 changes: 2 additions & 2 deletions session/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ type Store interface {
ListSessionsByType(ctx context.Context, t Type) ([]*Session, error)

// ListSessionsByState returns all sessions currently known to the store
// that are in the given states.
ListSessionsByState(ctx context.Context, state ...State) ([]*Session,
// that are in the given state.
ListSessionsByState(ctx context.Context, state State) ([]*Session,
error)

// UpdateSessionRemotePubKey can be used to add the given remote pub key
Expand Down
12 changes: 3 additions & 9 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,20 +370,14 @@ func (db *BoltStore) ListSessionsByType(_ context.Context, t Type) ([]*Session,
}

// ListSessionsByState returns all sessions currently known to the store that
// are in the given states.
// are in the given state.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListSessionsByState(_ context.Context, states ...State) (
func (db *BoltStore) ListSessionsByState(_ context.Context, state State) (
[]*Session, error) {

return db.listSessions(func(s *Session) bool {
for _, state := range states {
if s.State == state {
return true
}
}

return false
return s.State == state
})
}

Expand Down
13 changes: 1 addition & 12 deletions session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,6 @@ func TestBasicSessionStore(t *testing.T) {
assertEqualSessions(t, s2, sessions[0])
assertEqualSessions(t, s3, sessions[1])

sessions, err = db.ListSessionsByState(ctx, StateCreated, StateRevoked)
require.NoError(t, err)
require.Equal(t, 3, len(sessions))
assertEqualSessions(t, s1, sessions[0])
assertEqualSessions(t, s2, sessions[1])
assertEqualSessions(t, s3, sessions[2])

sessions, err = db.ListSessionsByState(ctx)
require.NoError(t, err)
require.Empty(t, sessions)

sessions, err = db.ListSessionsByState(ctx, StateReserved)
require.NoError(t, err)
require.Empty(t, sessions)
Expand Down Expand Up @@ -373,7 +362,7 @@ func reserveSession(db Store, label string,

return db.NewSession(
context.Background(), label, opts.sessType,
time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(9999, 1, 1, 0, 0, 0, 0, time.UTC),
"foo.bar.baz:1234",
WithDevServer(),
WithPrivacy(PrivacyFlags{ClearPubkeys}),
Expand Down
14 changes: 12 additions & 2 deletions session_rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,23 @@ func (s *sessionRpcServer) start(ctx context.Context) error {
}

// Start up all previously created sessions.
sessions, err := s.cfg.db.ListSessionsByState(
ctx, session.StateCreated, session.StateInUse,
sessions, err := s.cfg.db.ListSessionsByState(ctx, session.StateCreated)
if err != nil {
return fmt.Errorf("error listing sessions: %v", err)
}

// For backwards compatibility, we will also resume sessions that are in
// the InUse state even though we no longer put sessions into this
// state.
inUseSessions, err := s.cfg.db.ListSessionsByState(
ctx, session.StateInUse,
)
if err != nil {
return fmt.Errorf("error listing sessions: %v", err)
}

sessions = append(sessions, inUseSessions...)

for _, sess := range sessions {
key := sess.LocalPublicKey.SerializeCompressed()

Expand Down
Loading