Skip to content

Commit 37d0fd1

Browse files
committed
feat: allowed hosts
1 parent e783ff1 commit 37d0fd1

File tree

6 files changed

+422
-6
lines changed

6 files changed

+422
-6
lines changed

cmd/server/server.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"os"
1010
"sort"
1111
"strings"
12+
"unicode"
1213

1314
"github.com/spf13/cobra"
1415
"github.com/spf13/viper"
@@ -58,6 +59,26 @@ func parseAgentType(firstArg string, agentTypeVar string) (AgentType, error) {
5859
return AgentTypeCustom, nil
5960
}
6061

62+
// Validate allowed hosts don't contain whitespace or commas.
63+
// Viper/Cobra use different separators (space for env vars, comma for flags),
64+
// so these characters in origins likely indicate user error.
65+
func validateAllowedHosts(hosts []string) error {
66+
if len(hosts) == 0 {
67+
return fmt.Errorf("allowed hosts must not be empty")
68+
}
69+
for _, host := range hosts {
70+
for _, r := range host {
71+
if unicode.IsSpace(r) {
72+
return fmt.Errorf("host '%s' contains whitespace characters, which are not allowed", host)
73+
}
74+
if strings.Contains(host, ",") {
75+
return fmt.Errorf("host '%s' contains comma characters, which are not allowed", host)
76+
}
77+
}
78+
}
79+
return nil
80+
}
81+
6182
func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) error {
6283
agent := argsToPass[0]
6384
agentTypeValue := viper.GetString(FlagType)
@@ -95,12 +116,16 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
95116
}
96117
}
97118
port := viper.GetInt(FlagPort)
98-
srv := httpapi.NewServer(ctx, httpapi.ServerConfig{
119+
srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{
99120
AgentType: agentType,
100121
Process: process,
101122
Port: port,
102123
ChatBasePath: viper.GetString(FlagChatBasePath),
124+
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
103125
})
126+
if err != nil {
127+
return xerrors.Errorf("failed to create server: %w", err)
128+
}
104129
if printOpenAPI {
105130
fmt.Println(srv.GetOpenAPI())
106131
return nil
@@ -156,6 +181,7 @@ const (
156181
FlagChatBasePath = "chat-base-path"
157182
FlagTermWidth = "term-width"
158183
FlagTermHeight = "term-height"
184+
FlagAllowedHosts = "allowed-hosts"
159185
)
160186

161187
func CreateServerCmd() *cobra.Command {
@@ -164,6 +190,13 @@ func CreateServerCmd() *cobra.Command {
164190
Short: "Run the server",
165191
Long: fmt.Sprintf("Run the server with the specified agent (one of: %s)", strings.Join(agentNames, ", ")),
166192
Args: cobra.MinimumNArgs(1),
193+
PreRunE: func(cmd *cobra.Command, args []string) error {
194+
allowedOrigins := viper.GetStringSlice(FlagAllowedHosts)
195+
if err := validateAllowedHosts(allowedOrigins); err != nil {
196+
return err
197+
}
198+
return nil
199+
},
167200
Run: func(cmd *cobra.Command, args []string) {
168201
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
169202
ctx := logctx.WithLogger(context.Background(), logger)
@@ -181,6 +214,10 @@ func CreateServerCmd() *cobra.Command {
181214
{FlagChatBasePath, "c", "/chat", "Base path for assets and routes used in the static files of the chat interface", "string"},
182215
{FlagTermWidth, "W", uint16(80), "Width of the emulated terminal", "uint16"},
183216
{FlagTermHeight, "H", uint16(1000), "Height of the emulated terminal", "uint16"},
217+
// localhost:3284 is the default port for the server
218+
// localhost:3000 is the default port for the chat interface during development
219+
// localhost:3001 is used during development for the chat interface if 3000 is already in use
220+
{FlagAllowedHosts, "a", []string{"localhost:3284", "localhost:3000", "localhost:3001"}, "HTTP allowed hosts. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
184221
}
185222

186223
for _, spec := range flagSpecs {
@@ -193,6 +230,8 @@ func CreateServerCmd() *cobra.Command {
193230
serverCmd.Flags().BoolP(spec.name, spec.shorthand, spec.defaultValue.(bool), spec.usage)
194231
case "uint16":
195232
serverCmd.Flags().Uint16P(spec.name, spec.shorthand, spec.defaultValue.(uint16), spec.usage)
233+
case "stringSlice":
234+
serverCmd.Flags().StringSliceP(spec.name, spec.shorthand, spec.defaultValue.([]string), spec.usage)
196235
default:
197236
panic(fmt.Sprintf("unknown flag type: %s", spec.flagType))
198237
}

cmd/server/server_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ func TestServerCmd_AllArgs_Defaults(t *testing.T) {
141141
{"chat-base-path default", FlagChatBasePath, "/chat", func() any { return viper.GetString(FlagChatBasePath) }},
142142
{"term-width default", FlagTermWidth, uint16(80), func() any { return viper.GetUint16(FlagTermWidth) }},
143143
{"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }},
144+
{"allowed-hosts default", FlagAllowedHosts, []string{"localhost:3284"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
144145
}
145146

146147
for _, tt := range tests {
@@ -175,6 +176,7 @@ func TestServerCmd_AllEnvVars(t *testing.T) {
175176
{"AGENTAPI_CHAT_BASE_PATH", "AGENTAPI_CHAT_BASE_PATH", "/api", "/api", func() any { return viper.GetString(FlagChatBasePath) }},
176177
{"AGENTAPI_TERM_WIDTH", "AGENTAPI_TERM_WIDTH", "120", uint16(120), func() any { return viper.GetUint16(FlagTermWidth) }},
177178
{"AGENTAPI_TERM_HEIGHT", "AGENTAPI_TERM_HEIGHT", "500", uint16(500), func() any { return viper.GetUint16(FlagTermHeight) }},
179+
{"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost:3284 localhost:3285", []string{"localhost:3284", "localhost:3285"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
178180
}
179181

180182
for _, tt := range tests {
@@ -291,3 +293,119 @@ func TestMixed_ConfigurationScenarios(t *testing.T) {
291293
assert.Equal(t, uint16(1000), viper.GetUint16(FlagTermHeight)) // default
292294
})
293295
}
296+
297+
func TestServerCmd_AllowedHosts(t *testing.T) {
298+
tests := []struct {
299+
name string
300+
env map[string]string
301+
args []string
302+
expectedErr string
303+
expected []string // only checked if expectedErr is empty
304+
}{
305+
// Environment variable scenarios (space-separated format)
306+
{
307+
name: "env: single valid host",
308+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284"},
309+
args: []string{},
310+
expected: []string{"localhost:3284"},
311+
},
312+
{
313+
name: "env: multiple valid hosts space-separated",
314+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284 example.com 192.168.1.1:8080"},
315+
args: []string{},
316+
expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"},
317+
},
318+
{
319+
name: "env: host with tab",
320+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284\texample.com"},
321+
args: []string{},
322+
expected: []string{"localhost:3284", "example.com"},
323+
},
324+
{
325+
name: "env: host with comma (invalid)",
326+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284,example.com"},
327+
args: []string{},
328+
expectedErr: "contains comma characters",
329+
},
330+
331+
// CLI flag scenarios (comma-separated format)
332+
{
333+
name: "flag: single valid host",
334+
args: []string{"--allowed-hosts", "localhost:3284"},
335+
expected: []string{"localhost:3284"},
336+
},
337+
{
338+
name: "flag: multiple valid hosts comma-separated",
339+
args: []string{"--allowed-hosts", "localhost:3284,example.com,192.168.1.1:8080"},
340+
expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"},
341+
},
342+
{
343+
name: "flag: multiple valid hosts with multiple flags",
344+
args: []string{"--allowed-hosts", "localhost:3284", "--allowed-hosts", "example.com"},
345+
expected: []string{"localhost:3284", "example.com"},
346+
},
347+
{
348+
name: "flag: host with newline",
349+
args: []string{"--allowed-hosts", "localhost:3284\n"},
350+
expected: []string{"localhost:3284"},
351+
},
352+
{
353+
name: "flag: host with space in comma-separated list (invalid)",
354+
args: []string{"--allowed-hosts", "localhost:3284,example .com"},
355+
expectedErr: "contains whitespace characters",
356+
},
357+
358+
// Mixed scenarios (env + flag precedence)
359+
{
360+
name: "mixed: flag overrides env",
361+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"},
362+
args: []string{"--allowed-hosts", "override.com"},
363+
expected: []string{"override.com"},
364+
},
365+
{
366+
name: "mixed: flag overrides env but flag is invalid",
367+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"},
368+
args: []string{"--allowed-hosts", "invalid .com"},
369+
expectedErr: "contains whitespace characters",
370+
},
371+
372+
// Empty hosts are not allowed
373+
{
374+
name: "empty host",
375+
args: []string{"--allowed-hosts", ""},
376+
expectedErr: "allowed hosts must not be empty",
377+
},
378+
379+
// Default behavior
380+
{
381+
name: "default hosts when neither env nor flag provided",
382+
args: []string{},
383+
expected: []string{"localhost:3284", "localhost:3000", "localhost:3001"},
384+
},
385+
}
386+
387+
for _, tt := range tests {
388+
t.Run(tt.name, func(t *testing.T) {
389+
isolateViper(t)
390+
391+
// Set environment variables if provided
392+
for key, value := range tt.env {
393+
t.Setenv(key, value)
394+
}
395+
396+
serverCmd := CreateServerCmd()
397+
// --print-openapi acts as an agent command that immediately exits
398+
// use a 0 port to pick a random free port
399+
serverCmd.SetArgs(append(tt.args, "--port", "0", "--", "sh", "-c", "echo ok"))
400+
err := serverCmd.Execute()
401+
402+
if tt.expectedErr != "" {
403+
require.Error(t, err)
404+
assert.Contains(t, err.Error(), tt.expectedErr)
405+
} else {
406+
require.NoError(t, err)
407+
assert.Equal(t, tt.expected, viper.GetStringSlice(FlagAllowedHosts))
408+
}
409+
})
410+
}
411+
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ require (
5454
github.com/rogpeppe/go-internal v1.14.1 // indirect
5555
github.com/spf13/afero v1.14.0
5656
github.com/spf13/pflag v1.0.6 // indirect
57+
github.com/unrolled/secure v1.17.0
5758
golang.org/x/sync v0.12.0 // indirect
5859
golang.org/x/sys v0.31.0 // indirect
5960
golang.org/x/text v0.23.0 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8
113113
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
114114
github.com/tmaxmax/go-sse v0.10.0 h1:j9F93WB4Hxt8wUf6oGffMm4dutALvUPoDDxfuDQOSqA=
115115
github.com/tmaxmax/go-sse v0.10.0/go.mod h1:u/2kZQR1tyngo1lKaNCj1mJmhXGZWS1Zs5yiSOD+Eg8=
116+
github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU=
117+
github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40=
116118
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
117119
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
118120
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=

lib/httpapi/server.go

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"log/slog"
88
"net/http"
99
"net/url"
10+
"slices"
1011
"strings"
1112
"sync"
1213
"time"
@@ -20,6 +21,7 @@ import (
2021
"github.com/danielgtaylor/huma/v2/sse"
2122
"github.com/go-chi/chi/v5"
2223
"github.com/go-chi/cors"
24+
"github.com/unrolled/secure"
2325
"golang.org/x/xerrors"
2426
)
2527

@@ -64,14 +66,55 @@ type ServerConfig struct {
6466
Process *termexec.Process
6567
Port int
6668
ChatBasePath string
69+
AllowedHosts []string
70+
}
71+
72+
func parseAllowedHosts(hosts []string) ([]string, error) {
73+
if slices.Contains(hosts, "*") {
74+
return []string{}, nil
75+
}
76+
for _, host := range hosts {
77+
if strings.Contains(host, "*") {
78+
return nil, xerrors.Errorf("wildcard characters are not supported: %q", host)
79+
}
80+
if strings.Contains(host, "http://") || strings.Contains(host, "https://") {
81+
return nil, xerrors.Errorf("host must not contain http:// or https://: %q", host)
82+
}
83+
}
84+
return hosts, nil
85+
}
86+
87+
func hostsToOrigins(hosts []string) []string {
88+
if len(hosts) == 0 {
89+
return []string{"*"}
90+
}
91+
origins := []string{}
92+
for _, host := range hosts {
93+
origins = append(origins, "http://"+host)
94+
origins = append(origins, "https://"+host)
95+
}
96+
return origins
6797
}
6898

6999
// NewServer creates a new server instance
70-
func NewServer(ctx context.Context, config ServerConfig) *Server {
100+
func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
71101
router := chi.NewMux()
72102

103+
allowedHosts, err := parseAllowedHosts(config.AllowedHosts)
104+
if err != nil {
105+
return nil, xerrors.Errorf("failed to validate allowed hosts: %w", err)
106+
}
107+
secureMiddleware := secure.New(secure.Options{
108+
AllowedHosts: allowedHosts,
109+
})
110+
badHostHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
111+
http.Error(w, "Invalid host header. Allowed hosts: "+strings.Join(allowedHosts, ", "), http.StatusBadRequest)
112+
})
113+
secureMiddleware.SetBadHostHandler(badHostHandler)
114+
router.Use(secureMiddleware.Handler)
115+
73116
corsMiddleware := cors.New(cors.Options{
74-
AllowedOrigins: []string{"*"},
117+
AllowedOrigins: hostsToOrigins(allowedHosts),
75118
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
76119
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
77120
ExposedHeaders: []string{"Link"},
@@ -111,7 +154,7 @@ func NewServer(ctx context.Context, config ServerConfig) *Server {
111154
// Register API routes
112155
s.registerRoutes()
113156

114-
return s
157+
return s, nil
115158
}
116159

117160
// Handler returns the underlying chi.Router for testing purposes.

0 commit comments

Comments
 (0)