Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions config_dev.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) {
acctStore := accounts.NewSQLStore(
sqlStore.BaseDB, queries, clock,
)
sessStore := session.NewSQLStore(legacySqlStore.BaseDB, clock)
sessStore := session.NewSQLStore(
sqlStore.BaseDB, queries, clock,
)
firewallStore := firewalldb.NewSQLDB(
legacySqlStore.BaseDB, clock,
)
Expand Down Expand Up @@ -181,7 +183,9 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) {
acctStore := accounts.NewSQLStore(
sqlStore.BaseDB, queries, clock,
)
sessStore := session.NewSQLStore(legacySqlStore.BaseDB, clock)
sessStore := session.NewSQLStore(
sqlStore.BaseDB, queries, clock,
)
firewallStore := firewalldb.NewSQLDB(
legacySqlStore.BaseDB, clock,
)
Expand Down
8 changes: 6 additions & 2 deletions firewalldb/sql_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,8 @@ func createPrivacyPairs(t *testing.T, ctx context.Context,
sessSQLStore, ok := sessionStore.(*session.SQLStore)
require.True(t, ok)

queries := sqlc.NewForType(sessSQLStore, sessSQLStore.BackendType)

for i := range numSessions {
sess, err := sessionStore.NewSession(
ctx, fmt.Sprintf("session-%d", i),
Expand All @@ -806,7 +808,7 @@ func createPrivacyPairs(t *testing.T, ctx context.Context,
require.NoError(t, err)

groupID := sess.GroupID
sqlGroupID, err := sessSQLStore.GetSessionIDByAlias(
sqlGroupID, err := queries.GetSessionIDByAlias(
ctx, groupID[:],
)
require.NoError(t, err)
Expand Down Expand Up @@ -850,6 +852,8 @@ func randomPrivacyPairs(t *testing.T, ctx context.Context,
sessSQLStore, ok := sessionStore.(*session.SQLStore)
require.True(t, ok)

queries := sqlc.NewForType(sessSQLStore, sessSQLStore.BackendType)

for i := range numSessions {
sess, err := sessionStore.NewSession(
ctx, fmt.Sprintf("session-%d", i),
Expand All @@ -859,7 +863,7 @@ func randomPrivacyPairs(t *testing.T, ctx context.Context,
require.NoError(t, err)

groupID := sess.GroupID
sqlGroupID, err := sessSQLStore.GetSessionIDByAlias(
sqlGroupID, err := queries.GetSessionIDByAlias(
ctx, groupID[:],
)
require.NoError(t, err)
Expand Down
17 changes: 6 additions & 11 deletions session/sql_migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@ package session

import (
"context"
"database/sql"
"fmt"
"testing"
"time"

"github.com/lightninglabs/lightning-terminal/accounts"
"github.com/lightninglabs/lightning-terminal/db"
"github.com/lightninglabs/lightning-terminal/db/sqlc"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/macaroons"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/lightningnetwork/lnd/sqldb/v2"
"github.com/stretchr/testify/require"
"go.etcd.io/bbolt"
"golang.org/x/exp/rand"
Expand All @@ -38,7 +37,7 @@ func TestSessionsStoreMigration(t *testing.T) {
}

makeSQLDB := func(t *testing.T, acctStore accounts.Store) (*SQLStore,
*db.TransactionExecutor[SQLQueries]) {
*SQLQueriesExecutor[SQLQueries]) {

// Create a sql store with a linked account store.
testDBStore := NewTestDBWithAccounts(t, clock, acctStore)
Expand All @@ -48,13 +47,9 @@ func TestSessionsStoreMigration(t *testing.T) {

baseDB := store.BaseDB

genericExecutor := db.NewTransactionExecutor(
baseDB, func(tx *sql.Tx) SQLQueries {
return baseDB.WithTx(tx)
},
)
queries := sqlc.NewForType(baseDB, baseDB.BackendType)

return store, genericExecutor
return store, NewSQLQueriesExecutor(baseDB, queries)
}

// assertMigrationResults asserts that the sql store contains the
Expand Down Expand Up @@ -375,7 +370,7 @@ func TestSessionsStoreMigration(t *testing.T) {
return MigrateSessionStoreToSQL(
ctx, kvStore.DB, tx,
)
},
}, sqldb.NoOpReset,
)
require.NoError(t, err)

Expand Down
64 changes: 43 additions & 21 deletions session/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ import (
"github.com/lightninglabs/lightning-terminal/db/sqlc"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/sqldb/v2"
"gopkg.in/macaroon-bakery.v2/bakery"
"gopkg.in/macaroon.v2"
)

// SQLQueries is a subset of the sqlc.Queries interface that can be used to
// interact with session related tables.
type SQLQueries interface {
sqldb.BaseQuerier

GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error)
GetSessionByID(ctx context.Context, id int64) (sqlc.Session, error)
GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]sqlc.Session, error)
Expand Down Expand Up @@ -51,12 +54,13 @@ type SQLQueries interface {

var _ Store = (*SQLStore)(nil)

// BatchedSQLQueries is a version of the SQLQueries that's capable of batched
// database operations.
// BatchedSQLQueries combines the SQLQueries interface with the BatchedTx
// interface, allowing for multiple queries to be executed in single SQL
// transaction.
type BatchedSQLQueries interface {
SQLQueries

db.BatchedTx[SQLQueries]
sqldb.BatchedTx[SQLQueries]
}

// SQLStore represents a storage backend.
Expand All @@ -66,19 +70,37 @@ type SQLStore struct {
db BatchedSQLQueries

// BaseDB represents the underlying database connection.
*db.BaseDB
*sqldb.BaseDB

clock clock.Clock
}

// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
// storage backend.
func NewSQLStore(sqlDB *db.BaseDB, clock clock.Clock) *SQLStore {
executor := db.NewTransactionExecutor(
sqlDB, func(tx *sql.Tx) SQLQueries {
return sqlDB.WithTx(tx)
type SQLQueriesExecutor[T sqldb.BaseQuerier] struct {
*sqldb.TransactionExecutor[T]

SQLQueries
}

func NewSQLQueriesExecutor(baseDB *sqldb.BaseDB,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we don't need to export those?

queries *sqlc.Queries) *SQLQueriesExecutor[SQLQueries] {

executor := sqldb.NewTransactionExecutor(
baseDB, func(tx *sql.Tx) SQLQueries {
return queries.WithTx(tx)
},
)
return &SQLQueriesExecutor[SQLQueries]{
TransactionExecutor: executor,
SQLQueries: queries,
}
}

// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
// storage backend.
func NewSQLStore(sqlDB *sqldb.BaseDB, queries *sqlc.Queries,
clock clock.Clock) *SQLStore {

executor := NewSQLQueriesExecutor(sqlDB, queries)

return &SQLStore{
db: executor,
Expand Down Expand Up @@ -281,7 +303,7 @@ func (s *SQLStore) NewSession(ctx context.Context, label string, typ Type,
}

return nil
})
}, sqldb.NoOpReset)
if err != nil {
mappedSQLErr := db.MapSQLError(err)
var uniqueConstraintErr *db.ErrSqlUniqueConstraintViolation
Expand Down Expand Up @@ -325,7 +347,7 @@ func (s *SQLStore) ListSessionsByType(ctx context.Context, t Type) ([]*Session,
}

return nil
})
}, sqldb.NoOpReset)

return sessions, err
}
Expand Down Expand Up @@ -358,7 +380,7 @@ func (s *SQLStore) ListSessionsByState(ctx context.Context, state State) (
}

return nil
})
}, sqldb.NoOpReset)

return sessions, err
}
Expand Down Expand Up @@ -417,7 +439,7 @@ func (s *SQLStore) ShiftState(ctx context.Context, alias ID, dest State) error {
State: int16(dest),
},
)
})
}, sqldb.NoOpReset)
}

