@@ -3,7 +3,11 @@ package auditserver
33import (
44 "bytes"
55 "encoding/json"
6+ "errors"
67 "fmt"
8+ "github.com/ncode/vault-audit-filter/pkg/forwarder"
9+ "github.com/ncode/vault-audit-filter/pkg/messaging"
10+ "github.com/stretchr/testify/require"
711 "io"
812 "io/ioutil"
913 "net"
@@ -799,3 +803,102 @@ func TestAuditServer_React_WithForwarding(t *testing.T) {
799803 })
800804 }
801805}
806+
807+ type dummyMessenger struct {
808+ sendErr error
809+ calls int
810+ }
811+
812+ func (d * dummyMessenger ) Send (_ string ) error {
813+ d .calls ++
814+ return d .sendErr
815+ }
816+
817+ type dummyForwarder struct {
818+ forwardErr error
819+ calls int
820+ }
821+
822+ func (d * dummyForwarder ) Forward (_ []byte ) error {
823+ d .calls ++
824+ return d .forwardErr
825+ }
826+
827+ func mustAuditLogBytes (t * testing.T ) []byte {
828+ t .Helper ()
829+ json := `{"type":"request","time":"2000-01-01T00:00:00Z","auth":{},"request":{},"response":{}}`
830+ return []byte (json )
831+ }
832+
833+ func newRuleGroup (match bool , msgr messaging.Messenger , fwd forwarder.Forwarder , wr bytes.Buffer ) RuleGroup {
834+ rule := CompiledRule {}
835+ if ! match {
836+ // compiledRules ≠ 0 + rule that always returns false
837+ rule = CompiledRule {Program : nil }
838+ }
839+ return RuleGroup {
840+ Name : "grp" ,
841+ CompiledRules : []CompiledRule {rule },
842+ Messenger : msgr ,
843+ Forwarder : fwd ,
844+ Writer : & wr ,
845+ }
846+ }
847+
848+ func TestReact_Branches (t * testing.T ) {
849+ frame := mustAuditLogBytes (t )
850+
851+ tests := []struct {
852+ name string
853+ group RuleGroup
854+ wantAction gnet.Action
855+ wantMsgCalls int
856+ wantFwdCalls int
857+ }{
858+ {
859+ name : "match_no_side_effects_returns_None" ,
860+ group : newRuleGroup (true , nil , nil , bytes.Buffer {}),
861+ wantAction : gnet .None ,
862+ },
863+ {
864+ name : "forwarder_ok_triggers_Close" ,
865+ group : newRuleGroup (true , nil , & dummyForwarder {}, bytes.Buffer {}),
866+ wantAction : gnet .Close ,
867+ wantFwdCalls : 1 ,
868+ },
869+ {
870+ name : "forwarder_error_triggers_Close" ,
871+ group : newRuleGroup (true , nil , & dummyForwarder {forwardErr : errors .New ("x" )}, bytes.Buffer {}),
872+ wantAction : gnet .Close ,
873+ wantFwdCalls : 1 ,
874+ },
875+ {
876+ name : "messenger_error_triggers_Close" ,
877+ group : newRuleGroup (true , & dummyMessenger {sendErr : errors .New ("x" )}, nil , bytes.Buffer {}),
878+ wantAction : gnet .Close ,
879+ wantMsgCalls : 1 ,
880+ },
881+ {
882+ name : "no_match_triggers_Close" ,
883+ group : newRuleGroup (false , nil , nil , bytes.Buffer {}),
884+ wantAction : gnet .Close ,
885+ },
886+ }
887+
888+ for _ , tc := range tests {
889+ tc := tc // capture
890+ t .Run (tc .name , func (t * testing.T ) {
891+ srv := & AuditServer {ruleGroups : []RuleGroup {tc .group }}
892+ _ , act := srv .React (frame , nil )
893+
894+ require .Equal (t , tc .wantAction , act )
895+
896+ if dm , ok := tc .group .Messenger .(* dummyMessenger ); ok {
897+ require .Equal (t , tc .wantMsgCalls , dm .calls )
898+ }
899+ if df , ok := tc .group .Forwarder .(* dummyForwarder ); ok {
900+ require .Equal (t , tc .wantFwdCalls , df .calls )
901+ }
902+ })
903+ }
904+ }
0 commit comments