Skip to content

Commit 079369b

Browse files
committed
chore: refactor server cmd initialization for tests, add tests
1 parent d9066a9 commit 079369b

File tree

3 files changed

+224
-22
lines changed

3 files changed

+224
-22
lines changed

cmd/root.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ func Execute() {
2525
}
2626

2727
func init() {
28-
rootCmd.AddCommand(server.ServerCmd)
28+
rootCmd.AddCommand(server.CreateServerCmd())
2929
rootCmd.AddCommand(attach.AttachCmd)
3030
}

cmd/server/server.go

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -141,21 +141,6 @@ var agentNames = (func() []string {
141141
return names
142142
})()
143143

144-
var ServerCmd = &cobra.Command{
145-
Use: "server [agent]",
146-
Short: "Run the server",
147-
Long: fmt.Sprintf("Run the server with the specified agent (one of: %s)", strings.Join(agentNames, ", ")),
148-
Args: cobra.MinimumNArgs(1),
149-
Run: func(cmd *cobra.Command, args []string) {
150-
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
151-
ctx := logctx.WithLogger(context.Background(), logger)
152-
if err := runServer(ctx, logger, cmd.Flags().Args()); err != nil {
153-
fmt.Fprintf(os.Stderr, "%+v\n", err)
154-
os.Exit(1)
155-
}
156-
},
157-
}
158-
159144
type flagSpec struct {
160145
name string
161146
shorthand string
@@ -173,7 +158,22 @@ const (
173158
FlagTermHeight = "term-height"
174159
)
175160

176-
func init() {
161+
func CreateServerCmd() *cobra.Command {
162+
serverCmd := &cobra.Command{
163+
Use: "server [agent]",
164+
Short: "Run the server",
165+
Long: fmt.Sprintf("Run the server with the specified agent (one of: %s)", strings.Join(agentNames, ", ")),
166+
Args: cobra.MinimumNArgs(1),
167+
Run: func(cmd *cobra.Command, args []string) {
168+
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
169+
ctx := logctx.WithLogger(context.Background(), logger)
170+
if err := runServer(ctx, logger, cmd.Flags().Args()); err != nil {
171+
fmt.Fprintf(os.Stderr, "%+v\n", err)
172+
os.Exit(1)
173+
}
174+
},
175+
}
176+
177177
flagSpecs := []flagSpec{
178178
{FlagType, "t", "", fmt.Sprintf("Override the agent type (one of: %s, custom)", strings.Join(agentNames, ", ")), "string"},
179179
{FlagPort, "p", 3284, "Port to run the server on", "int"},
@@ -186,22 +186,24 @@ func init() {
186186
for _, spec := range flagSpecs {
187187
switch spec.flagType {
188188
case "string":
189-
ServerCmd.Flags().StringP(spec.name, spec.shorthand, spec.defaultValue.(string), spec.usage)
189+
serverCmd.Flags().StringP(spec.name, spec.shorthand, spec.defaultValue.(string), spec.usage)
190190
case "int":
191-
ServerCmd.Flags().IntP(spec.name, spec.shorthand, spec.defaultValue.(int), spec.usage)
191+
serverCmd.Flags().IntP(spec.name, spec.shorthand, spec.defaultValue.(int), spec.usage)
192192
case "bool":
193-
ServerCmd.Flags().BoolP(spec.name, spec.shorthand, spec.defaultValue.(bool), spec.usage)
193+
serverCmd.Flags().BoolP(spec.name, spec.shorthand, spec.defaultValue.(bool), spec.usage)
194194
case "uint16":
195-
ServerCmd.Flags().Uint16P(spec.name, spec.shorthand, spec.defaultValue.(uint16), spec.usage)
195+
serverCmd.Flags().Uint16P(spec.name, spec.shorthand, spec.defaultValue.(uint16), spec.usage)
196196
default:
197197
panic(fmt.Sprintf("unknown flag type: %s", spec.flagType))
198198
}
199-
if err := viper.BindPFlag(spec.name, ServerCmd.Flags().Lookup(spec.name)); err != nil {
199+
if err := viper.BindPFlag(spec.name, serverCmd.Flags().Lookup(spec.name)); err != nil {
200200
panic(fmt.Sprintf("failed to bind flag %s: %v", spec.name, err))
201201
}
202202
}
203203

204204
viper.SetEnvPrefix("AGENTAPI")
205205
viper.AutomaticEnv()
206206
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
207+
208+
return serverCmd
207209
}

cmd/server/server_test.go

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@ package server
22

33
import (
44
"fmt"
5+
"os"
6+
"strings"
57
"testing"
68

9+
"github.com/spf13/cobra"
10+
"github.com/spf13/viper"
11+
"github.com/stretchr/testify/assert"
712
"github.com/stretchr/testify/require"
813
)
914

@@ -83,3 +88,198 @@ func TestParseAgentType(t *testing.T) {
8388
require.Error(t, err)
8489
})
8590
}
91+
92+
// Test helper to isolate viper config between tests
93+
func isolateViper(t *testing.T) {
94+
// Save current state
95+
oldConfig := viper.AllSettings()
96+
97+
// Reset viper
98+
viper.Reset()
99+
100+
// Clear AGENTAPI_ env vars
101+
var agentapiEnvs []string
102+
for _, env := range os.Environ() {
103+
if strings.HasPrefix(env, "AGENTAPI_") {
104+
parts := strings.SplitN(env, "=", 2)
105+
agentapiEnvs = append(agentapiEnvs, parts[0])
106+
os.Unsetenv(parts[0])
107+
}
108+
}
109+
110+
t.Cleanup(func() {
111+
// Restore state
112+
viper.Reset()
113+
for key, value := range oldConfig {
114+
viper.Set(key, value)
115+
}
116+
117+
// Restore env vars
118+
for _, key := range agentapiEnvs {
119+
if val := os.Getenv(key); val != "" {
120+
os.Setenv(key, val)
121+
}
122+
}
123+
})
124+
}
125+
126+
// Test configuration values via ServerCmd execution
127+
func TestServerCmd_AllArgs_Defaults(t *testing.T) {
128+
tests := []struct {
129+
name string
130+
flag string
131+
expected any
132+
getter func() any
133+
}{
134+
{"type default", FlagType, "", func() any { return viper.GetString(FlagType) }},
135+
{"port default", FlagPort, 3284, func() any { return viper.GetInt(FlagPort) }},
136+
{"print-openapi default", FlagPrintOpenAPI, false, func() any { return viper.GetBool(FlagPrintOpenAPI) }},
137+
{"chat-base-path default", FlagChatBasePath, "/chat", func() any { return viper.GetString(FlagChatBasePath) }},
138+
{"term-width default", FlagTermWidth, uint16(80), func() any { return viper.GetUint16(FlagTermWidth) }},
139+
{"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }},
140+
}
141+
142+
for _, tt := range tests {
143+
t.Run(tt.name, func(t *testing.T) {
144+
isolateViper(t)
145+
serverCmd := CreateServerCmd()
146+
cmd := &cobra.Command{}
147+
cmd.AddCommand(serverCmd)
148+
149+
// Execute with no args to get defaults
150+
serverCmd.SetArgs([]string{"--help"}) // Use help to avoid actual execution
151+
serverCmd.Execute()
152+
153+
assert.Equal(t, tt.expected, tt.getter())
154+
})
155+
}
156+
}
157+
158+
func TestServerCmd_AllEnvVars(t *testing.T) {
159+
tests := []struct {
160+
name string
161+
envVar string
162+
envValue string
163+
expected any
164+
getter func() any
165+
}{
166+
{"AGENTAPI_TYPE", "AGENTAPI_TYPE", "claude", "claude", func() any { return viper.GetString(FlagType) }},
167+
{"AGENTAPI_PORT", "AGENTAPI_PORT", "8080", 8080, func() any { return viper.GetInt(FlagPort) }},
168+
{"AGENTAPI_PRINT_OPENAPI", "AGENTAPI_PRINT_OPENAPI", "true", true, func() any { return viper.GetBool(FlagPrintOpenAPI) }},
169+
{"AGENTAPI_CHAT_BASE_PATH", "AGENTAPI_CHAT_BASE_PATH", "/api", "/api", func() any { return viper.GetString(FlagChatBasePath) }},
170+
{"AGENTAPI_TERM_WIDTH", "AGENTAPI_TERM_WIDTH", "120", uint16(120), func() any { return viper.GetUint16(FlagTermWidth) }},
171+
{"AGENTAPI_TERM_HEIGHT", "AGENTAPI_TERM_HEIGHT", "500", uint16(500), func() any { return viper.GetUint16(FlagTermHeight) }},
172+
}
173+
174+
for _, tt := range tests {
175+
t.Run(tt.name, func(t *testing.T) {
176+
isolateViper(t)
177+
os.Setenv(tt.envVar, tt.envValue)
178+
defer os.Unsetenv(tt.envVar)
179+
180+
serverCmd := CreateServerCmd()
181+
cmd := &cobra.Command{}
182+
cmd.AddCommand(serverCmd)
183+
184+
serverCmd.SetArgs([]string{"--help"})
185+
serverCmd.Execute()
186+
187+
assert.Equal(t, tt.expected, tt.getter())
188+
})
189+
}
190+
}
191+
192+
func TestServerCmd_ArgsPrecedenceOverEnv(t *testing.T) {
193+
tests := []struct {
194+
name string
195+
envVar string
196+
envValue string
197+
args []string
198+
expected any
199+
getter func() any
200+
}{
201+
{
202+
"type: CLI overrides env",
203+
"AGENTAPI_TYPE", "goose",
204+
[]string{"--type", "claude"},
205+
"claude",
206+
func() any { return viper.GetString(FlagType) },
207+
},
208+
{
209+
"port: CLI overrides env",
210+
"AGENTAPI_PORT", "8080",
211+
[]string{"--port", "9090"},
212+
9090,
213+
func() any { return viper.GetInt(FlagPort) },
214+
},
215+
{
216+
"print-openapi: CLI overrides env",
217+
"AGENTAPI_PRINT_OPENAPI", "false",
218+
[]string{"--print-openapi"},
219+
true,
220+
func() any { return viper.GetBool(FlagPrintOpenAPI) },
221+
},
222+
{
223+
"chat-base-path: CLI overrides env",
224+
"AGENTAPI_CHAT_BASE_PATH", "/env-path",
225+
[]string{"--chat-base-path", "/cli-path"},
226+
"/cli-path",
227+
func() any { return viper.GetString(FlagChatBasePath) },
228+
},
229+
{
230+
"term-width: CLI overrides env",
231+
"AGENTAPI_TERM_WIDTH", "100",
232+
[]string{"--term-width", "150"},
233+
uint16(150),
234+
func() any { return viper.GetUint16(FlagTermWidth) },
235+
},
236+
{
237+
"term-height: CLI overrides env",
238+
"AGENTAPI_TERM_HEIGHT", "500",
239+
[]string{"--term-height", "600"},
240+
uint16(600),
241+
func() any { return viper.GetUint16(FlagTermHeight) },
242+
},
243+
}
244+
245+
for _, tt := range tests {
246+
t.Run(tt.name, func(t *testing.T) {
247+
isolateViper(t)
248+
os.Setenv(tt.envVar, tt.envValue)
249+
defer os.Unsetenv(tt.envVar)
250+
251+
// Mock execution to test arg parsing without running server
252+
args := append(tt.args, "--help")
253+
serverCmd := CreateServerCmd()
254+
serverCmd.SetArgs(args)
255+
serverCmd.Execute()
256+
257+
assert.Equal(t, tt.expected, tt.getter())
258+
})
259+
}
260+
}
261+
262+
func TestMixed_ConfigurationScenarios(t *testing.T) {
263+
t.Run("some env, some cli, some defaults", func(t *testing.T) {
264+
isolateViper(t)
265+
266+
// Set some env vars
267+
os.Setenv("AGENTAPI_TYPE", "goose")
268+
os.Setenv("AGENTAPI_TERM_WIDTH", "120")
269+
defer os.Unsetenv("AGENTAPI_TYPE")
270+
defer os.Unsetenv("AGENTAPI_TERM_WIDTH")
271+
272+
// Set some CLI args
273+
serverCmd := CreateServerCmd()
274+
serverCmd.SetArgs([]string{"--port", "9999", "--print-openapi", "--help"})
275+
serverCmd.Execute()
276+
277+
// Verify mixed configuration
278+
assert.Equal(t, "goose", viper.GetString(FlagType)) // from env
279+
assert.Equal(t, 9999, viper.GetInt(FlagPort)) // from CLI
280+
assert.Equal(t, true, viper.GetBool(FlagPrintOpenAPI)) // from CLI
281+
assert.Equal(t, "/chat", viper.GetString(FlagChatBasePath)) // default
282+
assert.Equal(t, uint16(120), viper.GetUint16(FlagTermWidth)) // from env
283+
assert.Equal(t, uint16(1000), viper.GetUint16(FlagTermHeight)) // default
284+
})
285+
}

0 commit comments

Comments
 (0)