Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions firewall/request_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type RequestLogger struct {
// be used to find the corresponding action. This is used so that
// requests and responses can be easily linked. The mu mutex must be
// used when accessing this map.
reqIDToAction map[uint64]*firewalldb.ActionLocator
reqIDToAction map[uint64]firewalldb.ActionLocator
mu sync.Mutex
}

Expand Down Expand Up @@ -105,7 +105,7 @@ func NewRequestLogger(cfg *RequestLoggerConfig,
return &RequestLogger{
shouldLogAction: shouldLogAction,
actionsDB: actionsDB,
reqIDToAction: make(map[uint64]*firewalldb.ActionLocator),
reqIDToAction: make(map[uint64]firewalldb.ActionLocator),
}, nil
}

Expand Down Expand Up @@ -223,16 +223,13 @@ func (r *RequestLogger) addNewAction(ctx context.Context, ri *RequestInfo,
}
}

id, err := r.actionsDB.AddAction(ctx, action)
locator, err := r.actionsDB.AddAction(ctx, action)
if err != nil {
return err
}

r.mu.Lock()
r.reqIDToAction[ri.RequestID] = &firewalldb.ActionLocator{
SessionID: sessionID,
ActionID: id,
}
r.reqIDToAction[ri.RequestID] = locator
r.mu.Unlock()

