Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions firewall/request_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type RequestLogger struct {
// be used to find the corresponding action. This is used so that
// requests and responses can be easily linked. The mu mutex must be
// used when accessing this map.
reqIDToAction map[uint64]*firewalldb.ActionLocator
reqIDToAction map[uint64]firewalldb.ActionLocator
mu sync.Mutex
}

Expand Down Expand Up @@ -105,7 +105,7 @@ func NewRequestLogger(cfg *RequestLoggerConfig,
return &RequestLogger{
shouldLogAction: shouldLogAction,
actionsDB: actionsDB,
reqIDToAction: make(map[uint64]*firewalldb.ActionLocator),
reqIDToAction: make(map[uint64]firewalldb.ActionLocator),
}, nil
}

Expand All @@ -128,7 +128,7 @@ func (r *RequestLogger) CustomCaveatName() string {

// Intercept processes an RPC middleware interception request and returns the
// interception result which either accepts or rejects the intercepted message.
func (r *RequestLogger) Intercept(_ context.Context,
func (r *RequestLogger) Intercept(ctx context.Context,
req *lnrpc.RPCMiddlewareRequest) (*lnrpc.RPCMiddlewareResponse, error) {

ri, err := NewInfoFromRequest(req)
Expand Down Expand Up @@ -156,7 +156,7 @@ func (r *RequestLogger) Intercept(_ context.Context,

// Parse incoming requests and act on them.
case MWRequestTypeRequest:
return mid.RPCErr(req, r.addNewAction(ri, withPayloadData))
return mid.RPCErr(req, r.addNewAction(ctx, ri, withPayloadData))

// Parse and possibly manipulate outgoing responses.
case MWRequestTypeResponse:
Expand All @@ -170,7 +170,7 @@ func (r *RequestLogger) Intercept(_ context.Context,
}

return mid.RPCErr(
req, r.MarkAction(ri.RequestID, state, errReason),
req, r.MarkAction(ctx, ri.RequestID, state, errReason),
)

default:
Expand All @@ -179,7 +179,7 @@ func (r *RequestLogger) Intercept(_ context.Context,
}

// addNewAction persists the new action to the db.
func (r *RequestLogger) addNewAction(ri *RequestInfo,
func (r *RequestLogger) addNewAction(ctx context.Context, ri *RequestInfo,
withPayloadData bool) error {

// If no macaroon is provided, then an empty 4-byte array is used as the
Expand Down Expand Up @@ -223,24 +223,21 @@ func (r *RequestLogger) addNewAction(ri *RequestInfo,
}
}

id, err := r.actionsDB.AddAction(action)
locator, err := r.actionsDB.AddAction(ctx, action)
if err != nil {
return err
}

r.mu.Lock()
r.reqIDToAction[ri.RequestID] = &firewalldb.ActionLocator{
SessionID: sessionID,
ActionID: id,
}
r.reqIDToAction[ri.RequestID] = locator
r.mu.Unlock()

return nil
}

// MarkAction can be used to set the state of an action identified by the given
// requestID.
func (r *RequestLogger) MarkAction(reqID uint64,
func (r *RequestLogger) MarkAction(ctx context.Context, reqID uint64,
state firewalldb.ActionState, errReason string) error {

r.mu.Lock()
Expand All @@ -252,5 +249,5 @@ func (r *RequestLogger) MarkAction(reqID uint64,
}
delete(r.reqIDToAction, reqID)

return r.actionsDB.SetActionState(actionLocator, state, errReason)
return r.actionsDB.SetActionState(ctx, actionLocator, state, errReason)
}
12 changes: 8 additions & 4 deletions firewall/rule_enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ type RuleEnforcer struct {
ruleDB firewalldb.RulesDB
actionsDB firewalldb.ActionReadDBGetter
sessionDB firewalldb.SessionDB
markActionErrored func(reqID uint64, reason string) error
privMapDB firewalldb.PrivacyMapper
markActionErrored func(ctx context.Context, reqID uint64,
reason string) error
privMapDB firewalldb.PrivacyMapper

permsMgr *perms.Manager
getFeaturePerms featurePerms
Expand Down Expand Up @@ -63,7 +64,8 @@ func NewRuleEnforcer(ruleDB firewalldb.RulesDB,
routerClient lndclient.RouterClient,
lndClient lndclient.LightningClient, lndConnID string,
ruleMgrs rules.ManagerSet,
markActionErrored func(reqID uint64, reason string) error,
markActionErrored func(ctx context.Context, reqID uint64,
reason string) error,
privMap firewalldb.PrivacyMapper) *RuleEnforcer {

return &RuleEnforcer{
Expand Down Expand Up @@ -164,7 +166,9 @@ func (r *RuleEnforcer) Intercept(ctx context.Context,

replacement, err := r.handleRequest(ctx, ri)
if err != nil {
dbErr := r.markActionErrored(ri.RequestID, err.Error())
dbErr := r.markActionErrored(
ctx, ri.RequestID, err.Error(),
)
if dbErr != nil {
log.Error("could not mark action for "+
"request ID %d as Errored: %v",
Expand Down
13 changes: 6 additions & 7 deletions firewalldb/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ func WithActionState(state ActionState) ListActionOption {
// ActionsWriteDB is an abstraction over the Actions DB that will allow a
// caller to add new actions as well as change the values of an existing action.
type ActionsWriteDB interface {
AddAction(action *Action) (uint64, error)
SetActionState(al *ActionLocator, state ActionState,
errReason string) error
AddAction(ctx context.Context, action *Action) (ActionLocator, error)
SetActionState(ctx context.Context, al ActionLocator,
state ActionState, errReason string) error
}

// RuleAction represents a method call that was performed at a certain time at
Expand Down Expand Up @@ -230,7 +230,7 @@ func (db *BoltDB) GetActionsReadDB(groupID session.ID,

// allActionsReadDb is an implementation of the ActionsReadDB.
type allActionsReadDB struct {
db *BoltDB
db ActionDB
groupID session.ID
featureName string
}
Expand Down Expand Up @@ -318,7 +318,6 @@ func actionToRulesAction(a *Action) *RuleAction {
}

// ActionLocator helps us find an action in the database.
type ActionLocator struct {
SessionID session.ID
ActionID uint64
type ActionLocator interface {
isActionLocator()
}
78 changes: 49 additions & 29 deletions firewalldb/actions_kvdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ var (
)

// AddAction serialises and adds an Action to the DB under the given sessionID.
func (db *BoltDB) AddAction(action *Action) (uint64, error) {
func (db *BoltDB) AddAction(_ context.Context, action *Action) (ActionLocator,
error) {

var buf bytes.Buffer
if err := SerializeAction(&buf, action); err != nil {
return 0, err
return nil, err
}

var id uint64
var locator kvdbActionLocator
err := db.DB.Update(func(tx *bbolt.Tx) error {
mainActionsBucket, err := getBucket(tx, actionsBucketKey)
if err != nil {
Expand All @@ -82,7 +84,6 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
if err != nil {
return err
}
id = nextActionIndex

var actionIndex [8]byte
byteOrder.PutUint64(actionIndex[:], nextActionIndex)
Expand All @@ -101,9 +102,9 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
return err
}

locator := ActionLocator{
SessionID: action.SessionID,
ActionID: nextActionIndex,
locator = kvdbActionLocator{
sessionID: action.SessionID,
actionID: nextActionIndex,
}

var buf bytes.Buffer
Expand All @@ -117,13 +118,25 @@ func (db *BoltDB) AddAction(action *Action) (uint64, error) {
return actionsIndexBucket.Put(seqNoBytes[:], buf.Bytes())
})
if err != nil {
return 0, err
return nil, err
}

return id, nil
return &locator, nil
}

// kvdbActionLocator helps us find an action in a KVDB database.
type kvdbActionLocator struct {
sessionID session.ID
actionID uint64
}

func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error {
// A compile-time check to ensure kvdbActionLocator implements the ActionLocator
// interface.
var _ ActionLocator = (*kvdbActionLocator)(nil)

func (al *kvdbActionLocator) isActionLocator() {}

func putAction(tx *bbolt.Tx, al *kvdbActionLocator, a *Action) error {
var buf bytes.Buffer
if err := SerializeAction(&buf, a); err != nil {
return err
Expand All @@ -139,42 +152,49 @@ func putAction(tx *bbolt.Tx, al *ActionLocator, a *Action) error {
return ErrNoSuchKeyFound
}

sessBucket := actionsBucket.Bucket(al.SessionID[:])
sessBucket := actionsBucket.Bucket(al.sessionID[:])
if sessBucket == nil {
return fmt.Errorf("session bucket for session ID %x does not "+
"exist", al.SessionID)
"exist", al.sessionID)
}

var id [8]byte
binary.BigEndian.PutUint64(id[:], al.ActionID)
binary.BigEndian.PutUint64(id[:], al.actionID)

return sessBucket.Put(id[:], buf.Bytes())
}

func getAction(actionsBkt *bbolt.Bucket, al *ActionLocator) (*Action, error) {
sessBucket := actionsBkt.Bucket(al.SessionID[:])
func getAction(actionsBkt *bbolt.Bucket, al *kvdbActionLocator) (*Action,
error) {

sessBucket := actionsBkt.Bucket(al.sessionID[:])
if sessBucket == nil {
return nil, fmt.Errorf("session bucket for session ID "+
"%x does not exist", al.SessionID)
"%x does not exist", al.sessionID)
}

var id [8]byte
binary.BigEndian.PutUint64(id[:], al.ActionID)
binary.BigEndian.PutUint64(id[:], al.actionID)

actionBytes := sessBucket.Get(id[:])
return DeserializeAction(bytes.NewReader(actionBytes), al.SessionID)
return DeserializeAction(bytes.NewReader(actionBytes), al.sessionID)
}

// SetActionState finds the action specified by the ActionLocator and sets its
// state to the given state.
func (db *BoltDB) SetActionState(al *ActionLocator, state ActionState,
errorReason string) error {
func (db *BoltDB) SetActionState(_ context.Context, al ActionLocator,
state ActionState, errorReason string) error {

if errorReason != "" && state != ActionStateError {
return fmt.Errorf("error reason should only be set for " +
"ActionStateError")
}

locator, ok := al.(*kvdbActionLocator)
if !ok {
return fmt.Errorf("expected kvdbActionLocator, got %T", al)
}

return db.DB.Update(func(tx *bbolt.Tx) error {
mainActionsBucket, err := getBucket(tx, actionsBucketKey)
if err != nil {
Expand All @@ -186,15 +206,15 @@ func (db *BoltDB) SetActionState(al *ActionLocator, state ActionState,
return ErrNoSuchKeyFound
}

action, err := getAction(actionsBucket, al)
action, err := getAction(actionsBucket, locator)
if err != nil {
return err
}

action.State = state
action.ErrorReason = errorReason

return putAction(tx, al, action)
return putAction(tx, locator, action)
})
}

Expand Down Expand Up @@ -540,14 +560,14 @@ func DeserializeAction(r io.Reader, sessionID session.ID) (*Action, error) {

// serializeActionLocator binary serializes the given ActionLocator to the
// writer using the tlv format.
func serializeActionLocator(w io.Writer, al *ActionLocator) error {
func serializeActionLocator(w io.Writer, al *kvdbActionLocator) error {
if al == nil {
return fmt.Errorf("action locator cannot be nil")
}

var (
sessionID = al.SessionID[:]
actionID = al.ActionID
sessionID = al.sessionID[:]
actionID = al.actionID
)

tlvRecords := []tlv.Record{
Expand All @@ -565,7 +585,7 @@ func serializeActionLocator(w io.Writer, al *ActionLocator) error {

// deserializeActionLocator deserializes an ActionLocator from the given reader,
// expecting the data to be encoded in the tlv format.
func deserializeActionLocator(r io.Reader) (*ActionLocator, error) {
func deserializeActionLocator(r io.Reader) (*kvdbActionLocator, error) {
var (
sessionID []byte
actionID uint64
Expand All @@ -588,8 +608,8 @@ func deserializeActionLocator(r io.Reader) (*ActionLocator, error) {
return nil, err
}

return &ActionLocator{
SessionID: id,
ActionID: actionID,
return &kvdbActionLocator{
sessionID: id,
actionID: actionID,
}, nil
}
Loading
Loading