diff --git a/config_dev.go b/config_dev.go index 4ab17bd77..90b8b290f 100644 --- a/config_dev.go +++ b/config_dev.go @@ -151,23 +151,19 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { stores.sessions = sessionStore stores.closeFns["bbolt-sessions"] = sessionStore.Close - } - firewallBoltDB, err := firewalldb.NewBoltDB( - networkDir, firewalldb.DBFilename, stores.sessions, - stores.accounts, clock, - ) - if err != nil { - return stores, fmt.Errorf("error creating firewall BoltDB: %v", - err) - } + firewallBoltDB, err := firewalldb.NewBoltDB( + networkDir, firewalldb.DBFilename, stores.sessions, + stores.accounts, clock, + ) + if err != nil { + return stores, fmt.Errorf("error creating firewall "+ + "BoltDB: %v", err) + } - if stores.firewall == nil { stores.firewall = firewalldb.NewDB(firewallBoltDB) + stores.closeFns["bbolt-firewalldb"] = firewallBoltDB.Close } - stores.firewallBolt = firewallBoltDB - stores.closeFns["bbolt-firewalldb"] = firewallBoltDB.Close - return stores, nil } diff --git a/config_prod.go b/config_prod.go index c13d66960..ac6e6d996 100644 --- a/config_prod.go +++ b/config_prod.go @@ -62,7 +62,6 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) { if err != nil { return stores, fmt.Errorf("error creating firewall DB: %v", err) } - stores.firewallBolt = firewallDB stores.firewall = firewalldb.NewDB(firewallDB) stores.closeFns["firewall"] = firewallDB.Close diff --git a/db/interfaces.go b/db/interfaces.go index 3a0378f55..ba64520b4 100644 --- a/db/interfaces.go +++ b/db/interfaces.go @@ -87,12 +87,13 @@ type BatchedQuerier interface { // create a batched version of the normal methods they need. sqlc.Querier + // CustomQueries is the set of custom queries that we have manually + // defined in addition to the ones generated by sqlc. + sqlc.CustomQueries + // BeginTx creates a new database transaction given the set of // transaction options. BeginTx(ctx context.Context, options TxOptions) (*sql.Tx, error) - - // Backend returns the type of the database backend used. - Backend() sqlc.BackendType } // txExecutorOptions is a struct that holds the options for the transaction diff --git a/db/migrations.go b/db/migrations.go index f52ec942a..79d63587e 100644 --- a/db/migrations.go +++ b/db/migrations.go @@ -22,7 +22,7 @@ const ( // daemon. // // NOTE: This MUST be updated when a new migration is added. - LatestMigrationVersion = 4 + LatestMigrationVersion = 5 ) // MigrationTarget is a functional option that can be passed to applyMigrations diff --git a/db/sqlc/actions.sql.go b/db/sqlc/actions.sql.go new file mode 100644 index 000000000..4a3e8c891 --- /dev/null +++ b/db/sqlc/actions.sql.go @@ -0,0 +1,78 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: actions.sql + +package sqlc + +import ( + "context" + "database/sql" + "time" +) + +const insertAction = `-- name: InsertAction :one +INSERT INTO actions ( + session_id, account_id, macaroon_identifier, actor_name, feature_name, action_trigger, + intent, structured_json_data, rpc_method, rpc_params_json, created_at, + action_state, error_reason +) VALUES ( + $1, $2, $3, $4, $5, $6, + $7, $8, $9, $10, $11, $12, $13 +) RETURNING id +` + +type InsertActionParams struct { + SessionID sql.NullInt64 + AccountID sql.NullInt64 + MacaroonIdentifier []byte + ActorName sql.NullString + FeatureName sql.NullString + ActionTrigger sql.NullString + Intent sql.NullString + StructuredJsonData []byte + RpcMethod string + RpcParamsJson []byte + CreatedAt time.Time + ActionState int16 + ErrorReason sql.NullString +} + +func (q *Queries) InsertAction(ctx context.Context, arg InsertActionParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertAction, + arg.SessionID, + arg.AccountID, + arg.MacaroonIdentifier, + arg.ActorName, + arg.FeatureName, + arg.ActionTrigger, + arg.Intent, + arg.StructuredJsonData, + arg.RpcMethod, + arg.RpcParamsJson, + arg.CreatedAt, + arg.ActionState, + arg.ErrorReason, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const setActionState = `-- name: SetActionState :exec +UPDATE actions +SET action_state = $1, + error_reason = $2 +WHERE id = $3 +` + +type SetActionStateParams struct { + ActionState int16 + ErrorReason sql.NullString + ID int64 +} + +func (q *Queries) SetActionState(ctx context.Context, arg SetActionStateParams) error { + _, err := q.db.ExecContext(ctx, setActionState, arg.ActionState, arg.ErrorReason, arg.ID) + return err +} diff --git a/db/sqlc/actions_custom.go b/db/sqlc/actions_custom.go new file mode 100644 index 000000000..693fc6428 --- /dev/null +++ b/db/sqlc/actions_custom.go @@ -0,0 +1,210 @@ +package sqlc + +import ( + "context" + "database/sql" + "strconv" + "strings" +) + +// ActionQueryParams defines the parameters for querying actions. +type ActionQueryParams struct { + SessionID sql.NullInt64 + AccountID sql.NullInt64 + FeatureName sql.NullString + ActorName sql.NullString + RpcMethod sql.NullString + State sql.NullInt16 + EndTime sql.NullTime + StartTime sql.NullTime + GroupID sql.NullInt64 +} + +// ListActionsParams defines the parameters for listing actions, including +// the ActionQueryParams for filtering and a Pagination struct for +// pagination. The Reversed field indicates whether the results should be +// returned in reverse order based on the created_at timestamp. +type ListActionsParams struct { + ActionQueryParams + Reversed bool + *Pagination +} + +// Pagination defines the pagination parameters for listing actions. +type Pagination struct { + NumOffset int32 + NumLimit int32 +} + +// ListActions retrieves a list of actions based on the provided +// ListActionsParams. +func (q *Queries) ListActions(ctx context.Context, + arg ListActionsParams) ([]Action, error) { + + query, args := buildListActionsQuery(arg) + rows, err := q.db.QueryContext(ctx, fillPlaceHolders(query), args...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Action + for rows.Next() { + var i Action + if err := rows.Scan( + &i.ID, + &i.SessionID, + &i.AccountID, + &i.MacaroonIdentifier, + &i.ActorName, + &i.FeatureName, + &i.ActionTrigger, + &i.Intent, + &i.StructuredJsonData, + &i.RpcMethod, + &i.RpcParamsJson, + &i.CreatedAt, + &i.ActionState, + &i.ErrorReason, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +// CountActions returns the number of actions that match the provided +// ActionQueryParams. +func (q *Queries) CountActions(ctx context.Context, + arg ActionQueryParams) (int64, error) { + + query, args := buildActionsQuery(arg, true) + row := q.db.QueryRowContext(ctx, query, args...) + + var count int64 + err := row.Scan(&count) + + return count, err +} + +// buildActionsQuery constructs a SQL query to retrieve actions based on the +// provided parameters. We do this manually so that if, for example, we have +// a sessionID we are filtering by, then this appears in the query as: +// `WHERE a.session_id = ?` which will properly make use of the underlying +// index. If we were instead to use a single SQLC query, it would include many +// WHERE clauses like: +// "WHERE a.session_id = COALESCE(sqlc.narg('session_id'), a.session_id)". +// This would use the index if run against postres but not when run against +// sqlite. +// +// The 'count' param indicates whether the query should return a count of +// actions that match the criteria or the actions themselves. +func buildActionsQuery(params ActionQueryParams, count bool) (string, []any) { + var ( + conditions []string + args []any + ) + + if params.SessionID.Valid { + conditions = append(conditions, "a.session_id = ?") + args = append(args, params.SessionID.Int64) + } + if params.AccountID.Valid { + conditions = append(conditions, "a.account_id = ?") + args = append(args, params.AccountID.Int64) + } + if params.FeatureName.Valid { + conditions = append(conditions, "a.feature_name = ?") + args = append(args, params.FeatureName.String) + } + if params.ActorName.Valid { + conditions = append(conditions, "a.actor_name = ?") + args = append(args, params.ActorName.String) + } + if params.RpcMethod.Valid { + conditions = append(conditions, "a.rpc_method = ?") + args = append(args, params.RpcMethod.String) + } + if params.State.Valid { + conditions = append(conditions, "a.action_state = ?") + args = append(args, params.State.Int16) + } + if params.EndTime.Valid { + conditions = append(conditions, "a.created_at <= ?") + args = append(args, params.EndTime.Time) + } + if params.StartTime.Valid { + conditions = append(conditions, "a.created_at >= ?") + args = append(args, params.StartTime.Time) + } + if params.GroupID.Valid { + conditions = append(conditions, ` + EXISTS ( + SELECT 1 + FROM sessions s + WHERE s.id = a.session_id AND s.group_id = ? + )`) + args = append(args, params.GroupID.Int64) + } + + query := "SELECT a.* FROM actions a" + if count { + query = "SELECT COUNT(*) FROM actions a" + } + if len(conditions) > 0 { + query += " WHERE " + strings.Join(conditions, " AND ") + } + + return query, args +} + +// buildListActionsQuery constructs a SQL query to retrieve a list of actions +// based on the provided parameters. It builds upon the `buildActionsQuery` +// function, adding pagination and ordering based on the reversed parameter. +func buildListActionsQuery(params ListActionsParams) (string, []interface{}) { + query, args := buildActionsQuery(params.ActionQueryParams, false) + + // Determine order direction. + order := "ASC" + if params.Reversed { + order = "DESC" + } + query += " ORDER BY a.created_at " + order + + // Maybe paginate. + if params.Pagination != nil { + query += " LIMIT ? OFFSET ?" + args = append(args, params.NumLimit, params.NumOffset) + } + + return query, args +} + +// fillPlaceHolders replaces all '?' placeholders in the SQL query with +// positional placeholders like $1, $2, etc. This is necessary for +// compatibility with Postgres. +func fillPlaceHolders(query string) string { + var ( + sb strings.Builder + argNum = 1 + ) + + for i := range len(query) { + if query[i] != '?' { + sb.WriteByte(query[i]) + continue + } + + sb.WriteString("$") + sb.WriteString(strconv.Itoa(argNum)) + argNum++ + } + + return sb.String() +} diff --git a/db/sqlc/db_custom.go b/db/sqlc/db_custom.go index f9e70033b..f4bf7f611 100644 --- a/db/sqlc/db_custom.go +++ b/db/sqlc/db_custom.go @@ -1,5 +1,9 @@ package sqlc +import ( + "context" +) + // BackendType is an enum that represents the type of database backend we're // using. type BackendType uint8 @@ -44,3 +48,19 @@ func NewSqlite(db DBTX) *Queries { func NewPostgres(db DBTX) *Queries { return &Queries{db: &wrappedTX{db, BackendTypePostgres}} } + +// CustomQueries defines a set of custom queries that we define in addition +// to the ones generated by sqlc. +type CustomQueries interface { + // CountActions returns the number of actions that match the provided + // ActionQueryParams. + CountActions(ctx context.Context, arg ActionQueryParams) (int64, error) + + // ListActions retrieves a list of actions based on the provided + // ListActionsParams. + ListActions(ctx context.Context, + arg ListActionsParams) ([]Action, error) + + // Backend returns the type of the database backend used. + Backend() BackendType +} diff --git a/db/sqlc/migrations/000005_actions.down.sql b/db/sqlc/migrations/000005_actions.down.sql new file mode 100644 index 000000000..28bdfb6b5 --- /dev/null +++ b/db/sqlc/migrations/000005_actions.down.sql @@ -0,0 +1,5 @@ +DROP INDEX IF NOT EXISTS actions_state_idx; +DROP INDEX IF NOT EXISTS actions_session_id_idx; +DROP INDEX IF NOT EXISTS actions_feature_name_idx; +DROP INDEX IF NOT EXISTS actions_created_at_idx; +DROP TABLE IF EXISTS actions; \ No newline at end of file diff --git a/db/sqlc/migrations/000005_actions.up.sql b/db/sqlc/migrations/000005_actions.up.sql new file mode 100644 index 000000000..44596af3c --- /dev/null +++ b/db/sqlc/migrations/000005_actions.up.sql @@ -0,0 +1,52 @@ +CREATE TABLE IF NOT EXISTS actions( + -- The ID of the action. + id INTEGER PRIMARY KEY, + + -- The session ID of the session that this action is associated with. + -- This may be null for actions that are not coupled to a session. + session_id BIGINT REFERENCES sessions(id) ON DELETE CASCADE, + + -- The account ID of the account that this action is associated with. + -- This may be null for actions that are not coupled to an account. + account_id BIGINT REFERENCES accounts(id) ON DELETE CASCADE, + + -- An ID derived from the macaroon used to perform the action. + macaroon_identifier BLOB, + + -- The name of the entity who performed the action. + actor_name TEXT, + + -- The name of the feature that the action is being performed by. + feature_name TEXT, + + -- Meta info detailing what caused this action to be executed. + action_trigger TEXT, + + -- Meta info detailing what the intended outcome of this action will be. + intent TEXT, + + -- Extra structured JSON data that an actor can send along with the + -- action as json. + structured_json_data BLOB, + + -- The method URI that was called. + rpc_method TEXT NOT NULL, + + -- The method parameters of the request in JSON form. + rpc_params_json BLOB, + + -- The time at which this action was created. + created_at TIMESTAMP NOT NULL, + + -- The current state of the action. + action_state SMALLINT NOT NULL, + + -- Human-readable reason for why the action failed. + -- It will only be set if state is ActionStateError (3). + error_reason TEXT +); + +CREATE INDEX IF NOT EXISTS actions_state_idx ON actions(action_state); +CREATE INDEX IF NOT EXISTS actions_session_id_idx ON actions(session_id); +CREATE INDEX IF NOT EXISTS actions_feature_name_idx ON actions(feature_name); +CREATE INDEX IF NOT EXISTS actions_created_at_idx ON actions(created_at); \ No newline at end of file diff --git a/db/sqlc/models.go b/db/sqlc/models.go index ea9242bed..357360c9e 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -37,6 +37,23 @@ type AccountPayment struct { FullAmountMsat int64 } +type Action struct { + ID int64 + SessionID sql.NullInt64 + AccountID sql.NullInt64 + MacaroonIdentifier []byte + ActorName sql.NullString + FeatureName sql.NullString + ActionTrigger sql.NullString + Intent sql.NullString + StructuredJsonData []byte + RpcMethod string + RpcParamsJson []byte + CreatedAt time.Time + ActionState int16 + ErrorReason sql.NullString +} + type Feature struct { ID int64 Name string diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index a0a9d122d..df89d0898 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -46,6 +46,7 @@ type Querier interface { GetSessionPrivacyFlags(ctx context.Context, sessionID int64) ([]SessionPrivacyFlag, error) GetSessionsInGroup(ctx context.Context, groupID sql.NullInt64) ([]Session, error) InsertAccount(ctx context.Context, arg InsertAccountParams) (int64, error) + InsertAction(ctx context.Context, arg InsertActionParams) (int64, error) InsertKVStoreRecord(ctx context.Context, arg InsertKVStoreRecordParams) error InsertPrivacyPair(ctx context.Context, arg InsertPrivacyPairParams) error InsertSession(ctx context.Context, arg InsertSessionParams) (int64, error) @@ -60,6 +61,7 @@ type Querier interface { ListSessionsByState(ctx context.Context, state int16) ([]Session, error) ListSessionsByType(ctx context.Context, type_ int16) ([]Session, error) SetAccountIndex(ctx context.Context, arg SetAccountIndexParams) error + SetActionState(ctx context.Context, arg SetActionStateParams) error SetSessionGroupID(ctx context.Context, arg SetSessionGroupIDParams) error SetSessionRemotePublicKey(ctx context.Context, arg SetSessionRemotePublicKeyParams) error SetSessionRevokedAt(ctx context.Context, arg SetSessionRevokedAtParams) error diff --git a/db/sqlc/queries/actions.sql b/db/sqlc/queries/actions.sql new file mode 100644 index 000000000..2a966022d --- /dev/null +++ b/db/sqlc/queries/actions.sql @@ -0,0 +1,15 @@ +-- name: InsertAction :one +INSERT INTO actions ( + session_id, account_id, macaroon_identifier, actor_name, feature_name, action_trigger, + intent, structured_json_data, rpc_method, rpc_params_json, created_at, + action_state, error_reason +) VALUES ( + $1, $2, $3, $4, $5, $6, + $7, $8, $9, $10, $11, $12, $13 +) RETURNING id; + +-- name: SetActionState :exec +UPDATE actions +SET action_state = $1, + error_reason = $2 +WHERE id = $3; diff --git a/firewalldb/actions_sql.go b/firewalldb/actions_sql.go new file mode 100644 index 000000000..75c9d0a6d --- /dev/null +++ b/firewalldb/actions_sql.go @@ -0,0 +1,418 @@ +package firewalldb + +import ( + "context" + "database/sql" + "errors" + "fmt" + "math" + + "github.com/lightninglabs/lightning-terminal/accounts" + "github.com/lightninglabs/lightning-terminal/db" + "github.com/lightninglabs/lightning-terminal/db/sqlc" + "github.com/lightninglabs/lightning-terminal/session" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/sqldb" +) + +// SQLAccountQueries is a subset of the sqlc.Queries interface that can be used +// to interact with the accounts table. +type SQLAccountQueries interface { + GetAccount(ctx context.Context, id int64) (sqlc.Account, error) + GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error) +} + +// SQLActionQueries is a subset of the sqlc.Queries interface that can be used +// to interact with action related tables. +// +//nolint:lll +type SQLActionQueries interface { + SQLSessionQueries + SQLAccountQueries + + InsertAction(ctx context.Context, arg sqlc.InsertActionParams) (int64, error) + SetActionState(ctx context.Context, arg sqlc.SetActionStateParams) error + ListActions(ctx context.Context, arg sqlc.ListActionsParams) ([]sqlc.Action, error) + CountActions(ctx context.Context, arg sqlc.ActionQueryParams) (int64, error) +} + +// sqlActionLocator helps us find an action in the SQL DB. +type sqlActionLocator struct { + // id is the DB level ID of the action. + id int64 +} + +func (s *sqlActionLocator) isActionLocator() {} + +// A compile-time check to ensure sqlActionLocator implements the ActionLocator +// interface. +var _ ActionLocator = (*sqlActionLocator)(nil) + +// GetActionsReadDB is a method on DB that constructs an ActionsReadDB. +// +// NOTE: This is part of the ActionDB interface. +func (s *SQLDB) GetActionsReadDB(groupID session.ID, + featureName string) ActionsReadDB { + + return &allActionsReadDB{ + db: s, + groupID: groupID, + featureName: featureName, + } +} + +// AddAction persists the given action to the database. +// +// NOTE: This is a part of the ActionDB interface. +func (s *SQLDB) AddAction(ctx context.Context, + req *AddActionReq) (ActionLocator, error) { + + var ( + writeTxOpts db.QueriesTxOptions + locator sqlActionLocator + + actor = sql.NullString{ + String: req.ActorName, + Valid: req.ActorName != "", + } + feature = sql.NullString{ + String: req.FeatureName, + Valid: req.FeatureName != "", + } + trigger = sql.NullString{ + String: req.Trigger, + Valid: req.Trigger != "", + } + intent = sql.NullString{ + String: req.Intent, + Valid: req.Intent != "", + } + ) + + err := s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + // Do best effort to see if this action is linked to a session, + // and/or an action or none. + var ( + sessionID sql.NullInt64 + accountID sql.NullInt64 + ) + + // First check session DB. + var sessErr error + req.SessionID.WhenSome(func(alias session.ID) { + sessID, err := db.GetSessionIDByAlias(ctx, alias[:]) + if errors.Is(err, sql.ErrNoRows) { + sessErr = session.ErrSessionNotFound + return + } else if err != nil { + sessErr = err + return + } + + sessionID = sqldb.SQLInt64(sessID) + }) + if sessErr != nil { + return sessErr + } + + // If an account ID was provided, then it must exist in our DB. + var getAcctErr error + req.AccountID.WhenSome(func(alias accounts.AccountID) { + aliasInt, err := alias.ToInt64() + if err != nil { + getAcctErr = err + return + } + + acctID, err := db.GetAccountIDByAlias(ctx, aliasInt) + if errors.Is(err, sql.ErrNoRows) { + getAcctErr = accounts.ErrAccNotFound + return + } else if err != nil { + getAcctErr = err + return + } + + accountID = sqldb.SQLInt64(acctID) + }) + if getAcctErr != nil { + return getAcctErr + } + + var macID []byte + req.MacaroonIdentifier.WhenSome(func(id [4]byte) { + macID = id[:] + }) + + id, err := db.InsertAction(ctx, sqlc.InsertActionParams{ + SessionID: sessionID, + AccountID: accountID, + ActorName: actor, + MacaroonIdentifier: macID, + FeatureName: feature, + ActionTrigger: trigger, + Intent: intent, + StructuredJsonData: []byte(req.StructuredJsonData), + RpcMethod: req.RPCMethod, + RpcParamsJson: req.RPCParamsJson, + CreatedAt: s.clock.Now().UTC(), + ActionState: int16(ActionStateInit), + }) + if err != nil { + return err + } + + locator = sqlActionLocator{ + id: id, + } + + return nil + }) + if err != nil { + return nil, err + } + + return &locator, nil +} + +// SetActionState finds the action specified by the ActionLocator and sets its +// state to the given state. +// +// NOTE: This is a part of the ActionDB interface. +func (s *SQLDB) SetActionState(ctx context.Context, al ActionLocator, + state ActionState, errReason string) error { + + if errReason != "" && state != ActionStateError { + return fmt.Errorf("error reason should only be set for " + + "ActionStateError") + } + + locator, ok := al.(*sqlActionLocator) + if !ok { + return fmt.Errorf("expected sqlActionLocator, got %T", al) + } + + var writeTxOpts db.QueriesTxOptions + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + return db.SetActionState(ctx, sqlc.SetActionStateParams{ + ID: locator.id, + ActionState: int16(state), + ErrorReason: sql.NullString{ + String: errReason, + Valid: errReason != "", + }, + }) + }) +} + +// ListActions returns a list of Actions. The query IndexOffset and MaxNum +// params can be used to control the number of actions returned. +// ListActionOptions may be used to filter on specific Action values. The return +// values are the list of actions, the last index and the total count (iff +// query.CountTotal is set). +// +// NOTE: This is part of the ActionDB interface. +func (s *SQLDB) ListActions(ctx context.Context, + query *ListActionsQuery, options ...ListActionOption) ([]*Action, + uint64, uint64, error) { + + opts := newListActionOptions() + for _, o := range options { + o(opts) + } + + var ( + readTxOpts = db.NewQueryReadTx() + actions []*Action + lastIndex uint64 + totalCount int64 + ) + err := s.db.ExecTx(ctx, &readTxOpts, func(db SQLQueries) error { + var ( + actorName = sql.NullString{ + String: opts.actorName, + Valid: opts.actorName != "", + } + feature = sql.NullString{ + String: opts.featureName, + Valid: opts.featureName != "", + } + rpcMethod = sql.NullString{ + String: opts.methodName, + Valid: opts.methodName != "", + } + actionState = sql.NullInt16{ + Int16: int16(opts.state), + Valid: opts.state != 0, + } + startTime = sql.NullTime{ + Time: opts.startTime, + Valid: !opts.startTime.IsZero(), + } + endTime = sql.NullTime{ + Time: opts.endTime, + Valid: !opts.endTime.IsZero(), + } + ) + + var sessionID sql.NullInt64 + if opts.sessionID != session.EmptyID { + sID, err := db.GetSessionIDByAlias( + ctx, opts.sessionID[:], + ) + if errors.Is(err, sql.ErrNoRows) { + return session.ErrSessionNotFound + } else if err != nil { + return fmt.Errorf("unable to get DB ID for "+ + "legacy session ID %x: %w", + opts.sessionID, err) + } + + sessionID = sqldb.SQLInt64(sID) + } + + var groupID sql.NullInt64 + if opts.groupID != session.EmptyID { + gID, err := db.GetSessionIDByAlias(ctx, opts.groupID[:]) + if errors.Is(err, sql.ErrNoRows) { + return session.ErrUnknownGroup + } else if err != nil { + return fmt.Errorf("unable to get DB ID for "+ + "legacy group ID %x: %w", opts.groupID, + err) + } + + groupID = sqldb.SQLInt64(gID) + } + + var ( + dbActions []sqlc.Action + err error + ) + actionQueryParams := sqlc.ActionQueryParams{ + SessionID: sessionID, + GroupID: groupID, + FeatureName: feature, + ActorName: actorName, + RpcMethod: rpcMethod, + State: actionState, + EndTime: endTime, + StartTime: startTime, + } + queryParams := sqlc.ListActionsParams{ + ActionQueryParams: actionQueryParams, + Reversed: false, + } + if query != nil { + queryParams.Reversed = query.Reversed + queryParams.Pagination = &sqlc.Pagination{ + NumLimit: func() int32 { + if query.MaxNum == 0 { + return int32(math.MaxInt32) + } + + return int32(query.MaxNum) + }(), + NumOffset: int32(query.IndexOffset), + } + } + + dbActions, err = db.ListActions(ctx, queryParams) + if err != nil { + return fmt.Errorf("unable to list actions: %w", err) + } + + // If pagination was used, then the number of results returned + // won't necessarily match the total number of actions that + // match the query. So, if pagination was used and the CountAll + // flag is set, then we need to count the total number of + // actions that match the query. + if query != nil && query.CountAll { + totalCount, err = db.CountActions( + ctx, actionQueryParams, + ) + if err != nil { + return fmt.Errorf("unable to count actions: %w", + err) + } + } + + actions = make([]*Action, len(dbActions)) + for i, dbAction := range dbActions { + action, err := unmarshalAction(ctx, db, dbAction) + if err != nil { + return fmt.Errorf("unable to unmarshal "+ + "action: %w", err) + } + + actions[i] = action + lastIndex = uint64(dbAction.ID) + } + + return nil + }) + + return actions, lastIndex, uint64(totalCount), err +} + +func unmarshalAction(ctx context.Context, db SQLActionQueries, + dbAction sqlc.Action) (*Action, error) { + + var legacySessID fn.Option[session.ID] + if dbAction.SessionID.Valid { + legacySessIDB, err := db.GetAliasBySessionID( + ctx, dbAction.SessionID.Int64, + ) + if err != nil { + return nil, fmt.Errorf("unable to get legacy "+ + "session ID for session ID %d: %w", + dbAction.SessionID.Int64, err) + } + + sessID, err := session.IDFromBytes(legacySessIDB) + if err != nil { + return nil, err + } + + legacySessID = fn.Some(sessID) + } + + var legacyAcctID fn.Option[accounts.AccountID] + if dbAction.AccountID.Valid { + acct, err := db.GetAccount(ctx, dbAction.AccountID.Int64) + if err != nil { + return nil, err + } + + acctID, err := accounts.AccountIDFromInt64(acct.Alias) + if err != nil { + return nil, fmt.Errorf("unable to get account ID: %w", + err) + } + + legacyAcctID = fn.Some(acctID) + } + + var macID fn.Option[[4]byte] + if len(dbAction.MacaroonIdentifier) > 0 { + macID = fn.Some([4]byte(dbAction.MacaroonIdentifier)) + } + + return &Action{ + AddActionReq: AddActionReq{ + MacaroonIdentifier: macID, + AccountID: legacyAcctID, + SessionID: legacySessID, + ActorName: dbAction.ActorName.String, + FeatureName: dbAction.FeatureName.String, + Trigger: dbAction.ActionTrigger.String, + Intent: dbAction.Intent.String, + StructuredJsonData: string(dbAction.StructuredJsonData), + RPCMethod: dbAction.RpcMethod, + RPCParamsJson: dbAction.RpcParamsJson, + }, + AttemptedAt: dbAction.CreatedAt, + State: ActionState(dbAction.ActionState), + ErrorReason: dbAction.ErrorReason.String, + }, nil +} diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index a355bede0..c27e53e96 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -24,18 +24,17 @@ func TestActionStorage(t *testing.T) { ctx := context.Background() clock := clock.NewTestClock(testTime1) - sessDB := session.NewTestDB(t, clock) accountsDB := accounts.NewTestDB(t, clock) + sessDB := session.NewTestDBWithAccounts(t, clock, accountsDB) - db, err := NewBoltDB(t.TempDir(), "test.db", sessDB, accountsDB, clock) - require.NoError(t, err) + db := NewTestDBWithSessionsAndAccounts(t, sessDB, accountsDB, clock) t.Cleanup(func() { _ = db.Close() }) // Assert that attempting to add an action for a session that does not // exist returns an error. - _, err = db.AddAction(ctx, &AddActionReq{ + _, err := db.AddAction(ctx, &AddActionReq{ SessionID: fn.Some(session.ID{1, 2, 3, 4}), }) require.ErrorIs(t, err, session.ErrSessionNotFound) @@ -194,13 +193,11 @@ func TestActionStorage(t *testing.T) { func TestListActions(t *testing.T) { t.Parallel() - tmpDir := t.TempDir() ctx := context.Background() clock := clock.NewDefaultClock() sessDB := session.NewTestDB(t, clock) - db, err := NewBoltDB(tmpDir, "test.db", sessDB, nil, clock) - require.NoError(t, err) + db := NewTestDBWithSessions(t, sessDB, clock) t.Cleanup(func() { _ = db.Close() }) @@ -468,8 +465,7 @@ func TestListGroupActions(t *testing.T) { State: ActionStateInit, } - db, err := NewBoltDB(t.TempDir(), "test.db", sessDB, nil, clock) - require.NoError(t, err) + db := NewTestDBWithSessions(t, sessDB, clock) t.Cleanup(func() { _ = db.Close() }) @@ -514,22 +510,3 @@ func TestListGroupActions(t *testing.T) { assertEqualActions(t, action2, al[0]) assertEqualActions(t, action1, al[1]) } - -func assertEqualActions(t *testing.T, expected, got *Action) { - // Accounts are not explicitly linked in our bbolt DB implementation. - got.AccountID = expected.AccountID - - expectedAttemptedAt := expected.AttemptedAt - actualAttemptedAt := got.AttemptedAt - - expected.AttemptedAt = time.Time{} - got.AttemptedAt = time.Time{} - - require.Equal(t, expected, got) - require.Equal(t, expectedAttemptedAt.Unix(), actualAttemptedAt.Unix()) - - expected.AttemptedAt = expectedAttemptedAt - got.AttemptedAt = actualAttemptedAt - - got.AccountID = fn.None[accounts.AccountID]() -} diff --git a/firewalldb/db.go b/firewalldb/db.go index 8b913b69d..b8d9ed06f 100644 --- a/firewalldb/db.go +++ b/firewalldb/db.go @@ -19,6 +19,7 @@ var ( type firewallDBs interface { RulesDB PrivacyMapper + ActionDB } // DB manages the firewall rules database. diff --git a/firewalldb/kvstores_sql.go b/firewalldb/kvstores_sql.go index e7e1e7da5..0c3df2ddb 100644 --- a/firewalldb/kvstores_sql.go +++ b/firewalldb/kvstores_sql.go @@ -13,12 +13,6 @@ import ( "github.com/lightningnetwork/lnd/fn" ) -// SQLSessionQueries is a subset of the sqlc.Queries interface that can be used -// to interact with the session table. -type SQLSessionQueries interface { - GetSessionIDByAlias(ctx context.Context, legacyID []byte) (int64, error) -} - // SQLKVStoreQueries is a subset of the sqlc.Queries interface that can be // used to interact with the kvstore tables. // diff --git a/firewalldb/sql_store.go b/firewalldb/sql_store.go index 369920d63..f17010f2c 100644 --- a/firewalldb/sql_store.go +++ b/firewalldb/sql_store.go @@ -8,11 +8,19 @@ import ( "github.com/lightningnetwork/lnd/clock" ) +// SQLSessionQueries is a subset of the sqlc.Queries interface that can be used +// to interact with the session table. +type SQLSessionQueries interface { + GetSessionIDByAlias(ctx context.Context, legacyID []byte) (int64, error) + GetAliasBySessionID(ctx context.Context, id int64) ([]byte, error) +} + // SQLQueries is a subset of the sqlc.Queries interface that can be used to // interact with various firewalldb tables. type SQLQueries interface { SQLKVStoreQueries SQLPrivacyPairQueries + SQLActionQueries } // BatchedSQLQueries is a version of the SQLQueries that's capable of batched @@ -39,6 +47,10 @@ type SQLDB struct { // interface. var _ RulesDB = (*SQLDB)(nil) +// A compile-time assertion to ensure that SQLDB implements the ActionsDB +// interface. +var _ ActionDB = (*SQLDB)(nil) + // NewSQLDB creates a new SQLStore instance given an open SQLQueries // storage backend. func NewSQLDB(sqlDB *db.BaseDB, clock clock.Clock) *SQLDB { diff --git a/firewalldb/test_kvdb.go b/firewalldb/test_kvdb.go index 659292702..6f7a49aa3 100644 --- a/firewalldb/test_kvdb.go +++ b/firewalldb/test_kvdb.go @@ -5,7 +5,9 @@ package firewalldb import ( "testing" + "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" "github.com/stretchr/testify/require" ) @@ -28,6 +30,16 @@ func NewTestDBWithSessions(t *testing.T, sessStore SessionDB, return newDBFromPathWithSessions(t, t.TempDir(), sessStore, nil, clock) } +// NewTestDBWithSessionsAndAccounts creates a new test BoltDB Store with access +// to an existing sessions DB and accounts DB. +func NewTestDBWithSessionsAndAccounts(t *testing.T, sessStore SessionDB, + acctStore AccountsDB, clock clock.Clock) *BoltDB { + + return newDBFromPathWithSessions( + t, t.TempDir(), sessStore, acctStore, clock, + ) +} + func newDBFromPathWithSessions(t *testing.T, dbPath string, sessStore SessionDB, acctStore AccountsDB, clock clock.Clock) *BoltDB { @@ -40,3 +52,11 @@ func newDBFromPathWithSessions(t *testing.T, dbPath string, return store } + +func assertEqualActions(t *testing.T, expected, got *Action) { + // Accounts are not explicitly linked in our bbolt DB implementation. + got.AccountID = expected.AccountID + require.Equal(t, expected, got) + + got.AccountID = fn.None[accounts.AccountID]() +} diff --git a/firewalldb/test_sql.go b/firewalldb/test_sql.go index 947ff1491..03dcfbebf 100644 --- a/firewalldb/test_sql.go +++ b/firewalldb/test_sql.go @@ -4,7 +4,9 @@ package firewalldb import ( "testing" + "time" + "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/session" "github.com/lightningnetwork/lnd/clock" "github.com/stretchr/testify/require" @@ -20,3 +22,33 @@ func NewTestDBWithSessions(t *testing.T, sessionStore session.Store, return NewSQLDB(sessions.BaseDB, clock) } + +// NewTestDBWithSessionsAndAccounts creates a new test SQLDB Store with access +// to an existing sessions DB and accounts DB. +func NewTestDBWithSessionsAndAccounts(t *testing.T, sessionStore SessionDB, + acctStore AccountsDB, clock clock.Clock) *SQLDB { + + sessions, ok := sessionStore.(*session.SQLStore) + require.True(t, ok) + + accounts, ok := acctStore.(*accounts.SQLStore) + require.True(t, ok) + + require.Equal(t, accounts.BaseDB, sessions.BaseDB) + + return NewSQLDB(sessions.BaseDB, clock) +} + +func assertEqualActions(t *testing.T, expected, got *Action) { + expectedAttemptedAt := expected.AttemptedAt + actualAttemptedAt := got.AttemptedAt + + expected.AttemptedAt = time.Time{} + got.AttemptedAt = time.Time{} + + require.Equal(t, expected, got) + require.Equal(t, expectedAttemptedAt.Unix(), actualAttemptedAt.Unix()) + + expected.AttemptedAt = expectedAttemptedAt + got.AttemptedAt = actualAttemptedAt +} diff --git a/terminal.go b/terminal.go index e08e9a25f..7e4d552c7 100644 --- a/terminal.go +++ b/terminal.go @@ -239,8 +239,7 @@ type stores struct { accounts accounts.Store sessions session.Store - firewall *firewalldb.DB - firewallBolt *firewalldb.BoltDB + firewall *firewalldb.DB // closeFns holds various callbacks that can be used to close any open // stores in the stores struct. @@ -531,7 +530,7 @@ func (g *LightningTerminal) start(ctx context.Context) error { superMacBaker: superMacBaker, firstConnectionDeadline: g.cfg.FirstLNCConnDeadline, permMgr: g.permsMgr, - actionsDB: g.stores.firewallBolt, + actionsDB: g.stores.firewall, autopilot: g.autopilotClient, ruleMgrs: g.ruleMgrs, privMap: g.stores.firewall, @@ -1093,7 +1092,7 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, } requestLogger, err := firewall.NewRequestLogger( - g.cfg.Firewall.RequestLogger, g.stores.firewallBolt, + g.cfg.Firewall.RequestLogger, g.stores.firewall, ) if err != nil { return fmt.Errorf("error creating new request logger") @@ -1112,7 +1111,7 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, if !g.cfg.Autopilot.Disable { ruleEnforcer := firewall.NewRuleEnforcer( - g.stores.firewall, g.stores.firewallBolt, + g.stores.firewall, g.stores.firewall, g.stores.sessions, g.autopilotClient.ListFeaturePerms, g.permsMgr, g.lndClient.NodePubkey,