Skip to content

Commit 52b93d5

Browse files
committed
fix: resolve multiple bugs and improve test coverage
Bug fixes: - Fix nil server panic in auditServer cmd when New() fails - Add nil checks for secret.Auth in AppRole/Cert/JWT auth methods - Fix TLS config order in vault.go (configure before client creation) - Add Close() method to UDPForwarder to prevent resource leaks - Fix GetStringSlice bug in root.go rule_groups check - Log write errors in server.go instead of silently ignoring Improvements: - Convert cmd Run functions to RunE for better error handling - Add cmd/ test coverage (was 0%, now 89.8%) - Add tests for nil auth response scenarios - Add UDPForwarder Close tests - Overall coverage improved from 75.7% to 91.6%
1 parent eed4f07 commit 52b93d5

File tree

13 files changed

+569
-123
lines changed

13 files changed

+569
-123
lines changed

cmd/auditServer.go

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,25 @@ package cmd
1717

1818
import (
1919
"fmt"
20-
"log"
2120

2221
"github.com/ncode/vault-audit-filter/pkg/auditserver"
2322
"github.com/panjf2000/gnet"
24-
"github.com/spf13/viper"
25-
2623
"github.com/spf13/cobra"
24+
"github.com/spf13/viper"
2725
)
2826

2927
// auditServerCmd represents the auditServer command
3028
var auditServerCmd = &cobra.Command{
3129
Use: "auditServer",
32-
Short: "A brief description of your command",
33-
Long: `A longer description that spans multiple lines and likely contains examples
34-
and usage of using your command. For example:
35-
36-
Cobra is a CLI library for Go that empowers applications.
37-
This application is a tool to generate the needed files
38-
to quickly create a Cobra application.`,
39-
Run: func(cmd *cobra.Command, args []string) {
30+
Short: "Start the audit server to receive and filter Vault audit logs",
31+
Long: `Starts a UDP server that receives Vault audit logs and filters them based on configured rules.`,
32+
RunE: func(cmd *cobra.Command, args []string) error {
4033
addr := fmt.Sprintf("udp://%s", viper.GetString("vault.audit_address"))
41-
server, err := auditserver.New(nil)
34+
server, err := auditserver.New(logger)
4235
if err != nil {
43-
logger.Error(err.Error())
36+
return fmt.Errorf("failed to create audit server: %w", err)
4437
}
45-
log.Fatal(gnet.Serve(server, addr, gnet.WithMulticore(true)))
38+
return gnet.Serve(server, addr, gnet.WithMulticore(true))
4639
},
4740
}
4841

cmd/auditServer_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
Copyright © 2024 Juliano Martinez <[email protected]>
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
package cmd
17+
18+
import (
19+
"testing"
20+
21+
"github.com/spf13/viper"
22+
"github.com/stretchr/testify/assert"
23+
)
24+
25+
func TestAuditServerCmd_InvalidRuleGroups(t *testing.T) {
26+
viper.Reset()
27+
viper.Set("vault.audit_address", "127.0.0.1:1269")
28+
viper.Set("rule_groups", "invalid_value")
29+
30+
err := auditServerCmd.RunE(auditServerCmd, []string{})
31+
assert.Error(t, err)
32+
assert.Contains(t, err.Error(), "failed to create audit server")
33+
}
34+
35+
func TestAuditServerCmd_ValidConfig(t *testing.T) {
36+
viper.Reset()
37+
viper.Set("vault.audit_address", "127.0.0.1:1269")
38+
viper.Set("rule_groups", []map[string]interface{}{
39+
{
40+
"name": "test_group",
41+
"rules": []string{
42+
"Request.Operation == 'read'",
43+
},
44+
"log_file": map[string]interface{}{
45+
"file_path": "/tmp/test-audit.log",
46+
"max_size": 10,
47+
},
48+
},
49+
})
50+
51+
// We can't actually test gnet.Serve without starting a server,
52+
// but we can at least verify the command is properly configured
53+
assert.NotNil(t, auditServerCmd)
54+
assert.Equal(t, "auditServer", auditServerCmd.Use)
55+
assert.NotNil(t, auditServerCmd.RunE)
56+
}

cmd/root.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ func initConfig() {
104104
fmt.Fprintln(os.Stderr, "Using config file:", viper.ConfigFileUsed())
105105
}
106106

107-
if !viper.IsSet("rule_groups") || len(viper.GetStringSlice("rule_groups")) == 0 {
107+
ruleGroups := viper.Get("rule_groups")
108+
if ruleGroups == nil {
109+
logger.Info("No rules defined in configuration; all audit logs will be printed")
110+
} else if slice, ok := ruleGroups.([]interface{}); ok && len(slice) == 0 {
108111
logger.Info("No rules defined in configuration; all audit logs will be printed")
109112
}
110113
}

cmd/root_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
Copyright © 2024 Juliano Martinez <[email protected]>
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
package cmd
17+
18+
import (
19+
"bytes"
20+
"os"
21+
"path/filepath"
22+
"testing"
23+
24+
"github.com/spf13/viper"
25+
"github.com/stretchr/testify/assert"
26+
)
27+
28+
func TestRootCmd_Exists(t *testing.T) {
29+
assert.NotNil(t, rootCmd)
30+
assert.Equal(t, "vault-audit-filter", rootCmd.Use)
31+
}
32+
33+
func TestRootCmd_HasSubcommands(t *testing.T) {
34+
commands := rootCmd.Commands()
35+
assert.GreaterOrEqual(t, len(commands), 2)
36+
37+
var hasSetup, hasAuditServer bool
38+
for _, cmd := range commands {
39+
if cmd.Use == "setup" {
40+
hasSetup = true
41+
}
42+
if cmd.Use == "auditServer" {
43+
hasAuditServer = true
44+
}
45+
}
46+
assert.True(t, hasSetup, "setup command should be registered")
47+
assert.True(t, hasAuditServer, "auditServer command should be registered")
48+
}
49+
50+
func TestRootCmd_PersistentFlags(t *testing.T) {
51+
flags := rootCmd.PersistentFlags()
52+
53+
configFlag := flags.Lookup("config")
54+
assert.NotNil(t, configFlag)
55+
56+
vaultAddressFlag := flags.Lookup("vault.address")
57+
assert.NotNil(t, vaultAddressFlag)
58+
assert.Equal(t, "http://127.0.0.1:8200", vaultAddressFlag.DefValue)
59+
60+
vaultTokenFlag := flags.Lookup("vault.token")
61+
assert.NotNil(t, vaultTokenFlag)
62+
63+
vaultAuditPathFlag := flags.Lookup("vault.audit_path")
64+
assert.NotNil(t, vaultAuditPathFlag)
65+
66+
vaultAuditAddressFlag := flags.Lookup("vault.audit_address")
67+
assert.NotNil(t, vaultAuditAddressFlag)
68+
assert.Equal(t, "127.0.0.1:1269", vaultAuditAddressFlag.DefValue)
69+
}
70+
71+
func TestInitConfig_WithConfigFile(t *testing.T) {
72+
// Create a temporary config file
73+
tmpDir := t.TempDir()
74+
configPath := filepath.Join(tmpDir, "test-config.yaml")
75+
76+
configContent := `
77+
vault:
78+
address: "http://test-vault:8200"
79+
token: "test-token"
80+
rule_groups:
81+
- name: test
82+
rules:
83+
- "Request.Operation == 'read'"
84+
`
85+
err := os.WriteFile(configPath, []byte(configContent), 0644)
86+
assert.NoError(t, err)
87+
88+
// Reset viper and set config file
89+
viper.Reset()
90+
cfgFile = configPath
91+
initConfig()
92+
93+
assert.Equal(t, "http://test-vault:8200", viper.GetString("vault.address"))
94+
assert.Equal(t, "test-token", viper.GetString("vault.token"))
95+
96+
ruleGroups := viper.Get("rule_groups")
97+
assert.NotNil(t, ruleGroups)
98+
99+
// Reset for other tests
100+
cfgFile = ""
101+
viper.Reset()
102+
}
103+
104+
func TestInitConfig_NoRuleGroups(t *testing.T) {
105+
viper.Reset()
106+
cfgFile = ""
107+
108+
// initConfig should log that no rules are defined
109+
// We just verify it doesn't panic
110+
initConfig()
111+
112+
// Reset for other tests
113+
viper.Reset()
114+
}
115+
116+
func TestExecute_Help(t *testing.T) {
117+
// Test that Execute doesn't panic with help flag
118+
oldArgs := os.Args
119+
defer func() { os.Args = oldArgs }()
120+
121+
os.Args = []string{"vault-audit-filter", "--help"}
122+
123+
// Capture output
124+
rootCmd.SetOut(&bytes.Buffer{})
125+
rootCmd.SetErr(&bytes.Buffer{})
126+
127+
// This should not panic
128+
// We don't call Execute() directly as it calls os.Exit
129+
err := rootCmd.Help()
130+
assert.NoError(t, err)
131+
}

cmd/setup.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ limitations under the License.
1616
package cmd
1717

1818
import (
19-
"os"
19+
"fmt"
2020

2121
"github.com/ncode/vault-audit-filter/pkg/vault"
2222
"github.com/spf13/cobra"
@@ -27,18 +27,17 @@ import (
2727
var setupCmd = &cobra.Command{
2828
Use: "setup",
2929
Short: "Setup vault audit device",
30-
Long: ``,
31-
Run: func(cmd *cobra.Command, args []string) {
30+
Long: `Configures Vault to send audit logs to this service via UDP socket.`,
31+
RunE: func(cmd *cobra.Command, args []string) error {
3232
if viper.GetString("vault.token") == "" {
33-
logger.Error("vault.token is required")
34-
os.Exit(1)
33+
return fmt.Errorf("vault.token is required")
3534
}
3635

3736
client, err := vault.NewVaultClient(viper.GetString("vault.address"), vault.TokenAuth{Token: viper.GetString("vault.token")})
3837
if err != nil {
39-
logger.Error("setup", "unable to setup vault client", err.Error())
40-
os.Exit(1)
38+
return fmt.Errorf("unable to setup vault client: %w", err)
4139
}
40+
4241
err = client.EnableAuditDevice(
4342
viper.GetString("vault.audit_path"),
4443
"socket",
@@ -51,9 +50,13 @@ var setupCmd = &cobra.Command{
5150
},
5251
)
5352
if err != nil {
54-
logger.Error("setup", "unable to enable audit device", err.Error())
55-
os.Exit(1)
53+
return fmt.Errorf("unable to enable audit device: %w", err)
5654
}
55+
56+
logger.Info("Vault audit device configured successfully",
57+
"path", viper.GetString("vault.audit_path"),
58+
"address", viper.GetString("vault.audit_address"))
59+
return nil
5760
},
5861
}
5962

0 commit comments

Comments
 (0)