From ec3b38f6075dcf714732935a261c225aa2edba0d Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 19 Apr 2025 14:24:32 +0200 Subject: [PATCH 1/4] firewalldb: put Action DB methods behind an interface So that we can easily add a different implementation and swop them out later. --- firewalldb/actions.go | 2 +- firewalldb/interface.go | 24 ++++++++++++++++++++++++ session_rpcserver.go | 2 +- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/firewalldb/actions.go b/firewalldb/actions.go index 493220e87..825a544a1 100644 --- a/firewalldb/actions.go +++ b/firewalldb/actions.go @@ -230,7 +230,7 @@ func (db *BoltDB) GetActionsReadDB(groupID session.ID, // allActionsReadDb is an implementation of the ActionsReadDB. type allActionsReadDB struct { - db *BoltDB + db ActionDB groupID session.ID featureName string } diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 401b3b8d6..9ef4eb0e8 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -100,3 +100,27 @@ type PrivacyMapper interface { // given group ID key. PrivacyDB(groupID session.ID) PrivacyMapDB } + +// ActionDB is an interface that abstracts the database operations needed for +// the Action persistence and querying. +type ActionDB interface { + // AddAction persists the given action to the database. + AddAction(action *Action) (uint64, error) + + // SetActionState finds the action specified by the ActionLocator and + // sets its state to the given state. + SetActionState(al *ActionLocator, state ActionState, + errReason string) error + + // ListActions returns a list of Actions that pass the filterFn + // requirements. The query IndexOffset and MaxNum params can be used to + // control the number of actions returned. The return values are the + // list of actions, the last index and the total count (iff + // query.CountTotal is set). + ListActions(ctx context.Context, query *ListActionsQuery, + options ...ListActionOption) ([]*Action, uint64, uint64, error) + + // GetActionsReadDB produces an ActionReadDB using the given group ID + // and feature name. + GetActionsReadDB(groupID session.ID, featureName string) ActionsReadDB +} diff --git a/session_rpcserver.go b/session_rpcserver.go index 34eb4aa7e..40d649d1d 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -63,7 +63,7 @@ type sessionRpcServerConfig struct { superMacBaker litmac.Baker firstConnectionDeadline time.Duration permMgr *perms.Manager - actionsDB *firewalldb.BoltDB + actionsDB firewalldb.ActionDB autopilot autopilotserver.Autopilot ruleMgrs rules.ManagerSet privMap firewalldb.PrivacyMapper From c1ee88462672da93a55b7f559dc59fee3fe33689 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 19 Apr 2025 14:28:33 +0200 Subject: [PATCH 2/4] multi: let most ActionDB methods take a context --- firewall/request_logger.go | 14 +++++++------- firewall/rule_enforcer.go | 12 ++++++++---- firewalldb/actions.go | 6 +++--- firewalldb/actions_kvdb.go | 6 +++--- firewalldb/actions_test.go | 18 +++++++++--------- firewalldb/interface.go | 6 +++--- terminal.go | 6 ++++-- 7 files changed, 37 insertions(+), 31 deletions(-) diff --git a/firewall/request_logger.go b/firewall/request_logger.go index 7c7abb043..72ed212d8 100644 --- a/firewall/request_logger.go +++ b/firewall/request_logger.go @@ -128,7 +128,7 @@ func (r *RequestLogger) CustomCaveatName() string { // Intercept processes an RPC middleware interception request and returns the // interception result which either accepts or rejects the intercepted message. -func (r *RequestLogger) Intercept(_ context.Context, +func (r *RequestLogger) Intercept(ctx context.Context, req *lnrpc.RPCMiddlewareRequest) (*lnrpc.RPCMiddlewareResponse, error) { ri, err := NewInfoFromRequest(req) @@ -156,7 +156,7 @@ func (r *RequestLogger) Intercept(_ context.Context, // Parse incoming requests and act on them. case MWRequestTypeRequest: - return mid.RPCErr(req, r.addNewAction(ri, withPayloadData)) + return mid.RPCErr(req, r.addNewAction(ctx, ri, withPayloadData)) // Parse and possibly manipulate outgoing responses. case MWRequestTypeResponse: @@ -170,7 +170,7 @@ func (r *RequestLogger) Intercept(_ context.Context, } return mid.RPCErr( - req, r.MarkAction(ri.RequestID, state, errReason), + req, r.MarkAction(ctx, ri.RequestID, state, errReason), ) default: @@ -179,7 +179,7 @@ func (r *RequestLogger) Intercept(_ context.Context, } // addNewAction persists the new action to the db. -func (r *RequestLogger) addNewAction(ri *RequestInfo, +func (r *RequestLogger) addNewAction(ctx context.Context, ri *RequestInfo, withPayloadData bool) error { // If no macaroon is provided, then an empty 4-byte array is used as the @@ -223,7 +223,7 @@ func (r *RequestLogger) addNewAction(ri *RequestInfo, } } - id, err := r.actionsDB.AddAction(action) + id, err := r.actionsDB.AddAction(ctx, action) if err != nil { return err } @@ -240,7 +240,7 @@ func (r *RequestLogger) addNewAction(ri *RequestInfo, // MarkAction can be used to set the state of an action identified by the given // requestID. -func (r *RequestLogger) MarkAction(reqID uint64, +func (r *RequestLogger) MarkAction(ctx context.Context, reqID uint64, state firewalldb.ActionState, errReason string) error { r.mu.Lock() @@ -252,5 +252,5 @@ func (r *RequestLogger) MarkAction(reqID uint64, } delete(r.reqIDToAction, reqID) - return r.actionsDB.SetActionState(actionLocator, state, errReason) + return r.actionsDB.SetActionState(ctx, actionLocator, state, errReason) } diff --git a/firewall/rule_enforcer.go b/firewall/rule_enforcer.go index 472143f05..35f92c534 100644 --- a/firewall/rule_enforcer.go +++ b/firewall/rule_enforcer.go @@ -32,8 +32,9 @@ type RuleEnforcer struct { ruleDB firewalldb.RulesDB actionsDB firewalldb.ActionReadDBGetter sessionDB firewalldb.SessionDB - markActionErrored func(reqID uint64, reason string) error - privMapDB firewalldb.PrivacyMapper + markActionErrored func(ctx context.Context, reqID uint64, + reason string) error + privMapDB firewalldb.PrivacyMapper permsMgr *perms.Manager getFeaturePerms featurePerms @@ -63,7 +64,8 @@ func NewRuleEnforcer(ruleDB firewalldb.RulesDB, routerClient lndclient.RouterClient, lndClient lndclient.LightningClient, lndConnID string, ruleMgrs rules.ManagerSet, - markActionErrored func(reqID uint64, reason string) error, + markActionErrored func(ctx context.Context, reqID uint64, + reason string) error, privMap firewalldb.PrivacyMapper) *RuleEnforcer { return &RuleEnforcer{ @@ -164,7 +166,9 @@ func (r *RuleEnforcer) Intercept(ctx context.Context, replacement, err := r.handleRequest(ctx, ri) if err != nil { - dbErr := r.markActionErrored(ri.RequestID, err.Error()) + dbErr := r.markActionErrored( + ctx, ri.RequestID, err.Error(), + ) if dbErr != nil { log.Error("could not mark action for "+ "request ID %d as Errored: %v", diff --git a/firewalldb/actions.go b/firewalldb/actions.go index 825a544a1..4743c85f5 100644 --- a/firewalldb/actions.go +++ b/firewalldb/actions.go @@ -181,9 +181,9 @@ func WithActionState(state ActionState) ListActionOption { // ActionsWriteDB is an abstraction over the Actions DB that will allow a // caller to add new actions as well as change the values of an existing action. type ActionsWriteDB interface { - AddAction(action *Action) (uint64, error) - SetActionState(al *ActionLocator, state ActionState, - errReason string) error + AddAction(ctx context.Context, action *Action) (uint64, error) + SetActionState(ctx context.Context, al *ActionLocator, + state ActionState, errReason string) error } // RuleAction represents a method call that was performed at a certain time at diff --git a/firewalldb/actions_kvdb.go b/firewalldb/actions_kvdb.go index 73fd8a61b..4b305c5f5 100644 --- a/firewalldb/actions_kvdb.go +++ b/firewalldb/actions_kvdb.go @@ -53,7 +53,7 @@ var ( ) // AddAction serialises and adds an Action to the DB under the given sessionID. -func (db *BoltDB) AddAction(action *Action) (uint64, error) { +func (db *BoltDB) AddAction(_ context.Context, action *Action) (uint64, error) { var buf bytes.Buffer if err := SerializeAction(&buf, action); err != nil { return 0, err @@ -167,8 +167,8 @@ func getAction(actionsBkt *bbolt.Bucket, al *ActionLocator) (*Action, error) { // SetActionState finds the action specified by the ActionLocator and sets its // state to the given state. -func (db *BoltDB) SetActionState(al *ActionLocator, state ActionState, - errorReason string) error { +func (db *BoltDB) SetActionState(_ context.Context, al *ActionLocator, + state ActionState, errorReason string) error { if errorReason != "" && state != ActionStateError { return fmt.Errorf("error reason should only be set for " + diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index f09e9f640..596d9a76c 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -66,11 +66,11 @@ func TestActionStorage(t *testing.T) { require.NoError(t, err) require.Len(t, actions, 0) - id, err := db.AddAction(action1) + id, err := db.AddAction(ctx, action1) require.NoError(t, err) require.Equal(t, uint64(1), id) - id, err = db.AddAction(action2) + id, err = db.AddAction(ctx, action2) require.NoError(t, err) require.Equal(t, uint64(1), id) @@ -92,7 +92,7 @@ func TestActionStorage(t *testing.T) { require.Len(t, actions, 0) err = db.SetActionState( - &ActionLocator{ + ctx, &ActionLocator{ SessionID: sessionID2, ActionID: uint64(1), }, ActionStateDone, "", @@ -109,7 +109,7 @@ func TestActionStorage(t *testing.T) { action2.State = ActionStateDone require.Equal(t, action2, actions[0]) - id, err = db.AddAction(action1) + id, err = db.AddAction(ctx, action1) require.NoError(t, err) require.Equal(t, uint64(2), id) @@ -125,7 +125,7 @@ func TestActionStorage(t *testing.T) { // Try set an error reason for a non Errored state. err = db.SetActionState( - &ActionLocator{ + ctx, &ActionLocator{ SessionID: sessionID2, ActionID: uint64(1), }, ActionStateDone, "hello", @@ -134,7 +134,7 @@ func TestActionStorage(t *testing.T) { // Now try move the action to errored with a reason. err = db.SetActionState( - &ActionLocator{ + ctx, &ActionLocator{ SessionID: sessionID2, ActionID: uint64(1), }, ActionStateError, "fail whale", @@ -184,7 +184,7 @@ func TestListActions(t *testing.T) { State: ActionStateDone, } - _, err := db.AddAction(action) + _, err := db.AddAction(ctx, action) require.NoError(t, err) } @@ -373,7 +373,7 @@ func TestListGroupActions(t *testing.T) { require.Empty(t, al) // Add an action under session 1. - _, err = db.AddAction(action1) + _, err = db.AddAction(ctx, action1) require.NoError(t, err) // There should now be one action in the group. @@ -383,7 +383,7 @@ func TestListGroupActions(t *testing.T) { require.Equal(t, sessionID1, al[0].SessionID) // Add an action under session 2. - _, err = db.AddAction(action2) + _, err = db.AddAction(ctx, action2) require.NoError(t, err) // There should now be actions in the group. diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 9ef4eb0e8..bb057a69a 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -105,12 +105,12 @@ type PrivacyMapper interface { // the Action persistence and querying. type ActionDB interface { // AddAction persists the given action to the database. - AddAction(action *Action) (uint64, error) + AddAction(ctx context.Context, action *Action) (uint64, error) // SetActionState finds the action specified by the ActionLocator and // sets its state to the given state. - SetActionState(al *ActionLocator, state ActionState, - errReason string) error + SetActionState(ctx context.Context, al *ActionLocator, + state ActionState, errReason string) error // ListActions returns a list of Actions that pass the filterFn // requirements. The query IndexOffset and MaxNum params can be used to diff --git a/terminal.go b/terminal.go index 9b35e0a66..9f9b6ef49 100644 --- a/terminal.go +++ b/terminal.go @@ -1118,9 +1118,11 @@ func (g *LightningTerminal) startInternalSubServers(ctx context.Context, g.permsMgr, g.lndClient.NodePubkey, g.lndClient.Router, g.lndClient.Client, g.lndConnID, g.ruleMgrs, - func(reqID uint64, reason string) error { + func(ctx context.Context, reqID uint64, + reason string) error { + return requestLogger.MarkAction( - reqID, firewalldb.ActionStateError, + ctx, reqID, firewalldb.ActionStateError, reason, ) }, g.stores.firewall, From 2cb10d28b1c5d1c205cb0c969de9579a27538fd9 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 19 Apr 2025 14:48:54 +0200 Subject: [PATCH 3/4] firewalldb: abstract ActionLocator The current ActionLocator is very specific to how actions are stored in the bbolt db. In our SQL implementation, we will simply have an auto-incrementing int64 that we will use as our locator for any action. In preparation for this, we make ActionLocator an abstract interface and implement our bbolt version of it. --- firewall/request_logger.go | 11 ++---- firewalldb/actions.go | 9 ++--- firewalldb/actions_kvdb.go | 76 ++++++++++++++++++++++++-------------- firewalldb/actions_test.go | 30 +++------------ firewalldb/interface.go | 4 +- 5 files changed, 64 insertions(+), 66 deletions(-) diff --git a/firewall/request_logger.go b/firewall/request_logger.go index 72ed212d8..ad602c48d 100644 --- a/firewall/request_logger.go +++ b/firewall/request_logger.go @@ -53,7 +53,7 @@ type RequestLogger struct { // be used to find the corresponding action. This is used so that // requests and responses can be easily linked. The mu mutex must be // used when accessing this map. - reqIDToAction map[uint64]*firewalldb.ActionLocator + reqIDToAction map[uint64]firewalldb.ActionLocator mu sync.Mutex } @@ -105,7 +105,7 @@ func NewRequestLogger(cfg *RequestLoggerConfig, return &RequestLogger{ shouldLogAction: shouldLogAction, actionsDB: actionsDB, - reqIDToAction: make(map[uint64]*firewalldb.ActionLocator), + reqIDToAction: make(map[uint64]firewalldb.ActionLocator), }, nil } @@ -223,16 +223,13 @@ func (r *RequestLogger) addNewAction(ctx context.Context, ri *RequestInfo, } } - id, err := r.actionsDB.AddAction(ctx, action) + locator, err := r.actionsDB.AddAction(ctx, action) if err != nil { return err } r.mu.Lock() - r.reqIDToAction[ri.RequestID] = &firewalldb.ActionLocator{ - SessionID: sessionID, - ActionID: id, - } + r.reqIDToAction[ri.RequestID] = locator r.mu.Unlock() return nil diff --git a/firewalldb/actions.go b/firewalldb/actions.go index 4743c85f5..8f81b02c8 100644 --- a/firewalldb/actions.go +++ b/firewalldb/actions.go @@ -181,8 +181,8 @@ func WithActionState(state ActionState) ListActionOption { // ActionsWriteDB is an abstraction over the Actions DB that will allow a // caller to add new actions as well as change the values of an existing action. type ActionsWriteDB interface { - AddAction(ctx context.Context, action *Action) (uint64, error) - SetActionState(ctx context.Context, al *ActionLocator, + AddAction(ctx context.Context, action *Action) (ActionLocator, error) + SetActionState(ctx context.Context, al ActionLocator, state ActionState, errReason string) error } @@ -318,7 +318,6 @@ func actionToRulesAction(a *Action) *RuleAction { } // ActionLocator helps us find an action in the database. -type ActionLocator struct { - SessionID session.ID - ActionID uint64 +type ActionLocator interface { + isActionLocator() } diff --git a/firewalldb/actions_kvdb.go b/firewalldb/actions_kvdb.go index 4b305c5f5..5f5615953 100644 --- a/firewalldb/actions_kvdb.go +++ b/firewalldb/actions_kvdb.go @@ -53,13 +53,15 @@ var ( ) // AddAction serialises and adds an Action to the DB under the given sessionID. -func (db *BoltDB) AddAction(_ context.Context, action *Action) (uint64, error) { +func (db *BoltDB) AddAction(_ context.Context, action *Action) (ActionLocator, + error) { + var buf bytes.Buffer if err := SerializeAction(&buf, action); err != nil { - return 0, err + return nil, err } - var id uint64 + var locator kvdbActionLocator err := db.DB.Update(func(tx *bbolt.Tx) error { mainActionsBucket, err := getBucket(tx, actionsBucketKey) if err != nil { @@ -82,7 +84,6 @@ func (db *BoltDB) AddAction(_ context.Context, action *Action) (uint64, error) { if err != nil { return err } - id = nextActionIndex var actionIndex [8]byte byteOrder.PutUint64(actionIndex[:], nextActionIndex) @@ -101,9 +102,9 @@ func (db *BoltDB) AddAction(_ context.Context, action *Action) (uint64, error) { return err } - locator := ActionLocator{ - SessionID: action.SessionID, - ActionID: nextActionIndex, + locator = kvdbActionLocator{ + sessionID: action.SessionID, + actionID: nextActionIndex, } var buf bytes.Buffer @@ -117,13 +118,25 @@ func (db *BoltDB) AddAction(_ context.Context, action *Action) (uint64, error) { return actionsIndexBucket.Put(seqNoBytes[:], buf.Bytes()) }) if err != nil { - return 0, err + return nil, err } - return id, nil + return &locator, nil +} + +// kvdbActionLocator helps us find an action in a KVDB database. +type kvdbActionLocator struct { + sessionID session.ID + actionID uint64 } -func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error { +// A compile-time check to ensure kvdbActionLocator implements the ActionLocator +// interface. +var _ ActionLocator = (*kvdbActionLocator)(nil) + +func (al *kvdbActionLocator) isActionLocator() {} + +func putAction(tx *bbolt.Tx, al *kvdbActionLocator, a *Action) error { var buf bytes.Buffer if err := SerializeAction(&buf, a); err != nil { return err @@ -139,35 +152,37 @@ func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error { return ErrNoSuchKeyFound } - sessBucket := actionsBucket.Bucket(al.SessionID[:]) + sessBucket := actionsBucket.Bucket(al.sessionID[:]) if sessBucket == nil { return fmt.Errorf("session bucket for session ID %x does not "+ - "exist", al.SessionID) + "exist", al.sessionID) } var id [8]byte - binary.BigEndian.PutUint64(id[:], al.ActionID) + binary.BigEndian.PutUint64(id[:], al.actionID) return sessBucket.Put(id[:], buf.Bytes()) } -func getAction(actionsBkt *bbolt.Bucket, al *ActionLocator) (*Action, error) { - sessBucket := actionsBkt.Bucket(al.SessionID[:]) +func getAction(actionsBkt *bbolt.Bucket, al *kvdbActionLocator) (*Action, + error) { + + sessBucket := actionsBkt.Bucket(al.sessionID[:]) if sessBucket == nil { return nil, fmt.Errorf("session bucket for session ID "+ - "%x does not exist", al.SessionID) + "%x does not exist", al.sessionID) } var id [8]byte - binary.BigEndian.PutUint64(id[:], al.ActionID) + binary.BigEndian.PutUint64(id[:], al.actionID) actionBytes := sessBucket.Get(id[:]) - return DeserializeAction(bytes.NewReader(actionBytes), al.SessionID) + return DeserializeAction(bytes.NewReader(actionBytes), al.sessionID) } // SetActionState finds the action specified by the ActionLocator and sets its // state to the given state. -func (db *BoltDB) SetActionState(_ context.Context, al *ActionLocator, +func (db *BoltDB) SetActionState(_ context.Context, al ActionLocator, state ActionState, errorReason string) error { if errorReason != "" && state != ActionStateError { @@ -175,6 +190,11 @@ func (db *BoltDB) SetActionState(_ context.Context, al *ActionLocator, "ActionStateError") } + locator, ok := al.(*kvdbActionLocator) + if !ok { + return fmt.Errorf("expected kvdbActionLocator, got %T", al) + } + return db.DB.Update(func(tx *bbolt.Tx) error { mainActionsBucket, err := getBucket(tx, actionsBucketKey) if err != nil { @@ -186,7 +206,7 @@ func (db *BoltDB) SetActionState(_ context.Context, al *ActionLocator, return ErrNoSuchKeyFound } - action, err := getAction(actionsBucket, al) + action, err := getAction(actionsBucket, locator) if err != nil { return err } @@ -194,7 +214,7 @@ func (db *BoltDB) SetActionState(_ context.Context, al *ActionLocator, action.State = state action.ErrorReason = errorReason - return putAction(tx, al, action) + return putAction(tx, locator, action) }) } @@ -540,14 +560,14 @@ func DeserializeAction(r io.Reader, sessionID session.ID) (*Action, error) { // serializeActionLocator binary serializes the given ActionLocator to the // writer using the tlv format. -func serializeActionLocator(w io.Writer, al *ActionLocator) error { +func serializeActionLocator(w io.Writer, al *kvdbActionLocator) error { if al == nil { return fmt.Errorf("action locator cannot be nil") } var ( - sessionID = al.SessionID[:] - actionID = al.ActionID + sessionID = al.sessionID[:] + actionID = al.actionID ) tlvRecords := []tlv.Record{ @@ -565,7 +585,7 @@ func serializeActionLocator(w io.Writer, al *ActionLocator) error { // deserializeActionLocator deserializes an ActionLocator from the given reader, // expecting the data to be encoded in the tlv format. -func deserializeActionLocator(r io.Reader) (*ActionLocator, error) { +func deserializeActionLocator(r io.Reader) (*kvdbActionLocator, error) { var ( sessionID []byte actionID uint64 @@ -588,8 +608,8 @@ func deserializeActionLocator(r io.Reader) (*ActionLocator, error) { return nil, err } - return &ActionLocator{ - SessionID: id, - ActionID: actionID, + return &kvdbActionLocator{ + sessionID: id, + actionID: actionID, }, nil } diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index 596d9a76c..541e37e95 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -66,13 +66,11 @@ func TestActionStorage(t *testing.T) { require.NoError(t, err) require.Len(t, actions, 0) - id, err := db.AddAction(ctx, action1) + _, err = db.AddAction(ctx, action1) require.NoError(t, err) - require.Equal(t, uint64(1), id) - id, err = db.AddAction(ctx, action2) + locator2, err := db.AddAction(ctx, action2) require.NoError(t, err) - require.Equal(t, uint64(1), id) actions, _, _, err = db.ListActions( ctx, nil, @@ -91,12 +89,7 @@ func TestActionStorage(t *testing.T) { require.NoError(t, err) require.Len(t, actions, 0) - err = db.SetActionState( - ctx, &ActionLocator{ - SessionID: sessionID2, - ActionID: uint64(1), - }, ActionStateDone, "", - ) + err = db.SetActionState(ctx, locator2, ActionStateDone, "") require.NoError(t, err) actions, _, _, err = db.ListActions( @@ -109,9 +102,8 @@ func TestActionStorage(t *testing.T) { action2.State = ActionStateDone require.Equal(t, action2, actions[0]) - id, err = db.AddAction(ctx, action1) + _, err = db.AddAction(ctx, action1) require.NoError(t, err) - require.Equal(t, uint64(2), id) // Check that providing no session id and no filter function returns // all the actions. @@ -124,21 +116,11 @@ func TestActionStorage(t *testing.T) { require.Len(t, actions, 3) // Try set an error reason for a non Errored state. - err = db.SetActionState( - ctx, &ActionLocator{ - SessionID: sessionID2, - ActionID: uint64(1), - }, ActionStateDone, "hello", - ) + err = db.SetActionState(ctx, locator2, ActionStateDone, "hello") require.Error(t, err) // Now try move the action to errored with a reason. - err = db.SetActionState( - ctx, &ActionLocator{ - SessionID: sessionID2, - ActionID: uint64(1), - }, ActionStateError, "fail whale", - ) + err = db.SetActionState(ctx, locator2, ActionStateError, "fail whale") require.NoError(t, err) actions, _, _, err = db.ListActions( diff --git a/firewalldb/interface.go b/firewalldb/interface.go index bb057a69a..a024ce8eb 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -105,11 +105,11 @@ type PrivacyMapper interface { // the Action persistence and querying. type ActionDB interface { // AddAction persists the given action to the database. - AddAction(ctx context.Context, action *Action) (uint64, error) + AddAction(ctx context.Context, action *Action) (ActionLocator, error) // SetActionState finds the action specified by the ActionLocator and // sets its state to the given state. - SetActionState(ctx context.Context, al *ActionLocator, + SetActionState(ctx context.Context, al ActionLocator, state ActionState, errReason string) error // ListActions returns a list of Actions that pass the filterFn From 53fc1d268d029d696087bff5bbde327d599bbf85 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 27 Dec 2024 14:43:16 +0200 Subject: [PATCH 4/4] firewalldb: assert actions equal helper Use a helper to compare Actions and prepare the timestamp comparisons so that they are ready for our SQL implementation of the Actions DB. --- firewalldb/actions_test.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index 541e37e95..26f33596d 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -79,7 +79,7 @@ func TestActionStorage(t *testing.T) { ) require.NoError(t, err) require.Len(t, actions, 1) - require.Equal(t, action1, actions[0]) + assertEqualActions(t, action1, actions[0]) actions, _, _, err = db.ListActions( ctx, nil, @@ -100,7 +100,7 @@ func TestActionStorage(t *testing.T) { require.NoError(t, err) require.Len(t, actions, 1) action2.State = ActionStateDone - require.Equal(t, action2, actions[0]) + assertEqualActions(t, action2, actions[0]) _, err = db.AddAction(ctx, action1) require.NoError(t, err) @@ -132,7 +132,7 @@ func TestActionStorage(t *testing.T) { require.Len(t, actions, 1) action2.State = ActionStateError action2.ErrorReason = "fail whale" - require.Equal(t, action2, actions[0]) + assertEqualActions(t, action2, actions[0]) } // TestListActions tests some ListAction options. @@ -375,3 +375,17 @@ func TestListGroupActions(t *testing.T) { require.Equal(t, sessionID1, al[0].SessionID) require.Equal(t, sessionID2, al[1].SessionID) } + +func assertEqualActions(t *testing.T, expected, got *Action) { + expectedAttemptedAt := expected.AttemptedAt + actualAttemptedAt := got.AttemptedAt + + expected.AttemptedAt = time.Time{} + got.AttemptedAt = time.Time{} + + require.Equal(t, expected, got) + require.Equal(t, expectedAttemptedAt.Unix(), actualAttemptedAt.Unix()) + + expected.AttemptedAt = expectedAttemptedAt + got.AttemptedAt = actualAttemptedAt +}