return nil
Expand Down
9 changes: 4 additions & 5 deletions firewalldb/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ func WithActionState(state ActionState) ListActionOption {
// ActionsWriteDB is an abstraction over the Actions DB that will allow a
// caller to add new actions as well as change the values of an existing action.
type ActionsWriteDB interface {
AddAction(ctx context.Context, action *Action) (uint64, error)
SetActionState(ctx context.Context, al *ActionLocator,
AddAction(ctx context.Context, action *Action) (ActionLocator, error)
SetActionState(ctx context.Context, al ActionLocator,
state ActionState, errReason string) error
}

Expand Down Expand Up @@ -318,7 +318,6 @@ func actionToRulesAction(a *Action) *RuleAction {
}

// ActionLocator helps us find an action in the database.
type ActionLocator struct {
SessionID session.ID
ActionID uint64
type ActionLocator interface {
isActionLocator()
}
76 changes: 48 additions & 28 deletions firewalldb/actions_kvdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ var (
)

// AddAction serialises and adds an Action to the DB under the given sessionID.
func (db *BoltDB) AddAction(_ context.Context, action *Action) (uint64, error) {
func (db *BoltDB) AddAction(_ context.Context, action *Action) (ActionLocator,
error) {

var buf bytes.Buffer
if err := SerializeAction(&buf, action); err != nil {
return 0, err
return nil, err
}

var id uint64
var locator kvdbActionLocator
err := db.DB.Update(func(tx *bbolt.Tx) error {
mainActionsBucket, err := getBucket(tx, actionsBucketKey)
if err != nil {
Expand All @@ -82,7 +84,6 @@ func (db *BoltDB) AddAction(_ context.Context, action *Action) (uint64, error) {
if err != nil {
return err
}
id = nextActionIndex

var actionIndex [8]byte
byteOrder.PutUint64(actionIndex[:], nextActionIndex)
Expand All @@ -101,9 +102,9 @@ func (db *BoltDB) AddAction(_ context.Context, action *Action) (uint64, error) {
return err
}

locator := ActionLocator{
SessionID: action.SessionID,
ActionID: nextActionIndex,
locator = kvdbActionLocator{
sessionID: action.SessionID,
actionID: nextActionIndex,
}

var buf bytes.Buffer
Expand All @@ -117,13 +118,25 @@ func (db *BoltDB) AddAction(_ context.Context, action *Action) (uint64, error) {
return actionsIndexBucket.Put(seqNoBytes[:], buf.Bytes())
})
if err != nil {
return 0, err
return nil, err
}

return id, nil
return &locator, nil
}

// kvdbActionLocator helps us find an action in a KVDB database.
type kvdbActionLocator struct {
sessionID session.ID
actionID uint64
}

func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error {
// A compile-time check to ensure kvdbActionLocator implements the ActionLocator
// interface.
var _ ActionLocator = (*kvdbActionLocator)(nil)

func (al *kvdbActionLocator) isActionLocator() {}

func putAction(tx *bbolt.Tx, al *kvdbActionLocator, a *Action) error {
var buf bytes.Buffer
if err := SerializeAction(&buf, a); err != nil {
return err
Expand All @@ -139,42 +152,49 @@ func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error {
return ErrNoSuchKeyFound
}

sessBucket := actionsBucket.Bucket(al.SessionID[:])
sessBucket := actionsBucket.Bucket(al.sessionID[:])
if sessBucket == nil {
return fmt.Errorf("session bucket for session ID %x does not "+
"exist", al.SessionID)
"exist", al.sessionID)
}

var id [8]byte
binary.BigEndian.PutUint64(id[:], al.ActionID)
binary.BigEndian.PutUint64(id[:], al.actionID)

return sessBucket.Put(id[:], buf.Bytes())
}

func getAction(actionsBkt *bbolt.Bucket, al *ActionLocator) (*Action, error) {
sessBucket := actionsBkt.Bucket(al.SessionID[:])
func getAction(actionsBkt *bbolt.Bucket, al *kvdbActionLocator) (*Action,
error) {

sessBucket := actionsBkt.Bucket(al.sessionID[:])
if sessBucket == nil {
return nil, fmt.Errorf("session bucket for session ID "+
"%x does not exist", al.SessionID)
"%x does not exist", al.sessionID)
}

var id [8]byte
binary.BigEndian.PutUint64(id[:], al.ActionID)
binary.BigEndian.PutUint64(id[:], al.actionID)

actionBytes := sessBucket.Get(id[:])
return DeserializeAction(bytes.NewReader(actionBytes), al.SessionID)
return DeserializeAction(bytes.NewReader(actionBytes), al.sessionID)
}

// SetActionState finds the action specified by the ActionLocator and sets its
// state to the given state.
func (db *BoltDB) SetActionState(_ context.Context, al *ActionLocator,
func (db *BoltDB) SetActionState(_ context.Context, al ActionLocator,
state ActionState, errorReason string) error {

if errorReason != "" && state != ActionStateError {
return fmt.Errorf("error reason should only be set for " +
"ActionStateError")
}

locator, ok := al.(*kvdbActionLocator)
if !ok {
return fmt.Errorf("expected kvdbActionLocator, got %T", al)
}

return db.DB.Update(func(tx *bbolt.Tx) error {
mainActionsBucket, err := getBucket(tx, actionsBucketKey)
if err != nil {
Expand All @@ -186,15 +206,15 @@ func (db *BoltDB) SetActionState(_ context.Context, al *ActionLocator,
return ErrNoSuchKeyFound
}

action, err := getAction(actionsBucket, al)
action, err := getAction(actionsBucket, locator)
if err != nil {
return err
}

action.State = state
action.ErrorReason = errorReason

return putAction(tx, al, action)
return putAction(tx, locator, action)
})
}

Expand Down Expand Up @@ -540,14 +560,14 @@ func DeserializeAction(r io.Reader, sessionID session.ID) (*Action, error) {

// serializeActionLocator binary serializes the given ActionLocator to the
// writer using the tlv format.
func serializeActionLocator(w io.Writer, al *ActionLocator) error {
func serializeActionLocator(w io.Writer, al *kvdbActionLocator) error {
if al == nil {
return fmt.Errorf("action locator cannot be nil")
}

var (
sessionID = al.SessionID[:]
actionID = al.ActionID
sessionID = al.sessionID[:]
actionID = al.actionID
)

tlvRecords := []tlv.Record{
Expand All @@ -565,7 +585,7 @@ func serializeActionLocator(w io.Writer, al *ActionLocator) error {

// deserializeActionLocator deserializes an ActionLocator from the given reader,
// expecting the data to be encoded in the tlv format.
func deserializeActionLocator(r io.Reader) (*ActionLocator, error) {
func deserializeActionLocator(r io.Reader) (*kvdbActionLocator, error) {
var (
sessionID []byte
actionID uint64
Expand All @@ -588,8 +608,8 @@ func deserializeActionLocator(r io.Reader) (*ActionLocator, error) {
return nil, err
}

return &ActionLocator{
SessionID: id,
ActionID: actionID,
return &kvdbActionLocator{
sessionID: id,
actionID: actionID,
}, nil
}
30 changes: 6 additions & 24 deletions firewalldb/actions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,11 @@ func TestActionStorage(t *testing.T) {
require.NoError(t, err)
require.Len(t, actions, 0)

id, err := db.AddAction(ctx, action1)
_, err = db.AddAction(ctx, action1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we could still check uniqueness of the locator, by replacing isActionLocator with a representation() string that we could compare? that way we would have at least some utility in the interface-definding method

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's ok, im gonna leave it as is cause:

  • this is a pretty standard pattern in go
  • the correctness of the locator is already tested below when we use locator 2 to update the state of an action

require.NoError(t, err)
require.Equal(t, uint64(1), id)

id, err = db.AddAction(ctx, action2)
locator2, err := db.AddAction(ctx, action2)
require.NoError(t, err)
require.Equal(t, uint64(1), id)

actions, _, _, err = db.ListActions(
ctx, nil,
Expand All @@ -91,12 +89,7 @@ func TestActionStorage(t *testing.T) {
require.NoError(t, err)
require.Len(t, actions, 0)

err = db.SetActionState(
ctx, &ActionLocator{
SessionID: sessionID2,
ActionID: uint64(1),
}, ActionStateDone, "",
)
err = db.SetActionState(ctx, locator2, ActionStateDone, "")
require.NoError(t, err)

actions, _, _, err = db.ListActions(
Expand All @@ -109,9 +102,8 @@ func TestActionStorage(t *testing.T) {
action2.State = ActionStateDone
require.Equal(t, action2, actions[0])

id, err = db.AddAction(ctx, action1)
_, err = db.AddAction(ctx, action1)
require.NoError(t, err)
require.Equal(t, uint64(2), id)

// Check that providing no session id and no filter function returns
// all the actions.
Expand All @@ -124,21 +116,11 @@ func TestActionStorage(t *testing.T) {
require.Len(t, actions, 3)

// Try set an error reason for a non Errored state.
err = db.SetActionState(
ctx, &ActionLocator{
SessionID: sessionID2,
ActionID: uint64(1),
}, ActionStateDone, "hello",
)
err = db.SetActionState(ctx, locator2, ActionStateDone, "hello")
require.Error(t, err)

// Now try move the action to errored with a reason.
err = db.SetActionState(
ctx, &ActionLocator{
SessionID: sessionID2,
ActionID: uint64(1),
}, ActionStateError, "fail whale",
)
err = db.SetActionState(ctx, locator2, ActionStateError, "fail whale")
require.NoError(t, err)

actions, _, _, err = db.ListActions(
Expand Down
4 changes: 2 additions & 2 deletions firewalldb/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ type PrivacyMapper interface {
// the Action persistence and querying.
type ActionDB interface {
// AddAction persists the given action to the database.
AddAction(ctx context.Context, action *Action) (uint64, error)
AddAction(ctx context.Context, action *Action) (ActionLocator, error)

// SetActionState finds the action specified by the ActionLocator and
// sets its state to the given state.
SetActionState(ctx context.Context, al *ActionLocator,
SetActionState(ctx context.Context, al ActionLocator,
state ActionState, errReason string) error

// ListActions returns a list of Actions that pass the filterFn
Expand Down