// DeleteReservedSessions deletes all sessions that are in the StateReserved
Expand All @@ -428,7 +450,7 @@ func (s *SQLStore) DeleteReservedSessions(ctx context.Context) error {
var writeTxOpts db.QueriesTxOptions
return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error {
return db.DeleteSessionsWithState(ctx, int16(StateReserved))
})
}, sqldb.NoOpReset)
}

// GetSessionByLocalPub fetches the session with the given local pub key.
Expand Down Expand Up @@ -458,7 +480,7 @@ func (s *SQLStore) GetSessionByLocalPub(ctx context.Context,
}

return nil
})
}, sqldb.NoOpReset)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -491,7 +513,7 @@ func (s *SQLStore) ListAllSessions(ctx context.Context) ([]*Session, error) {
}

return nil
})
}, sqldb.NoOpReset)

return sessions, err
}
Expand Down Expand Up @@ -521,7 +543,7 @@ func (s *SQLStore) UpdateSessionRemotePubKey(ctx context.Context, alias ID,
RemotePublicKey: remoteKey,
},
)
})
}, sqldb.NoOpReset)
}

// getSqlUnusedAliasAndKeyPair can be used to generate a new, unused, local
Expand Down Expand Up @@ -576,7 +598,7 @@ func (s *SQLStore) GetSession(ctx context.Context, alias ID) (*Session, error) {
}

