From 4f8c3aec16445bdb381bad0c8536aed411a8b135 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Sat, 19 Apr 2025 14:21:49 +0200 Subject: [PATCH] firewalldb+rpcserver: refactor ListActions Here we move the filter logic behind the interface so that our sql implementation can make use of indexes. --- firewalldb/action_paginator.go | 4 +- firewalldb/actions.go | 107 +++++++++++++++++++++++--- firewalldb/actions_kvdb.go | 113 ++++++++++++++++++++++----- firewalldb/actions_test.go | 68 +++++++++-------- session/interface.go | 3 + session_rpcserver.go | 136 +++++++++++++-------------------- 6 files changed, 284 insertions(+), 147 deletions(-) diff --git a/firewalldb/action_paginator.go b/firewalldb/action_paginator.go index 5e492f55e..d2bd3d2f6 100644 --- a/firewalldb/action_paginator.go +++ b/firewalldb/action_paginator.go @@ -16,7 +16,7 @@ type actionPaginator struct { // filterFn is the filter function which we are using to determine which // actions should be included in the return list. - filterFn ListActionsFilterFn + filterFn listActionsFilterFn // readAction is a closure which we use to read an action from the db // given a key value pair. @@ -32,7 +32,7 @@ type actionPaginator struct { // cfg.CountAll is set). func paginateActions(cfg *ListActionsQuery, c kvdb.RCursor, readAction func(k, v []byte) (*Action, error), - filterFn ListActionsFilterFn) ([]*Action, uint64, uint64, error) { + filterFn listActionsFilterFn) ([]*Action, uint64, uint64, error) { if cfg == nil { cfg = &ListActionsQuery{} diff --git a/firewalldb/actions.go b/firewalldb/actions.go index 57e43e0d6..493220e87 100644 --- a/firewalldb/actions.go +++ b/firewalldb/actions.go @@ -72,7 +72,7 @@ type Action struct { } // ListActionsQuery can be used to tweak the query to ListActions and -// ListSessionActions. +// listSessionActions. type ListActionsQuery struct { // IndexOffset is index of the action to inspect. IndexOffset uint64 @@ -91,6 +91,93 @@ type ListActionsQuery struct { CountAll bool } +// listActionsOptions holds the options that can be used to filter the actions +// that are returned by the ListActions method. +type listActionOptions struct { + sessionID session.ID + groupID session.ID + featureName string + actorName string + methodName string + state ActionState + endTime time.Time + startTime time.Time +} + +// newListActionOptions creates a new listActionOptions instance with default +// query values. +func newListActionOptions() *listActionOptions { + return &listActionOptions{} +} + +// ListActionOption is a functional option that can be used to tweak the +// behaviour of the ListActions method. +type ListActionOption func(*listActionOptions) + +// WithActionSessionID is a ListActionOption that can be used to select all +// Actions performed under the given session ID. +func WithActionSessionID(sessionID session.ID) ListActionOption { + return func(o *listActionOptions) { + o.sessionID = sessionID + } +} + +// WithActionGroupID is a ListActionOption that can be used to select all +// Actions performed under the give group ID. +func WithActionGroupID(groupID session.ID) ListActionOption { + return func(o *listActionOptions) { + o.groupID = groupID + } +} + +// WithActionStartTime is a ListActionOption that can be used to select all +// Actions that were attempted after the given time. +func WithActionStartTime(startTime time.Time) ListActionOption { + return func(o *listActionOptions) { + o.startTime = startTime + } +} + +// WithActionEndTime is a ListActionOption that can be used to select all +// Actions that were attempted before the given time. +func WithActionEndTime(endTime time.Time) ListActionOption { + return func(o *listActionOptions) { + o.endTime = endTime + } +} + +// WithActionFeatureName is a ListActionOption that can be used to select all +// Actions that were performed by the given feature. +func WithActionFeatureName(featureName string) ListActionOption { + return func(o *listActionOptions) { + o.featureName = featureName + } +} + +// WithActionActorName is a ListActionOption that can be used to select all +// Actions that were performed by the given actor. +func WithActionActorName(actorName string) ListActionOption { + return func(o *listActionOptions) { + o.actorName = actorName + } +} + +// WithActionMethodName is a ListActionOption that can be used to select all +// Actions that called the given RPC method. +func WithActionMethodName(methodName string) ListActionOption { + return func(o *listActionOptions) { + o.methodName = methodName + } +} + +// WithActionState is a ListActionOption that can be used to select all Actions +// that are in the given state. +func WithActionState(state ActionState) ListActionOption { + return func(o *listActionOptions) { + o.state = state + } +} + // ActionsWriteDB is an abstraction over the Actions DB that will allow a // caller to add new actions as well as change the values of an existing action. type ActionsWriteDB interface { @@ -174,10 +261,10 @@ var _ ActionsListDB = (*groupActionsReadDB)(nil) func (s *groupActionsReadDB) ListActions(ctx context.Context) ([]*RuleAction, error) { - sessionActions, err := s.db.ListGroupActions( - ctx, s.groupID, func(a *Action, _ bool) (bool, bool) { - return a.State == ActionStateDone, true - }, + sessionActions, _, _, err := s.db.ListActions( + ctx, nil, + WithActionGroupID(s.groupID), + WithActionState(ActionStateDone), ) if err != nil { return nil, err @@ -205,11 +292,11 @@ var _ ActionsListDB = (*groupFeatureActionsReadDB)(nil) func (a *groupFeatureActionsReadDB) ListActions(ctx context.Context) ( []*RuleAction, error) { - featureActions, err := a.db.ListGroupActions( - ctx, a.groupID, func(action *Action, _ bool) (bool, bool) { - return action.State == ActionStateDone && - action.FeatureName == a.featureName, true - }, + featureActions, _, _, err := a.db.ListActions( + ctx, nil, + WithActionGroupID(a.groupID), + WithActionState(ActionStateDone), + WithActionFeatureName(a.featureName), ) if err != nil { return nil, err diff --git a/firewalldb/actions_kvdb.go b/firewalldb/actions_kvdb.go index d92f95543..73fd8a61b 100644 --- a/firewalldb/actions_kvdb.go +++ b/firewalldb/actions_kvdb.go @@ -198,19 +198,87 @@ func (db *BoltDB) SetActionState(al *ActionLocator, state ActionState, }) } -// ListActionsFilterFn defines a function that can be used to determine if an -// action should be included in a set of results or not. The reversed parameter -// indicates if the actions are being traversed in reverse order or not. -// The first return boolean indicates if the action should be included or not -// and the second one indicates if the iteration should be stopped or not. -type ListActionsFilterFn func(a *Action, reversed bool) (bool, bool) +// ListActions returns a list of Actions. The query IndexOffset and MaxNum +// params can be used to control the number of actions returned. +// ListActionOptions may be used to filter on specific Action values. The return +// values are the list of actions, the last index and the total count (iff +// query.CountTotal is set). +func (db *BoltDB) ListActions(ctx context.Context, query *ListActionsQuery, + options ...ListActionOption) ([]*Action, uint64, uint64, error) { + + opts := newListActionOptions() + for _, o := range options { + o(opts) + } + + filterFn := func(a *Action, reversed bool) (bool, bool) { + timeStamp := a.AttemptedAt + if !opts.endTime.IsZero() { + // If actions are being considered in order and the + // timestamp of this action exceeds the given end + // timestamp, then there is no need to continue + // traversing. + if !reversed && timeStamp.After(opts.endTime) { + return false, false + } + + // If the actions are in reverse order and the timestamp + // comes after the end timestamp, then the actions is + // not included but the search can continue. + if reversed && timeStamp.After(opts.endTime) { + return false, true + } + } + + if !opts.startTime.IsZero() { + // If actions are being considered in order and the + // timestamp of this action comes before the given start + // timestamp, then the action is not included but the + // search can continue. + if !reversed && timeStamp.Before(opts.startTime) { + return false, true + } + + // If the actions are in reverse order and the timestamp + // comes before the start timestamp, then there is no + // need to continue traversing. + if reversed && timeStamp.Before(opts.startTime) { + return false, false + } + } + + if opts.featureName != "" && a.FeatureName != opts.featureName { + return false, true + } + + if opts.actorName != "" && a.ActorName != opts.actorName { + return false, true + } -// ListActions returns a list of Actions that pass the filterFn requirements. -// The indexOffset and maxNum params can be used to control the number of -// actions returned. The return values are the list of actions, the last index -// and the total count (iff query.CountTotal is set). -func (db *BoltDB) ListActions(filterFn ListActionsFilterFn, - query *ListActionsQuery) ([]*Action, uint64, uint64, error) { + if opts.methodName != "" && a.RPCMethod != opts.methodName { + return false, true + } + + if opts.state != ActionStateUnknown && a.State != opts.state { + return false, true + } + + return true, true + } + + if opts.sessionID != session.EmptyID { + return db.listSessionActions( + opts.sessionID, filterFn, query, + ) + } + if opts.groupID != session.EmptyID { + actions, err := db.listGroupActions(ctx, opts.groupID, filterFn) + if err != nil { + return nil, 0, 0, err + } + + return actions, 0, uint64(len(actions)), nil + } var ( actions []*Action @@ -242,7 +310,6 @@ func (db *BoltDB) ListActions(filterFn ListActionsFilterFn, if err != nil { return nil, err } - return getAction(actionsBucket, locator) } @@ -255,14 +322,20 @@ func (db *BoltDB) ListActions(filterFn ListActionsFilterFn, if err != nil { return nil, 0, 0, err } - return actions, lastIndex, totalCount, nil } -// ListSessionActions returns a list of the given session's Actions that pass +// listActionsFilterFn defines a function that can be used to determine if an +// action should be included in a set of results or not. The reversed parameter +// indicates if the actions are being traversed in reverse order or not. +// The first return boolean indicates if the action should be included or not +// and the second one indicates if the iteration should be continued or not. +type listActionsFilterFn func(a *Action, reversed bool) (bool, bool) + +// listSessionActions returns a list of the given session's Actions that pass // the filterFn requirements. -func (db *BoltDB) ListSessionActions(sessionID session.ID, - filterFn ListActionsFilterFn, query *ListActionsQuery) ([]*Action, +func (db *BoltDB) listSessionActions(sessionID session.ID, + filterFn listActionsFilterFn, query *ListActionsQuery) ([]*Action, uint64, uint64, error) { var ( @@ -303,12 +376,12 @@ func (db *BoltDB) ListSessionActions(sessionID session.ID, return actions, lastIndex, totalCount, nil } -// ListGroupActions returns a list of the given session group's Actions that +// listGroupActions returns a list of the given session group's Actions that // pass the filterFn requirements. // // TODO: update to allow for pagination. -func (db *BoltDB) ListGroupActions(ctx context.Context, groupID session.ID, - filterFn ListActionsFilterFn) ([]*Action, error) { +func (db *BoltDB) listGroupActions(ctx context.Context, groupID session.ID, + filterFn listActionsFilterFn) ([]*Action, error) { if filterFn == nil { filterFn = func(a *Action, reversed bool) (bool, bool) { diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index da5dff147..f09e9f640 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -42,6 +42,7 @@ var ( // TestActionStorage tests that the ActionsListDB CRUD logic. func TestActionStorage(t *testing.T) { tmpDir := t.TempDir() + ctx := context.Background() db, err := NewBoltDB(tmpDir, "test.db", nil) require.NoError(t, err) @@ -49,20 +50,18 @@ func TestActionStorage(t *testing.T) { _ = db.Close() }) - actionsStateFilterFn := func(state ActionState) ListActionsFilterFn { - return func(a *Action, _ bool) (bool, bool) { - return a.State == state, true - } - } - - actions, _, _, err := db.ListSessionActions( - sessionID1, actionsStateFilterFn(ActionStateDone), nil, + actions, _, _, err := db.ListActions( + ctx, nil, + WithActionSessionID(sessionID1), + WithActionState(ActionStateDone), ) require.NoError(t, err) require.Len(t, actions, 0) - actions, _, _, err = db.ListSessionActions( - sessionID2, actionsStateFilterFn(ActionStateDone), nil, + actions, _, _, err = db.ListActions( + ctx, nil, + WithActionSessionID(sessionID2), + WithActionState(ActionStateDone), ) require.NoError(t, err) require.Len(t, actions, 0) @@ -75,15 +74,19 @@ func TestActionStorage(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(1), id) - actions, _, _, err = db.ListSessionActions( - sessionID1, actionsStateFilterFn(ActionStateDone), nil, + actions, _, _, err = db.ListActions( + ctx, nil, + WithActionSessionID(sessionID1), + WithActionState(ActionStateDone), ) require.NoError(t, err) require.Len(t, actions, 1) require.Equal(t, action1, actions[0]) - actions, _, _, err = db.ListSessionActions( - sessionID2, actionsStateFilterFn(ActionStateDone), nil, + actions, _, _, err = db.ListActions( + ctx, nil, + WithActionSessionID(sessionID2), + WithActionState(ActionStateDone), ) require.NoError(t, err) require.Len(t, actions, 0) @@ -96,8 +99,10 @@ func TestActionStorage(t *testing.T) { ) require.NoError(t, err) - actions, _, _, err = db.ListSessionActions( - sessionID2, actionsStateFilterFn(ActionStateDone), nil, + actions, _, _, err = db.ListActions( + ctx, nil, + WithActionSessionID(sessionID2), + WithActionState(ActionStateDone), ) require.NoError(t, err) require.Len(t, actions, 1) @@ -136,8 +141,10 @@ func TestActionStorage(t *testing.T) { ) require.NoError(t, err) - actions, _, _, err = db.ListSessionActions( - sessionID2, actionsStateFilterFn(ActionStateError), nil, + actions, _, _, err = db.ListActions( + ctx, nil, + WithActionSessionID(sessionID2), + WithActionState(ActionStateError), ) require.NoError(t, err) require.Len(t, actions, 1) @@ -150,6 +157,7 @@ func TestActionStorage(t *testing.T) { // TODO(elle): cover more test cases here. func TestListActions(t *testing.T) { tmpDir := t.TempDir() + ctx := context.Background() db, err := NewBoltDB(tmpDir, "test.db", nil) require.NoError(t, err) @@ -201,7 +209,7 @@ func TestListActions(t *testing.T) { addAction(sessionID1) addAction(sessionID2) - actions, lastIndex, totalCount, err := db.ListActions(nil, nil) + actions, lastIndex, totalCount, err := db.ListActions(ctx, nil) require.NoError(t, err) require.Len(t, actions, 5) require.EqualValues(t, 5, lastIndex) @@ -218,7 +226,7 @@ func TestListActions(t *testing.T) { Reversed: true, } - actions, lastIndex, totalCount, err = db.ListActions(nil, query) + actions, lastIndex, totalCount, err = db.ListActions(ctx, query) require.NoError(t, err) require.Len(t, actions, 5) require.EqualValues(t, 1, lastIndex) @@ -232,7 +240,7 @@ func TestListActions(t *testing.T) { }) actions, lastIndex, totalCount, err = db.ListActions( - nil, &ListActionsQuery{ + ctx, &ListActionsQuery{ CountAll: true, }, ) @@ -249,7 +257,7 @@ func TestListActions(t *testing.T) { }) actions, lastIndex, totalCount, err = db.ListActions( - nil, &ListActionsQuery{ + ctx, &ListActionsQuery{ CountAll: true, Reversed: true, }, @@ -272,7 +280,7 @@ func TestListActions(t *testing.T) { addAction(sessionID1) addAction(sessionID2) - actions, lastIndex, totalCount, err = db.ListActions(nil, nil) + actions, lastIndex, totalCount, err = db.ListActions(ctx, nil) require.NoError(t, err) require.Len(t, actions, 10) require.EqualValues(t, 10, lastIndex) @@ -291,7 +299,7 @@ func TestListActions(t *testing.T) { }) actions, lastIndex, totalCount, err = db.ListActions( - nil, &ListActionsQuery{ + ctx, &ListActionsQuery{ MaxNum: 3, CountAll: true, }, @@ -307,7 +315,7 @@ func TestListActions(t *testing.T) { }) actions, lastIndex, totalCount, err = db.ListActions( - nil, &ListActionsQuery{ + ctx, &ListActionsQuery{ MaxNum: 3, IndexOffset: 3, }, @@ -323,7 +331,7 @@ func TestListActions(t *testing.T) { }) actions, lastIndex, totalCount, err = db.ListActions( - nil, &ListActionsQuery{ + ctx, &ListActionsQuery{ MaxNum: 3, IndexOffset: 3, CountAll: true, @@ -340,7 +348,7 @@ func TestListActions(t *testing.T) { }) } -// TestListGroupActions tests that the ListGroupActions correctly returns all +// TestListGroupActions tests that the listGroupActions correctly returns all // actions in a particular session group. func TestListGroupActions(t *testing.T) { t.Parallel() @@ -360,7 +368,7 @@ func TestListGroupActions(t *testing.T) { }) // There should not be any actions in group 1 yet. - al, err := db.ListGroupActions(ctx, group1, nil) + al, _, _, err := db.ListActions(ctx, nil, WithActionGroupID(group1)) require.NoError(t, err) require.Empty(t, al) @@ -369,7 +377,7 @@ func TestListGroupActions(t *testing.T) { require.NoError(t, err) // There should now be one action in the group. - al, err = db.ListGroupActions(ctx, group1, nil) + al, _, _, err = db.ListActions(ctx, nil, WithActionGroupID(group1)) require.NoError(t, err) require.Len(t, al, 1) require.Equal(t, sessionID1, al[0].SessionID) @@ -379,7 +387,7 @@ func TestListGroupActions(t *testing.T) { require.NoError(t, err) // There should now be actions in the group. - al, err = db.ListGroupActions(ctx, group1, nil) + al, _, _, err = db.ListActions(ctx, nil, WithActionGroupID(group1)) require.NoError(t, err) require.Len(t, al, 2) require.Equal(t, sessionID1, al[0].SessionID) diff --git a/session/interface.go b/session/interface.go index c34264c0d..36b5075fe 100644 --- a/session/interface.go +++ b/session/interface.go @@ -14,6 +14,9 @@ import ( "gopkg.in/macaroon.v2" ) +// EmptyID is an empty session ID. +var EmptyID ID + // Type represents the type of session. type Type uint8 diff --git a/session_rpcserver.go b/session_rpcserver.go index ffb8e6b49..34eb4aa7e 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -668,69 +668,6 @@ func (s *sessionRpcServer) ListActions(ctx context.Context, req.MaxNumActions = 100 } - // Build a filter function based on the request values. - filterFn := func(a *firewalldb.Action, reversed bool) (bool, bool) { - timeStamp := uint64(a.AttemptedAt.Unix()) - if req.EndTimestamp != 0 { - // If actions are being considered in order and the - // timestamp of this action exceeds the given end - // timestamp, then there is no need to continue - // traversing. - if !reversed && timeStamp > req.EndTimestamp { - return false, false - } - - // If the actions are in reverse order and the timestamp - // comes after the end timestamp, then the actions is - // not included but the search can continue. - if reversed && timeStamp > req.EndTimestamp { - return false, true - } - } - - if req.StartTimestamp != 0 { - // If actions are being considered in order and the - // timestamp of this action comes before the given start - // timestamp, then the action is not included but the - // search can continue. - if !reversed && timeStamp < req.StartTimestamp { - return false, true - } - - // If the actions are in reverse order and the timestamp - // comes before the start timestamp, then there is no - // need to continue traversing. - if reversed && timeStamp < req.StartTimestamp { - return false, false - } - } - - if req.FeatureName != "" && a.FeatureName != req.FeatureName { - return false, true - } - - if req.ActorName != "" && a.ActorName != req.ActorName { - return false, true - } - - if req.MethodName != "" && a.RPCMethod != req.MethodName { - return false, true - } - - if req.State != 0 { - s, err := marshalActionState(a.State) - if err != nil { - return false, true - } - - if s != req.State { - return false, true - } - } - - return true, true - } - query := &firewalldb.ListActionsQuery{ IndexOffset: req.IndexOffset, MaxNum: req.MaxNumActions, @@ -738,43 +675,55 @@ func (s *sessionRpcServer) ListActions(ctx context.Context, CountAll: req.CountTotal, } + state, err := unmarshalActionState(req.State) + if err != nil { + return nil, err + } + var ( - db = s.cfg.actionsDB - actions []*firewalldb.Action - lastIndex uint64 - totalCount uint64 - err error + listOptions = []firewalldb.ListActionOption{ + firewalldb.WithActionFeatureName(req.FeatureName), + firewalldb.WithActionActorName(req.ActorName), + firewalldb.WithActionMethodName(req.MethodName), + firewalldb.WithActionState(state), + } + addOption = func(opt firewalldb.ListActionOption) { + listOptions = append(listOptions, opt) + } ) if req.SessionId != nil { sessionID, err := session.IDFromBytes(req.SessionId) if err != nil { return nil, err } - - actions, lastIndex, totalCount, err = db.ListSessionActions( - sessionID, filterFn, query, - ) - if err != nil { - return nil, err - } + addOption(firewalldb.WithActionSessionID(sessionID)) } else if req.GroupId != nil { groupID, err := session.IDFromBytes(req.GroupId) if err != nil { return nil, err } + addOption(firewalldb.WithActionGroupID(groupID)) + } - actions, err = db.ListGroupActions(ctx, groupID, filterFn) - if err != nil { - return nil, err - } - } else { - actions, lastIndex, totalCount, err = db.ListActions( - filterFn, query, + if req.EndTimestamp != 0 { + addOption(firewalldb.WithActionEndTime( + time.Unix(int64(req.EndTimestamp), 0)), ) - if err != nil { - return nil, err - } } + + if req.StartTimestamp != 0 { + addOption(firewalldb.WithActionStartTime( + time.Unix(int64(req.StartTimestamp), 0)), + ) + } + + actions, lastIndex, totalCount, err := s.cfg.actionsDB.ListActions( + ctx, query, listOptions..., + ) + if err != nil { + return nil, err + } + resp := make([]*litrpc.Action, len(actions)) for i, a := range actions { state, err := marshalActionState(a.State) @@ -1664,3 +1613,20 @@ func marshalActionState(state firewalldb.ActionState) (litrpc.ActionState, return 0, fmt.Errorf("unknown state <%d>", state) } } + +func unmarshalActionState(state litrpc.ActionState) (firewalldb.ActionState, + error) { + + switch state { + case litrpc.ActionState_STATE_UNKNOWN: + return firewalldb.ActionStateUnknown, nil + case litrpc.ActionState_STATE_PENDING: + return firewalldb.ActionStateInit, nil + case litrpc.ActionState_STATE_DONE: + return firewalldb.ActionStateDone, nil + case litrpc.ActionState_STATE_ERROR: + return firewalldb.ActionStateError, nil + default: + return 0, fmt.Errorf("unknown state <%d>", state) + } +}