Skip to content

Commit 87cefa1

Browse files
committed
feat: allowed hosts
1 parent e783ff1 commit 87cefa1

File tree

6 files changed

+467
-18
lines changed

6 files changed

+467
-18
lines changed

cmd/server/server.go

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"os"
1010
"sort"
1111
"strings"
12+
"unicode"
1213

1314
"github.com/spf13/cobra"
1415
"github.com/spf13/viper"
@@ -58,6 +59,26 @@ func parseAgentType(firstArg string, agentTypeVar string) (AgentType, error) {
5859
return AgentTypeCustom, nil
5960
}
6061

62+
// Validate allowed hosts don't contain whitespace or commas.
63+
// Viper/Cobra use different separators (space for env vars, comma for flags),
64+
// so these characters in origins likely indicate user error.
65+
func validateAllowedHosts(hosts []string) error {
66+
if len(hosts) == 0 {
67+
return fmt.Errorf("allowed hosts must not be empty")
68+
}
69+
for _, host := range hosts {
70+
for _, r := range host {
71+
if unicode.IsSpace(r) {
72+
return fmt.Errorf("host '%s' contains whitespace characters, which are not allowed", host)
73+
}
74+
if strings.Contains(host, ",") {
75+
return fmt.Errorf("host '%s' contains comma characters, which are not allowed", host)
76+
}
77+
}
78+
}
79+
return nil
80+
}
81+
6182
func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) error {
6283
agent := argsToPass[0]
6384
agentTypeValue := viper.GetString(FlagType)
@@ -95,12 +116,16 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
95116
}
96117
}
97118
port := viper.GetInt(FlagPort)
98-
srv := httpapi.NewServer(ctx, httpapi.ServerConfig{
119+
srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{
99120
AgentType: agentType,
100121
Process: process,
101122
Port: port,
102123
ChatBasePath: viper.GetString(FlagChatBasePath),
124+
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
103125
})
126+
if err != nil {
127+
return xerrors.Errorf("failed to create server: %w", err)
128+
}
104129
if printOpenAPI {
105130
fmt.Println(srv.GetOpenAPI())
106131
return nil
@@ -156,6 +181,8 @@ const (
156181
FlagChatBasePath = "chat-base-path"
157182
FlagTermWidth = "term-width"
158183
FlagTermHeight = "term-height"
184+
FlagAllowedHosts = "allowed-hosts"
185+
FlagExit = "exit"
159186
)
160187

161188
func CreateServerCmd() *cobra.Command {
@@ -164,7 +191,18 @@ func CreateServerCmd() *cobra.Command {
164191
Short: "Run the server",
165192
Long: fmt.Sprintf("Run the server with the specified agent (one of: %s)", strings.Join(agentNames, ", ")),
166193
Args: cobra.MinimumNArgs(1),
194+
PreRunE: func(cmd *cobra.Command, args []string) error {
195+
allowedOrigins := viper.GetStringSlice(FlagAllowedHosts)
196+
if err := validateAllowedHosts(allowedOrigins); err != nil {
197+
return err
198+
}
199+
return nil
200+
},
167201
Run: func(cmd *cobra.Command, args []string) {
202+
// The --exit flag is used for testing validation of flags in the test suite
203+
if viper.GetBool(FlagExit) {
204+
return
205+
}
168206
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
169207
ctx := logctx.WithLogger(context.Background(), logger)
170208
if err := runServer(ctx, logger, cmd.Flags().Args()); err != nil {
@@ -181,6 +219,10 @@ func CreateServerCmd() *cobra.Command {
181219
{FlagChatBasePath, "c", "/chat", "Base path for assets and routes used in the static files of the chat interface", "string"},
182220
{FlagTermWidth, "W", uint16(80), "Width of the emulated terminal", "uint16"},
183221
{FlagTermHeight, "H", uint16(1000), "Height of the emulated terminal", "uint16"},
222+
// localhost:3284 is the default port for the server
223+
// localhost:3000 is the default port for the chat interface during development
224+
// localhost:3001 is used during development for the chat interface if 3000 is already in use
225+
{FlagAllowedHosts, "a", []string{"localhost:3284", "localhost:3000", "localhost:3001"}, "HTTP allowed hosts. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
184226
}
185227

186228
for _, spec := range flagSpecs {
@@ -193,6 +235,8 @@ func CreateServerCmd() *cobra.Command {
193235
serverCmd.Flags().BoolP(spec.name, spec.shorthand, spec.defaultValue.(bool), spec.usage)
194236
case "uint16":
195237
serverCmd.Flags().Uint16P(spec.name, spec.shorthand, spec.defaultValue.(uint16), spec.usage)
238+
case "stringSlice":
239+
serverCmd.Flags().StringSliceP(spec.name, spec.shorthand, spec.defaultValue.([]string), spec.usage)
196240
default:
197241
panic(fmt.Sprintf("unknown flag type: %s", spec.flagType))
198242
}
@@ -201,6 +245,12 @@ func CreateServerCmd() *cobra.Command {
201245
}
202246
}
203247

248+
serverCmd.Flags().Bool(FlagExit, false, "Exit immediately after parsing arguments")
249+
serverCmd.Flags().MarkHidden(FlagExit)
250+
if err := viper.BindPFlag(FlagExit, serverCmd.Flags().Lookup(FlagExit)); err != nil {
251+
panic(fmt.Sprintf("failed to bind flag %s: %v", FlagExit, err))
252+
}
253+
204254
viper.SetEnvPrefix("AGENTAPI")
205255
viper.AutomaticEnv()
206256
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))

cmd/server/server_test.go

Lines changed: 152 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,38 @@ import (
66
"strings"
77
"testing"
88

9+
"github.com/ActiveState/termtest/xpty"
910
"github.com/spf13/cobra"
1011
"github.com/spf13/viper"
1112
"github.com/stretchr/testify/assert"
1213
"github.com/stretchr/testify/require"
1314
)
1415

16+
// setupCommandWithPTY configures a cobra command to use xpty for output capture
17+
// and returns the xpty instance for reading captured output.
18+
// The caller is responsible for closing the returned xpty.
19+
func setupCommandWithPTY(t *testing.T, cmd *cobra.Command) *xpty.Xpty {
20+
t.Helper()
21+
22+
// Create virtual PTY with 100x100 dimensions
23+
xp, err := xpty.New(100, 100, false)
24+
require.NoError(t, err, "failed to create xpty")
25+
26+
// Configure command to write to the PTY
27+
ptyWriter := xp.TerminalInPipe()
28+
cmd.SetOut(ptyWriter)
29+
cmd.SetErr(ptyWriter)
30+
31+
// Setup cleanup to close PTY when test completes
32+
t.Cleanup(func() {
33+
if err := xp.Close(); err != nil {
34+
t.Logf("Warning: failed to close xpty: %v", err)
35+
}
36+
})
37+
38+
return xp
39+
}
40+
1541
func TestParseAgentType(t *testing.T) {
1642
tests := []struct {
1743
firstArg string
@@ -141,17 +167,16 @@ func TestServerCmd_AllArgs_Defaults(t *testing.T) {
141167
{"chat-base-path default", FlagChatBasePath, "/chat", func() any { return viper.GetString(FlagChatBasePath) }},
142168
{"term-width default", FlagTermWidth, uint16(80), func() any { return viper.GetUint16(FlagTermWidth) }},
143169
{"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }},
170+
{"allowed-hosts default", FlagAllowedHosts, []string{"localhost:3284", "localhost:3000", "localhost:3001"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
144171
}
145172

146173
for _, tt := range tests {
147174
t.Run(tt.name, func(t *testing.T) {
148175
isolateViper(t)
149176
serverCmd := CreateServerCmd()
150-
cmd := &cobra.Command{}
151-
cmd.AddCommand(serverCmd)
152-
153-
// Execute with no args to get defaults
154-
serverCmd.SetArgs([]string{"--help"}) // Use help to avoid actual execution
177+
setupCommandWithPTY(t, serverCmd)
178+
// Execute with --exit to get defaults
179+
serverCmd.SetArgs([]string{"--exit", "dummy-command"})
155180
if err := serverCmd.Execute(); err != nil {
156181
t.Fatalf("Failed to execute server command: %v", err)
157182
}
@@ -175,6 +200,7 @@ func TestServerCmd_AllEnvVars(t *testing.T) {
175200
{"AGENTAPI_CHAT_BASE_PATH", "AGENTAPI_CHAT_BASE_PATH", "/api", "/api", func() any { return viper.GetString(FlagChatBasePath) }},
176201
{"AGENTAPI_TERM_WIDTH", "AGENTAPI_TERM_WIDTH", "120", uint16(120), func() any { return viper.GetUint16(FlagTermWidth) }},
177202
{"AGENTAPI_TERM_HEIGHT", "AGENTAPI_TERM_HEIGHT", "500", uint16(500), func() any { return viper.GetUint16(FlagTermHeight) }},
203+
{"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost:3284 localhost:3285", []string{"localhost:3284", "localhost:3285"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
178204
}
179205

180206
for _, tt := range tests {
@@ -183,10 +209,8 @@ func TestServerCmd_AllEnvVars(t *testing.T) {
183209
t.Setenv(tt.envVar, tt.envValue)
184210

185211
serverCmd := CreateServerCmd()
186-
cmd := &cobra.Command{}
187-
cmd.AddCommand(serverCmd)
188-
189-
serverCmd.SetArgs([]string{"--help"})
212+
setupCommandWithPTY(t, serverCmd)
213+
serverCmd.SetArgs([]string{"--exit", "dummy-command"})
190214
if err := serverCmd.Execute(); err != nil {
191215
t.Fatalf("Failed to execute server command: %v", err)
192216
}
@@ -254,9 +278,9 @@ func TestServerCmd_ArgsPrecedenceOverEnv(t *testing.T) {
254278
isolateViper(t)
255279
t.Setenv(tt.envVar, tt.envValue)
256280

257-
// Mock execution to test arg parsing without running server
258-
args := append(tt.args, "--help")
281+
args := append(tt.args, "--exit", "dummy-command")
259282
serverCmd := CreateServerCmd()
283+
setupCommandWithPTY(t, serverCmd)
260284
serverCmd.SetArgs(args)
261285
if err := serverCmd.Execute(); err != nil {
262286
t.Fatalf("Failed to execute server command: %v", err)
@@ -277,7 +301,8 @@ func TestMixed_ConfigurationScenarios(t *testing.T) {
277301

278302
// Set some CLI args
279303
serverCmd := CreateServerCmd()
280-
serverCmd.SetArgs([]string{"--port", "9999", "--print-openapi", "--help"})
304+
setupCommandWithPTY(t, serverCmd)
305+
serverCmd.SetArgs([]string{"--port", "9999", "--print-openapi", "--exit", "dummy-command"})
281306
if err := serverCmd.Execute(); err != nil {
282307
t.Fatalf("Failed to execute server command: %v", err)
283308
}
@@ -291,3 +316,118 @@ func TestMixed_ConfigurationScenarios(t *testing.T) {
291316
assert.Equal(t, uint16(1000), viper.GetUint16(FlagTermHeight)) // default
292317
})
293318
}
319+
320+
func TestServerCmd_AllowedHosts(t *testing.T) {
321+
tests := []struct {
322+
name string
323+
env map[string]string
324+
args []string
325+
expectedErr string
326+
expected []string // only checked if expectedErr is empty
327+
}{
328+
// Environment variable scenarios (space-separated format)
329+
{
330+
name: "env: single valid host",
331+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284"},
332+
args: []string{},
333+
expected: []string{"localhost:3284"},
334+
},
335+
{
336+
name: "env: multiple valid hosts space-separated",
337+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284 example.com 192.168.1.1:8080"},
338+
args: []string{},
339+
expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"},
340+
},
341+
{
342+
name: "env: host with tab",
343+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284\texample.com"},
344+
args: []string{},
345+
expected: []string{"localhost:3284", "example.com"},
346+
},
347+
{
348+
name: "env: host with comma (invalid)",
349+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284,example.com"},
350+
args: []string{},
351+
expectedErr: "contains comma characters",
352+
},
353+
354+
// CLI flag scenarios (comma-separated format)
355+
{
356+
name: "flag: single valid host",
357+
args: []string{"--allowed-hosts", "localhost:3284"},
358+
expected: []string{"localhost:3284"},
359+
},
360+
{
361+
name: "flag: multiple valid hosts comma-separated",
362+
args: []string{"--allowed-hosts", "localhost:3284,example.com,192.168.1.1:8080"},
363+
expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"},
364+
},
365+
{
366+
name: "flag: multiple valid hosts with multiple flags",
367+
args: []string{"--allowed-hosts", "localhost:3284", "--allowed-hosts", "example.com"},
368+
expected: []string{"localhost:3284", "example.com"},
369+
},
370+
{
371+
name: "flag: host with newline",
372+
args: []string{"--allowed-hosts", "localhost:3284\n"},
373+
expected: []string{"localhost:3284"},
374+
},
375+
{
376+
name: "flag: host with space in comma-separated list (invalid)",
377+
args: []string{"--allowed-hosts", "localhost:3284,example .com"},
378+
expectedErr: "contains whitespace characters",
379+
},
380+
381+
// Mixed scenarios (env + flag precedence)
382+
{
383+
name: "mixed: flag overrides env",
384+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"},
385+
args: []string{"--allowed-hosts", "override.com"},
386+
expected: []string{"override.com"},
387+
},
388+
{
389+
name: "mixed: flag overrides env but flag is invalid",
390+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"},
391+
args: []string{"--allowed-hosts", "invalid .com"},
392+
expectedErr: "contains whitespace characters",
393+
},
394+
395+
// Empty hosts are not allowed
396+
{
397+
name: "empty host",
398+
args: []string{"--allowed-hosts", ""},
399+
expectedErr: "allowed hosts must not be empty",
400+
},
401+
402+
// Default behavior
403+
{
404+
name: "default hosts when neither env nor flag provided",
405+
args: []string{},
406+
expected: []string{"localhost:3284", "localhost:3000", "localhost:3001"},
407+
},
408+
}
409+
410+
for _, tt := range tests {
411+
t.Run(tt.name, func(t *testing.T) {
412+
isolateViper(t)
413+
414+
// Set environment variables if provided
415+
for key, value := range tt.env {
416+
t.Setenv(key, value)
417+
}
418+
419+
serverCmd := CreateServerCmd()
420+
setupCommandWithPTY(t, serverCmd)
421+
serverCmd.SetArgs(append(tt.args, "--exit", "dummy-command"))
422+
err := serverCmd.Execute()
423+
424+
if tt.expectedErr != "" {
425+
require.Error(t, err)
426+
assert.Contains(t, err.Error(), tt.expectedErr)
427+
} else {
428+
require.NoError(t, err)
429+
assert.Equal(t, tt.expected, viper.GetStringSlice(FlagAllowedHosts))
430+
}
431+
})
432+
}
433+
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ require (
5454
github.com/rogpeppe/go-internal v1.14.1 // indirect
5555
github.com/spf13/afero v1.14.0
5656
github.com/spf13/pflag v1.0.6 // indirect
57+
github.com/unrolled/secure v1.17.0
5758
golang.org/x/sync v0.12.0 // indirect
5859
golang.org/x/sys v0.31.0 // indirect
5960
golang.org/x/text v0.23.0 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8
113113
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
114114
github.com/tmaxmax/go-sse v0.10.0 h1:j9F93WB4Hxt8wUf6oGffMm4dutALvUPoDDxfuDQOSqA=
115115
github.com/tmaxmax/go-sse v0.10.0/go.mod h1:u/2kZQR1tyngo1lKaNCj1mJmhXGZWS1Zs5yiSOD+Eg8=
116+
github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU=
117+
github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40=
116118
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
117119
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
118120
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=

0 commit comments

Comments
 (0)