return nil
})
}, sqldb.NoOpReset)

return sess, err
}
Expand Down Expand Up @@ -617,7 +639,7 @@ func (s *SQLStore) GetGroupID(ctx context.Context, sessionID ID) (ID, error) {
legacyGroupID, err = IDFromBytes(legacyGroupIDB)

return err
})
}, sqldb.NoOpReset)
if err != nil {
return ID{}, err
}
Expand Down Expand Up @@ -666,7 +688,7 @@ func (s *SQLStore) GetSessionIDs(ctx context.Context, legacyGroupID ID) ([]ID,
}

return nil
})
}, sqldb.NoOpReset)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions session/test_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ var ErrDBClosed = errors.New("database is closed")

// NewTestDB is a helper function that creates an SQLStore database for testing.
func NewTestDB(t *testing.T, clock clock.Clock) Store {
return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock)
return createStore(t, db.NewTestPostgresV2DB(t).BaseDB, clock)
}

// NewTestDBFromPath is a helper function that creates a new SQLStore with a
// connection to an existing postgres database for testing.
func NewTestDBFromPath(t *testing.T, dbPath string,
clock clock.Clock) Store {

return createStore(t, db.NewTestPostgresDB(t).BaseDB, clock)
return createStore(t, db.NewTestPostgresV2DB(t).BaseDB, clock)
}
11 changes: 8 additions & 3 deletions session/test_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import (
"testing"

"github.com/lightninglabs/lightning-terminal/accounts"
"github.com/lightninglabs/lightning-terminal/db"
"github.com/lightninglabs/lightning-terminal/db/sqlc"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/sqldb/v2"
"github.com/stretchr/testify/require"
)

Expand All @@ -22,8 +23,12 @@ func NewTestDBWithAccounts(t *testing.T, clock clock.Clock,

// createStore is a helper function that creates a new SQLStore and ensure that
// it is closed when during the test cleanup.
func createStore(t *testing.T, sqlDB *db.BaseDB, clock clock.Clock) *SQLStore {
store := NewSQLStore(sqlDB, clock)
func createStore(t *testing.T, sqlDB *sqldb.BaseDB,
clock clock.Clock) *SQLStore {

queries := sqlc.NewForType(sqlDB, sqlDB.BackendType)

store := NewSQLStore(sqlDB, queries, clock)
t.Cleanup(func() {
require.NoError(t, store.Close())
})
Expand Down
12 changes: 8 additions & 4 deletions session/test_sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/lightninglabs/lightning-terminal/db"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/sqldb/v2"
)

// ErrDBClosed is an error that is returned when a database operation is
Expand All @@ -16,15 +17,18 @@ var ErrDBClosed = errors.New("database is closed")

// NewTestDB is a helper function that creates an SQLStore database for testing.
func NewTestDB(t *testing.T, clock clock.Clock) Store {
return createStore(t, db.NewTestSqliteDB(t).BaseDB, clock)
return createStore(
t, sqldb.NewTestSqliteDB(t, db.LitdMigrationStreams).BaseDB,
clock,
)
}

// NewTestDBFromPath is a helper function that creates a new SQLStore with a
// connection to an existing sqlite database for testing.
func NewTestDBFromPath(t *testing.T, dbPath string,
clock clock.Clock) Store {

return createStore(
t, db.NewTestSqliteDbHandleFromPath(t, dbPath).BaseDB, clock,
)
tDb := sqldb.NewTestSqliteDBFromPath(t, dbPath, db.LitdMigrationStreams)

return createStore(t, tDb.BaseDB, clock)
}