From 65e4309f9ce2df576a1e64e170990c82b89c86ac Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 13 May 2025 14:00:35 +0200 Subject: [PATCH 1/4] db: add actions schemas and queries In this commit we define the schema for the `actions` table along with various queries we will need for interacting with the table. NOTE: we will also add some of our own queries manually in commits to follow. --- db/migrations.go | 2 +- db/sqlc/actions.sql.go | 78 ++++++++++++++++++++++ db/sqlc/migrations/000005_actions.down.sql | 5 ++ db/sqlc/migrations/000005_actions.up.sql | 52 +++++++++++++++ db/sqlc/models.go | 17 +++++ db/sqlc/querier.go | 2 + db/sqlc/queries/actions.sql | 15 +++++ 7 files changed, 170 insertions(+), 1 deletion(-) create mode 100644 db/sqlc/actions.sql.go create mode 100644 db/sqlc/migrations/000005_actions.down.sql create mode 100644 db/sqlc/migrations/000005_actions.up.sql create mode 100644 db/sqlc/queries/actions.sql diff --git a/db/migrations.go b/db/migrations.go index f52ec942a..79d63587e 100644 --- a/db/migrations.go +++ b/db/migrations.go @@ -22,7 +22,7 @@ const ( // daemon. // // NOTE: This MUST be updated when a new migration is added. - LatestMigrationVersion = 4 + LatestMigrationVersion = 5 ) // MigrationTarget is a functional option that can be passed to applyMigrations diff --git a/db/sqlc/actions.sql.go b/db/sqlc/actions.sql.go new file mode 100644 index 000000000..4a3e8c891 --- /dev/null +++ b/db/sqlc/actions.sql.go @@ -0,0 +1,78 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: actions.sql + +package sqlc + +import ( + "context" + "database/sql" + "time" +) + +const insertAction = `-- name: InsertAction :one +INSERT INTO actions ( + session_id, account_id, macaroon_identifier, actor_name, feature_name, action_trigger, + intent, structured_json_data, rpc_method, rpc_params_json, created_at, + action_state, error_reason +) VALUES ( + $1, $2, $3, $4, $5, $6, + $7, $8, $9, $10, $11, $12, $13 +) RETURNING id +` + +type InsertActionParams struct { + SessionID sql.NullInt64 + AccountID sql.NullInt64 + MacaroonIdentifier []byte + ActorName sql.NullString + FeatureName sql.NullString + ActionTrigger sql.NullString + Intent sql.NullString + StructuredJsonData []byte + RpcMethod string + RpcParamsJson []byte + CreatedAt time.Time + ActionState int16 + ErrorReason sql.NullString +} + +func (q *Queries) InsertAction(ctx context.Context, arg InsertActionParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertAction, + arg.SessionID, + arg.AccountID, + arg.MacaroonIdentifier, + arg.ActorName, + arg.FeatureName, + arg.ActionTrigger, + arg.Intent, + arg.StructuredJsonData, + arg.RpcMethod, + arg.RpcParamsJson, + arg.CreatedAt, + arg.ActionState, + arg.ErrorReason, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const setActionState = `-- name: SetActionState :exec +UPDATE actions +SET action_state = $1, + error_reason = $2 +WHERE id = $3 +` + +type SetActionStateParams struct { + ActionState int16 + ErrorReason sql.NullString + ID int64 +} + +func (q *Queries) SetActionState(ctx context.Context, arg SetActionStateParams) error { + _, err := q.db.ExecContext(ctx, setActionState, arg.ActionState, arg.ErrorReason, arg.ID) + return err +} diff --git a/db/sqlc/migrations/000005_actions.down.sql b/db/sqlc/migrations/000005_actions.down.sql new file mode 100644 index 000000000..28bdfb6b5 --- /dev/null +++ b/db/sqlc/migrations/000005_actions.down.sql @@ -0,0 +1,5 @@ +DROP INDEX IF NOT EXISTS actions_state_idx; +DROP INDEX IF NOT EXISTS actions_session_id_idx; +DROP INDEX IF NOT EXISTS actions_feature_name_idx; +DROP INDEX IF NOT EXISTS actions_created_at_idx; +DROP TABLE IF EXISTS actions; \ No newline at end of file diff --git a/db/sqlc/migrations/000005_actions.up.sql b/db/sqlc/migrations/000005_actions.up.sql new file mode 100644 index 000000000..44596af3c --- /dev/null +++ b/db/sqlc/migrations/000005_actions.up.sql @@ -0,0 +1,52 @@ +CREATE TABLE IF NOT EXISTS actions( + -- The ID of the action. + id INTEGER PRIMARY KEY, + + -- The session ID of the session that this action is associated with. + -- This may be null for actions that are not coupled to a session. + session_id BIGINT REFERENCES sessions(id) ON DELETE CASCADE, + + -- The account ID of the account that this action is associated with. + -- This may be null for actions that are not coupled to an account. + account_id BIGINT REFERENCES accounts(id) ON DELETE CASCADE, + + -- An ID derived from the macaroon used to perform the action. + macaroon_identifier BLOB, + + -- The name of the entity who performed the action. + actor_name TEXT, + + -- The name of the feature that the action is being performed by. + feature_name TEXT, + + -- Meta info detailing what caused this action to be executed. + action_trigger TEXT, + + -- Meta info detailing what the intended outcome of this action will be. + intent TEXT, + + -- Extra structured JSON data that an actor can send along with the + -- action as json. + structured_json_data BLOB, + + -- The method URI that was called. + rpc_method TEXT NOT NULL, + + -- The method parameters of the request in JSON form. + rpc_params_json BLOB, + + -- The time at which this action was created. + created_at TIMESTAMP NOT NULL, + + -- The current state of the action. + action_state SMALLINT NOT NULL, + + -- Human-readable reason for why the action failed. + -- It will only be set if state is ActionStateError (3). + error_reason TEXT +); + +CREATE INDEX IF NOT EXISTS actions_state_idx ON actions(action_state); +CREATE INDEX IF NOT EXISTS actions_session_id_idx ON actions(session_id); +CREATE INDEX IF NOT EXISTS actions_feature_name_idx ON actions(feature_name); +CREATE INDEX IF NOT EXISTS actions_created_at_idx ON actions(created_at); \ No newline at end of file diff --git a/db/sqlc/models.go b/db/sqlc/models.go index ea9242bed..357360c9e 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -37,6 +37,23 @@ type AccountPayment struct { FullAmountMsat int64 } +type Action struct { + ID int64 + SessionID sql.NullInt64 + AccountID sql.NullInt64 + MacaroonIdentifier []byte + ActorName sql.NullString + FeatureName sql.NullString + ActionTrigger sql.NullString + Intent sql.NullString + StructuredJsonData []byte + RpcMethod string + RpcParamsJson []byte + CreatedAt time.Time + ActionState int16 + ErrorReason sql.NullString +} + type Feature struct { ID int64 Name string diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index a0a9d122d..df89d0898 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -46,6 +46,7 @@ type Querier interface { GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]SessionPrivacyFlag, error) GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]Session, error) InsertAccount(ctx context.Context, arg InsertAccountParams) (int64, error) + InsertAction(ctx context.Context, arg InsertActionParams) (int64, error) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreRecordParams) error InsertPrivacyPair(ctx context.Context, arg InsertPrivacyPairParams) error InsertSession(ctx context.Context, arg InsertSessionParams) (int64, error) @@ -60,6 +61,7 @@ type Querier interface { ListSessionsByState(ctx context.Context, state int16) ([]Session, error) ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) SetAccountIndex(ctx context.Context, arg SetAccountIndexParams) error + SetActionState(ctx context.Context, arg SetActionStateParams) error SetSessionGroupID(ctx context.Context, arg SetSessionGroupIDParams) error SetSessionRemotePublicKey(ctx context.Context, arg SetSessionRemotePublicKeyParams) error SetSessionRevokedAt(ctx context.Context, arg SetSessionRevokedAtParams) error diff --git a/db/sqlc/queries/actions.sql b/db/sqlc/queries/actions.sql new file mode 100644 index 000000000..2a966022d --- /dev/null +++ b/db/sqlc/queries/actions.sql @@ -0,0 +1,15 @@ +-- name: InsertAction :one +INSERT INTO actions ( + session_id, account_id, macaroon_identifier, actor_name, feature_name, action_trigger, + intent, structured_json_data, rpc_method, rpc_params_json, created_at, + action_state, error_reason +) VALUES ( + $1, $2, $3, $4, $5, $6, + $7, $8, $9, $10, $11, $12, $13 +) RETURNING id; + +-- name: SetActionState :exec +UPDATE actions +SET action_state = $1, + error_reason = $2 +WHERE id = $3; From 1674490ab52dc25608d901bd13dbeb6f0e5e6057 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 27 May 2025 15:16:24 +0200 Subject: [PATCH 2/4] db: define manual action SQL queries Here, we manually define some queries for the actions store. We do this so that we can manually build the "SELECT" and only add "WHERE" clauses that are actually needed for the query and hence ensure that available indexes are used. --- db/interfaces.go | 7 +- db/sqlc/actions_custom.go | 210 ++++++++++++++++++++++++++++++++++++++ db/sqlc/db_custom.go | 20 ++++ 3 files changed, 234 insertions(+), 3 deletions(-) create mode 100644 db/sqlc/actions_custom.go diff --git a/db/interfaces.go b/db/interfaces.go index 3a0378f55..ba64520b4 100644 --- a/db/interfaces.go +++ b/db/interfaces.go @@ -87,12 +87,13 @@ type BatchedQuerier interface { // create a batched version of the normal methods they need. sqlc.Querier + // CustomQueries is the set of custom queries that we have manually + // defined in addition to the ones generated by sqlc. + sqlc.CustomQueries + // BeginTx creates a new database transaction given the set of // transaction options. BeginTx(ctx context.Context, options TxOptions) (*sql.Tx, error) - - // Backend returns the type of the database backend used. - Backend() sqlc.BackendType } // txExecutorOptions is a struct that holds the options for the transaction diff --git a/db/sqlc/actions_custom.go b/db/sqlc/actions_custom.go new file mode 100644 index 000000000..693fc6428 --- /dev/null +++ b/db/sqlc/actions_custom.go @@ -0,0 +1,210 @@ +package sqlc + +import ( + "context" + "database/sql" + "strconv" + "strings" +) + +// ActionQueryParams defines the parameters for querying actions. +type ActionQueryParams struct { + SessionID sql.NullInt64 + AccountID sql.NullInt64 + FeatureName sql.NullString + ActorName sql.NullString + RpcMethod sql.NullString + State sql.NullInt16 + EndTime sql.NullTime + StartTime sql.NullTime + GroupID sql.NullInt64 +} + +// ListActionsParams defines the parameters for listing actions, including +// the ActionQueryParams for filtering and a Pagination struct for +// pagination. The Reversed field indicates whether the results should be +// returned in reverse order based on the created_at timestamp. +type ListActionsParams struct { + ActionQueryParams + Reversed bool + *Pagination +} + +// Pagination defines the pagination parameters for listing actions. +type Pagination struct { + NumOffset int32 + NumLimit int32 +} + +// ListActions retrieves a list of actions based on the provided +// ListActionsParams. +func (q *Queries) ListActions(ctx context.Context, + arg ListActionsParams) ([]Action, error) { + + query, args := buildListActionsQuery(arg) + rows, err := q.db.QueryContext(ctx, fillPlaceHolders(query), args...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Action + for rows.Next() { + var i Action + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.AccountID, + &i.MacaroonIdentifier, + &i.ActorName, + &i.FeatureName, + &i.ActionTrigger, + &i.Intent, + &i.StructuredJsonData, + &i.RpcMethod, + &i.RpcParamsJson, + &i.CreatedAt, + &i.ActionState, + &i.ErrorReason, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +// CountActions returns the number of actions that match the provided +// ActionQueryParams. +func (q *Queries) CountActions(ctx context.Context, + arg ActionQueryParams) (int64, error) { + + query, args := buildActionsQuery(arg, true) + row := q.db.QueryRowContext(ctx, query, args...) + + var count int64 + err := row.Scan(&count) + + return count, err +} + +// buildActionsQuery constructs a SQL query to retrieve actions based on the +// provided parameters. We do this manually so that if, for example, we have +// a sessionID we are filtering by, then this appears in the query as: +// `WHERE a.session_id = ?` which will properly make use of the underlying +// index. If we were instead to use a single SQLC query, it would include many +// WHERE clauses like: +// "WHERE a.session_id = COALESCE(sqlc.narg('session_id'), a.session_id)". +// This would use the index if run against postres but not when run against +// sqlite. +// +// The 'count' param indicates whether the query should return a count of +// actions that match the criteria or the actions themselves. +func buildActionsQuery(params ActionQueryParams, count bool) (string, []any) { + var ( + conditions []string + args []any + ) + + if params.SessionID.Valid { + conditions = append(conditions, "a.session_id = ?") + args = append(args, params.SessionID.Int64) + } + if params.AccountID.Valid { + conditions = append(conditions, "a.account_id = ?") + args = append(args, params.AccountID.Int64) + } + if params.FeatureName.Valid { + conditions = append(conditions, "a.feature_name = ?") + args = append(args, params.FeatureName.String) + } + if params.ActorName.Valid { + conditions = append(conditions, "a.actor_name = ?") + args = append(args, params.ActorName.String) + } + if params.RpcMethod.Valid { + conditions = append(conditions, "a.rpc_method = ?") + args = append(args, params.RpcMethod.String) + } + if params.State.Valid { + conditions = append(conditions, "a.action_state = ?") + args = append(args, params.State.Int16) + } + if params.EndTime.Valid { + conditions = append(conditions, "a.created_at <= ?") + args = append(args, params.EndTime.Time) + } + if params.StartTime.Valid { + conditions = append(conditions, "a.created_at >= ?") + args = append(args, params.StartTime.Time) + } + if params.GroupID.Valid { + conditions = append(conditions, ` + EXISTS ( + SELECT 1 + FROM sessions s + WHERE s.id = a.session_id AND s.group_id = ? + )`) + args = append(args, params.GroupID.Int64) + } + + query := "SELECT a.* FROM actions a" + if count { + query = "SELECT COUNT(*) FROM actions a" + } + if len(conditions) > 0 { + query += " WHERE " + strings.Join(conditions, " AND ") + } + + return query, args +} + +// buildListActionsQuery constructs a SQL query to retrieve a list of actions +// based on the provided parameters. It builds upon the `buildActionsQuery` +// function, adding pagination and ordering based on the reversed parameter. +func buildListActionsQuery(params ListActionsParams) (string, []interface{}) { + query, args := buildActionsQuery(params.ActionQueryParams, false) + + // Determine order direction. + order := "ASC" + if params.Reversed { + order = "DESC" + } + query += " ORDER BY a.created_at " + order + + // Maybe paginate. + if params.Pagination != nil { + query += " LIMIT ? OFFSET ?" + args = append(args, params.NumLimit, params.NumOffset) + } + + return query, args +} + +// fillPlaceHolders replaces all '?' placeholders in the SQL query with +// positional placeholders like $1, $2, etc. This is necessary for +// compatibility with Postgres. +func fillPlaceHolders(query string) string { + var ( + sb strings.Builder + argNum = 1 + ) + + for i := range len(query) { + if query[i] != '?' { + sb.WriteByte(query[i]) + continue + } + + sb.WriteString("$") + sb.WriteString(strconv.Itoa(argNum)) + argNum++ + } + + return sb.String() +} diff --git a/db/sqlc/db_custom.go b/db/sqlc/db_custom.go index f9e70033b..f4bf7f611 100644 --- a/db/sqlc/db_custom.go +++ b/db/sqlc/db_custom.go @@ -1,5 +1,9 @@ package sqlc +import ( + "context" +) + // BackendType is an enum that represents the type of database backend we're // using. type BackendType uint8 @@ -44,3 +48,19 @@ func NewSqlite(db DBTX) *Queries { func NewPostgres(db DBTX) *Queries { return &Queries{db: &wrappedTX{db, BackendTypePostgres}} } + +// CustomQueries defines a set of custom queries that we define in addition +// to the ones generated by sqlc. +type CustomQueries interface { + // CountActions returns the number of actions that match the provided + // ActionQueryParams. + CountActions(ctx context.Context, arg ActionQueryParams) (int64, error) + + // ListActions retrieves a list of actions based on the provided + // ListActionsParams. + ListActions(ctx context.Context, + arg ListActionsParams) ([]Action, error) + + // Backend returns the type of the database backend used. + Backend() BackendType +} From 3f7ae5b6c87cdc456d1e552723277e7cd92b1601 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 13 May 2025 14:29:24 +0200 Subject: [PATCH 3/4] firewalldb: add SQL impl of Actions store In this commit we let the firewalldb.SQLDB implement the ActionsDB. We also ensure that all the action unit tests now run against the SQL impl. --- firewalldb/actions_sql.go | 418 +++++++++++++++++++++++++++++++++++++ firewalldb/actions_test.go | 33 +-- firewalldb/kvstores_sql.go | 6 - firewalldb/sql_store.go | 12 ++ firewalldb/test_kvdb.go | 20 ++ firewalldb/test_sql.go | 32 +++ 6 files changed, 487 insertions(+), 34 deletions(-) create mode 100644 firewalldb/actions_sql.go diff --git a/firewalldb/actions_sql.go b/firewalldb/actions_sql.go new file mode 100644 index 000000000..75c9d0a6d --- /dev/null +++ b/firewalldb/actions_sql.go @@ -0,0 +1,418 @@ +package firewalldb + +import ( + "context" + "database/sql" + "errors" + "fmt" + "math" + + "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb" +) + +// SQLAccountQueries is a subset of the sqlc.Queries interface that can be used +// to interact with the accounts table. +type SQLAccountQueries interface { + GetAccount(ctx context.Context, id int64) (sqlc.Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) +} + +// SQLActionQueries is a subset of the sqlc.Queries interface that can be used +// to interact with action related tables. +// +//nolint:lll +type SQLActionQueries interface { + SQLSessionQueries + SQLAccountQueries + + InsertAction(ctx context.Context, arg sqlc.InsertActionParams) (int64, error) + SetActionState(ctx context.Context, arg sqlc.SetActionStateParams) error + ListActions(ctx context.Context, arg sqlc.ListActionsParams) ([]sqlc.Action, error) + CountActions(ctx context.Context, arg sqlc.ActionQueryParams) (int64, error) +} + +// sqlActionLocator helps us find an action in the SQL DB. +type sqlActionLocator struct { + // id is the DB level ID of the action. + id int64 +} + +func (s *sqlActionLocator) isActionLocator() {} + +// A compile-time check to ensure sqlActionLocator implements the ActionLocator +// interface. +var _ ActionLocator = (*sqlActionLocator)(nil) + +// GetActionsReadDB is a method on DB that constructs an ActionsReadDB. +// +// NOTE: This is part of the ActionDB interface. +func (s *SQLDB) GetActionsReadDB(groupID session.ID, + featureName string) ActionsReadDB { + + return &allActionsReadDB{ + db: s, + groupID: groupID, + featureName: featureName, + } +} + +// AddAction persists the given action to the database. +// +// NOTE: This is a part of the ActionDB interface. +func (s *SQLDB) AddAction(ctx context.Context, + req *AddActionReq) (ActionLocator, error) { + + var ( + writeTxOpts db.QueriesTxOptions + locator sqlActionLocator + + actor = sql.NullString{ + String: req.ActorName, + Valid: req.ActorName != "", + } + feature = sql.NullString{ + String: req.FeatureName, + Valid: req.FeatureName != "", + } + trigger = sql.NullString{ + String: req.Trigger, + Valid: req.Trigger != "", + } + intent = sql.NullString{ + String: req.Intent, + Valid: req.Intent != "", + } + ) + + err := s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + // Do best effort to see if this action is linked to a session, + // and/or an action or none. + var ( + sessionID sql.NullInt64 + accountID sql.NullInt64 + ) + + // First check session DB. + var sessErr error + req.SessionID.WhenSome(func(alias session.ID) { + sessID, err := db.GetSessionIDByAlias(ctx, alias[:]) + if errors.Is(err, sql.ErrNoRows) { + sessErr = session.ErrSessionNotFound + return + } else if err != nil { + sessErr = err + return + } + + sessionID = sqldb.SQLInt64(sessID) + }) + if sessErr != nil { + return sessErr + } + + // If an account ID was provided, then it must exist in our DB. + var getAcctErr error + req.AccountID.WhenSome(func(alias accounts.AccountID) { + aliasInt, err := alias.ToInt64() + if err != nil { + getAcctErr = err + return + } + + acctID, err := db.GetAccountIDByAlias(ctx, aliasInt) + if errors.Is(err, sql.ErrNoRows) { + getAcctErr = accounts.ErrAccNotFound + return + } else if err != nil { + getAcctErr = err + return + } + + accountID = sqldb.SQLInt64(acctID) + }) + if getAcctErr != nil { + return getAcctErr + } + + var macID []byte + req.MacaroonIdentifier.WhenSome(func(id [4]byte) { + macID = id[:] + }) + + id, err := db.InsertAction(ctx, sqlc.InsertActionParams{ + SessionID: sessionID, + AccountID: accountID, + ActorName: actor, + MacaroonIdentifier: macID, + FeatureName: feature, + ActionTrigger: trigger, + Intent: intent, + StructuredJsonData: []byte(req.StructuredJsonData), + RpcMethod: req.RPCMethod, + RpcParamsJson: req.RPCParamsJson, + CreatedAt: s.clock.Now().UTC(), + ActionState: int16(ActionStateInit), + }) + if err != nil { + return err + } + + locator = sqlActionLocator{ + id: id, + } + + return nil + }) + if err != nil { + return nil, err + } + + return &locator, nil +} + +// SetActionState finds the action specified by the ActionLocator and sets its +// state to the given state. +// +// NOTE: This is a part of the ActionDB interface. +func (s *SQLDB) SetActionState(ctx context.Context, al ActionLocator, + state ActionState, errReason string) error { + + if errReason != "" && state != ActionStateError { + return fmt.Errorf("error reason should only be set for " + + "ActionStateError") + } + + locator, ok := al.(*sqlActionLocator) + if !ok { + return fmt.Errorf("expected sqlActionLocator, got %T", al) + } + + var writeTxOpts db.QueriesTxOptions + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + return db.SetActionState(ctx, sqlc.SetActionStateParams{ + ID: locator.id, + ActionState: int16(state), + ErrorReason: sql.NullString{ + String: errReason, + Valid: errReason != "", + }, + }) + }) +} + +// ListActions returns a list of Actions. The query IndexOffset and MaxNum +// params can be used to control the number of actions returned. +// ListActionOptions may be used to filter on specific Action values. The return +// values are the list of actions, the last index and the total count (iff +// query.CountTotal is set). +// +// NOTE: This is part of the ActionDB interface. +func (s *SQLDB) ListActions(ctx context.Context, + query *ListActionsQuery, options ...ListActionOption) ([]*Action, + uint64, uint64, error) { + + opts := newListActionOptions() + for _, o := range options { + o(opts) + } + + var ( + readTxOpts = db.NewQueryReadTx() + actions []*Action + lastIndex uint64 + totalCount int64 + ) + err := s.db.ExecTx(ctx, &readTxOpts, func(db SQLQueries) error { + var ( + actorName = sql.NullString{ + String: opts.actorName, + Valid: opts.actorName != "", + } + feature = sql.NullString{ + String: opts.featureName, + Valid: opts.featureName != "", + } + rpcMethod = sql.NullString{ + String: opts.methodName, + Valid: opts.methodName != "", + } + actionState = sql.NullInt16{ + Int16: int16(opts.state), + Valid: opts.state != 0, + } + startTime = sql.NullTime{ + Time: opts.startTime, + Valid: !opts.startTime.IsZero(), + } + endTime = sql.NullTime{ + Time: opts.endTime, + Valid: !opts.endTime.IsZero(), + } + ) + + var sessionID sql.NullInt64 + if opts.sessionID != session.EmptyID { + sID, err := db.GetSessionIDByAlias( + ctx, opts.sessionID[:], + ) + if errors.Is(err, sql.ErrNoRows) { + return session.ErrSessionNotFound + } else if err != nil { + return fmt.Errorf("unable to get DB ID for "+ + "legacy session ID %x: %w", + opts.sessionID, err) + } + + sessionID = sqldb.SQLInt64(sID) + } + + var groupID sql.NullInt64 + if opts.groupID != session.EmptyID { + gID, err := db.GetSessionIDByAlias(ctx, opts.groupID[:]) + if errors.Is(err, sql.ErrNoRows) { + return session.ErrUnknownGroup + } else if err != nil { + return fmt.Errorf("unable to get DB ID for "+ + "legacy group ID %x: %w", opts.groupID, + err) + } + + groupID = sqldb.SQLInt64(gID) + } + + var ( + dbActions []sqlc.Action + err error + ) + actionQueryParams := sqlc.ActionQueryParams{ + SessionID: sessionID, + GroupID: groupID, + FeatureName: feature, + ActorName: actorName, + RpcMethod: rpcMethod, + State: actionState, + EndTime: endTime, + StartTime: startTime, + } + queryParams := sqlc.ListActionsParams{ + ActionQueryParams: actionQueryParams, + Reversed: false, + } + if query != nil { + queryParams.Reversed = query.Reversed + queryParams.Pagination = &sqlc.Pagination{ + NumLimit: func() int32 { + if query.MaxNum == 0 { + return int32(math.MaxInt32) + } + + return int32(query.MaxNum) + }(), + NumOffset: int32(query.IndexOffset), + } + } + + dbActions, err = db.ListActions(ctx, queryParams) + if err != nil { + return fmt.Errorf("unable to list actions: %w", err) + } + + // If pagination was used, then the number of results returned + // won't necessarily match the total number of actions that + // match the query. So, if pagination was used and the CountAll + // flag is set, then we need to count the total number of + // actions that match the query. + if query != nil && query.CountAll { + totalCount, err = db.CountActions( + ctx, actionQueryParams, + ) + if err != nil { + return fmt.Errorf("unable to count actions: %w", + err) + } + } + + actions = make([]*Action, len(dbActions)) + for i, dbAction := range dbActions { + action, err := unmarshalAction(ctx, db, dbAction) + if err != nil { + return fmt.Errorf("unable to unmarshal "+ + "action: %w", err) + } + + actions[i] = action + lastIndex = uint64(dbAction.ID) + } + + return nil + }) + + return actions, lastIndex, uint64(totalCount), err +} + +func unmarshalAction(ctx context.Context, db SQLActionQueries, + dbAction sqlc.Action) (*Action, error) { + + var legacySessID fn.Option[session.ID] + if dbAction.SessionID.Valid { + legacySessIDB, err := db.GetAliasBySessionID( + ctx, dbAction.SessionID.Int64, + ) + if err != nil { + return nil, fmt.Errorf("unable to get legacy "+ + "session ID for session ID %d: %w", + dbAction.SessionID.Int64, err) + } + + sessID, err := session.IDFromBytes(legacySessIDB) + if err != nil { + return nil, err + } + + legacySessID = fn.Some(sessID) + } + + var legacyAcctID fn.Option[accounts.AccountID] + if dbAction.AccountID.Valid { + acct, err := db.GetAccount(ctx, dbAction.AccountID.Int64) + if err != nil { + return nil, err + } + + acctID, err := accounts.AccountIDFromInt64(acct.Alias) + if err != nil { + return nil, fmt.Errorf("unable to get account ID: %w", + err) + } + + legacyAcctID = fn.Some(acctID) + } + + var macID fn.Option[[4]byte] + if len(dbAction.MacaroonIdentifier) > 0 { + macID = fn.Some([4]byte(dbAction.MacaroonIdentifier)) + } + + return &Action{ + AddActionReq: AddActionReq{ + MacaroonIdentifier: macID, + AccountID: legacyAcctID, + SessionID: legacySessID, + ActorName: dbAction.ActorName.String, + FeatureName: dbAction.FeatureName.String, + Trigger: dbAction.ActionTrigger.String, + Intent: dbAction.Intent.String, + StructuredJsonData: string(dbAction.StructuredJsonData), + RPCMethod: dbAction.RpcMethod, + RPCParamsJson: dbAction.RpcParamsJson, + }, + AttemptedAt: dbAction.CreatedAt, + State: ActionState(dbAction.ActionState), + ErrorReason: dbAction.ErrorReason.String, + }, nil +} diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index a355bede0..c27e53e96 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -24,18 +24,17 @@ func TestActionStorage(t *testing.T) { ctx := context.Background() clock := clock.NewTestClock(testTime1) - sessDB := session.NewTestDB(t, clock) accountsDB := accounts.NewTestDB(t, clock) + sessDB := session.NewTestDBWithAccounts(t, clock, accountsDB) - db, err := NewBoltDB(t.TempDir(), "test.db", sessDB, accountsDB, clock) - require.NoError(t, err) + db := NewTestDBWithSessionsAndAccounts(t, sessDB, accountsDB, clock) t.Cleanup(func() { _ = db.Close() }) // Assert that attempting to add an action for a session that does not // exist returns an error. - _, err = db.AddAction(ctx, &AddActionReq{ + _, err := db.AddAction(ctx, &AddActionReq{ SessionID: fn.Some(session.ID{1, 2, 3, 4}), }) require.ErrorIs(t, err, session.ErrSessionNotFound) @@ -194,13 +193,11 @@ func TestActionStorage(t *testing.T) { func TestListActions(t *testing.T) { t.Parallel() - tmpDir := t.TempDir() ctx := context.Background() clock := clock.NewDefaultClock() sessDB := session.NewTestDB(t, clock) - db, err := NewBoltDB(tmpDir, "test.db", sessDB, nil, clock) - require.NoError(t, err) + db := NewTestDBWithSessions(t, sessDB, clock) t.Cleanup(func() { _ = db.Close() }) @@ -468,8 +465,7 @@ func TestListGroupActions(t *testing.T) { State: ActionStateInit, } - db, err := NewBoltDB(t.TempDir(), "test.db", sessDB, nil, clock) - require.NoError(t, err) + db := NewTestDBWithSessions(t, sessDB, clock) t.Cleanup(func() { _ = db.Close() }) @@ -514,22 +510,3 @@ func TestListGroupActions(t *testing.T) { assertEqualActions(t, action2, al[0]) assertEqualActions(t, action1, al[1]) } - -func assertEqualActions(t *testing.T, expected, got *Action) { - // Accounts are not explicitly linked in our bbolt DB implementation. - got.AccountID = expected.AccountID - - expectedAttemptedAt := expected.AttemptedAt - actualAttemptedAt := got.AttemptedAt - - expected.AttemptedAt = time.Time{} - got.AttemptedAt = time.Time{} - - require.Equal(t, expected, got) - require.Equal(t, expectedAttemptedAt.Unix(), actualAttemptedAt.Unix()) - - expected.AttemptedAt = expectedAttemptedAt - got.AttemptedAt = actualAttemptedAt - - got.AccountID = fn.None[accounts.AccountID]() -} diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index e7e1e7da5..0c3df2ddb 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -13,12 +13,6 @@ import ( "github.com/lightningnetwork/lnd/fn" ) -// SQLSessionQueries is a subset of the sqlc.Queries interface that can be used -// to interact with the session table. -type SQLSessionQueries interface { - GetSessionIDByAlias(ctx context.Context, legacyID []byte) (int64, error) -} - // SQLKVStoreQueries is a subset of the sqlc.Queries interface that can be // used to interact with the kvstore tables. // diff --git a/firewalldb/sql_store.go b/firewalldb/sql_store.go index 369920d63..f17010f2c 100644 --- a/firewalldb/sql_store.go +++ b/firewalldb/sql_store.go @@ -8,11 +8,19 @@ import ( "github.com/lightningnetwork/lnd/clock" ) +// SQLSessionQueries is a subset of the sqlc.Queries interface that can be used +// to interact with the session table. +type SQLSessionQueries interface { + GetSessionIDByAlias(ctx context.Context, legacyID []byte) (int64, error) + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) +} + // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with various firewalldb tables. type SQLQueries interface { SQLKVStoreQueries SQLPrivacyPairQueries + SQLActionQueries } // BatchedSQLQueries is a version of the SQLQueries that's capable of batched @@ -39,6 +47,10 @@ type SQLDB struct { // interface. var _ RulesDB = (*SQLDB)(nil) +// A compile-time assertion to ensure that SQLDB implements the ActionsDB +// interface. +var _ ActionDB = (*SQLDB)(nil) + // NewSQLDB creates a new SQLStore instance given an open SQLQueries // storage backend. func NewSQLDB(sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { diff --git a/firewalldb/test_kvdb.go b/firewalldb/test_kvdb.go index 659292702..6f7a49aa3 100644 --- a/firewalldb/test_kvdb.go +++ b/firewalldb/test_kvdb.go @@ -5,7 +5,9 @@ package firewalldb import ( "testing" + "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/require" ) @@ -28,6 +30,16 @@ func NewTestDBWithSessions(t *testing.T, sessStore SessionDB, return newDBFromPathWithSessions(t, t.TempDir(), sessStore, nil, clock) } +// NewTestDBWithSessionsAndAccounts creates a new test BoltDB Store with access +// to an existing sessions DB and accounts DB. +func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, + acctStore AccountsDB, clock clock.Clock) *BoltDB { + + return newDBFromPathWithSessions( + t, t.TempDir(), sessStore, acctStore, clock, + ) +} + func newDBFromPathWithSessions(t *testing.T, dbPath string, sessStore SessionDB, acctStore AccountsDB, clock clock.Clock) *BoltDB { @@ -40,3 +52,11 @@ func newDBFromPathWithSessions(t *testing.T, dbPath string, return store } + +func assertEqualActions(t *testing.T, expected, got *Action) { + // Accounts are not explicitly linked in our bbolt DB implementation. + got.AccountID = expected.AccountID + require.Equal(t, expected, got) + + got.AccountID = fn.None[accounts.AccountID]() +} diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index 947ff1491..03dcfbebf 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -4,7 +4,9 @@ package firewalldb import ( "testing" + "time" + "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/stretchr/testify/require" @@ -20,3 +22,33 @@ func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, return NewSQLDB(sessions.BaseDB, clock) } + +// NewTestDBWithSessionsAndAccounts creates a new test SQLDB Store with access +// to an existing sessions DB and accounts DB. +func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, + acctStore AccountsDB, clock clock.Clock) *SQLDB { + + sessions, ok := sessionStore.(*session.SQLStore) + require.True(t, ok) + + accounts, ok := acctStore.(*accounts.SQLStore) + require.True(t, ok) + + require.Equal(t, accounts.BaseDB, sessions.BaseDB) + + return NewSQLDB(sessions.BaseDB, clock) +} + +func assertEqualActions(t *testing.T, expected, got *Action) { + expectedAttemptedAt := expected.AttemptedAt + actualAttemptedAt := got.AttemptedAt + + expected.AttemptedAt = time.Time{} + got.AttemptedAt = time.Time{} + + require.Equal(t, expected, got) + require.Equal(t, expectedAttemptedAt.Unix(), actualAttemptedAt.Unix()) + + expected.AttemptedAt = expectedAttemptedAt + got.AttemptedAt = actualAttemptedAt +} From 89b807c3bbfaa0147f6f070005e5a27c2ab4bffb Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 13 May 2025 14:33:37 +0200 Subject: [PATCH 4/4] lit: plug the SQL actions DB into the `dev` build --- config_dev.go | 22 +++++++++------------- config_prod.go | 1 - firewalldb/db.go | 1 + terminal.go | 9 ++++----- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/config_dev.go b/config_dev.go index 4ab17bd77..90b8b290f 100644 --- a/config_dev.go +++ b/config_dev.go @@ -151,23 +151,19 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { stores.sessions = sessionStore stores.closeFns["bbolt-sessions"] = sessionStore.Close - } - firewallBoltDB, err := firewalldb.NewBoltDB( - networkDir, firewalldb.DBFilename, stores.sessions, - stores.accounts, clock, - ) - if err != nil { - return stores, fmt.Errorf("error creating firewall BoltDB: %v", - err) - } + firewallBoltDB, err := firewalldb.NewBoltDB( + networkDir, firewalldb.DBFilename, stores.sessions, + stores.accounts, clock, + ) + if err != nil { + return stores, fmt.Errorf("error creating firewall "+ + "BoltDB: %v", err) + } - if stores.firewall == nil { stores.firewall = firewalldb.NewDB(firewallBoltDB) + stores.closeFns["bbolt-firewalldb"] = firewallBoltDB.Close } - stores.firewallBolt = firewallBoltDB - stores.closeFns["bbolt-firewalldb"] = firewallBoltDB.Close - return stores, nil } diff --git a/config_prod.go b/config_prod.go index c13d66960..ac6e6d996 100644 --- a/config_prod.go +++ b/config_prod.go @@ -62,7 +62,6 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { if err != nil { return stores, fmt.Errorf("error creating firewall DB: %v", err) } - stores.firewallBolt = firewallDB stores.firewall = firewalldb.NewDB(firewallDB) stores.closeFns["firewall"] = firewallDB.Close diff --git a/firewalldb/db.go b/firewalldb/db.go index 8b913b69d..b8d9ed06f 100644 --- a/firewalldb/db.go +++ b/firewalldb/db.go @@ -19,6 +19,7 @@ var ( type firewallDBs interface { RulesDB PrivacyMapper + ActionDB } // DB manages the firewall rules database. diff --git a/terminal.go b/terminal.go index e08e9a25f..7e4d552c7 100644 --- a/terminal.go +++ b/terminal.go @@ -239,8 +239,7 @@ type stores struct { accounts accounts.Store sessions session.Store - firewall *firewalldb.DB - firewallBolt *firewalldb.BoltDB + firewall *firewalldb.DB // closeFns holds various callbacks that can be used to close any open // stores in the stores struct. @@ -531,7 +530,7 @@ func (g *LightningTerminal) start(ctx context.Context) error { superMacBaker: superMacBaker, firstConnectionDeadline: g.cfg.FirstLNCConnDeadline, permMgr: g.permsMgr, - actionsDB: g.stores.firewallBolt, + actionsDB: g.stores.firewall, autopilot: g.autopilotClient, ruleMgrs: g.ruleMgrs, privMap: g.stores.firewall, @@ -1093,7 +1092,7 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, } requestLogger, err := firewall.NewRequestLogger( - g.cfg.Firewall.RequestLogger, g.stores.firewallBolt, + g.cfg.Firewall.RequestLogger, g.stores.firewall, ) if err != nil { return fmt.Errorf("error creating new request logger") @@ -1112,7 +1111,7 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, if !g.cfg.Autopilot.Disable { ruleEnforcer := firewall.NewRuleEnforcer( - g.stores.firewall, g.stores.firewallBolt, + g.stores.firewall, g.stores.firewall, g.stores.sessions, g.autopilotClient.ListFeaturePerms, g.permsMgr, g.lndClient.NodePubkey,