Skip to content

Commit d2400b9

Browse files
committed
fix: strip port from host
1 parent f1b18a2 commit d2400b9

File tree

7 files changed

+217
-42
lines changed

7 files changed

+217
-42
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ There are 4 endpoints:
8080

8181
#### Allowed hosts
8282

83-
By default, the server only allows requests with the host header set to localhost:3284. If you'd like to host AgentAPI elsewhere, you can change this by using the `AGENTAPI_ALLOWED_HOSTS` environment variable or the `--allowed-hosts` flag.
83+
By default, the server only allows requests with the host header set to `localhost`. If you'd like to host AgentAPI elsewhere, you can change this by using the `AGENTAPI_ALLOWED_HOSTS` environment variable or the `--allowed-hosts` flag. Hosts must be hostnames only (no ports); the server ignores the port portion of incoming requests when authorizing.
8484

8585
To allow requests from any host, use `*` as the allowed host.
8686

cmd/server/server.go

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"log/slog"
88
"net/http"
9+
"net/url"
910
"os"
1011
"sort"
1112
"strings"
@@ -59,10 +60,63 @@ func parseAgentType(firstArg string, agentTypeVar string) (AgentType, error) {
5960
return AgentTypeCustom, nil
6061
}
6162

62-
// Validate allowed hosts or origins don't contain whitespace or commas.
63+
// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
6364
// Viper/Cobra use different separators (space for env vars, comma for flags),
6465
// so these characters likely indicate user error.
65-
func validateAllowedHostsOrOrigins(input []string) error {
66+
func validateAllowedHosts(input []string) error {
67+
if len(input) == 0 {
68+
return fmt.Errorf("the list must not be empty")
69+
}
70+
// First pass: whitespace & comma checks (surface these errors first)
71+
for _, item := range input {
72+
for _, r := range item {
73+
if unicode.IsSpace(r) {
74+
return fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item)
75+
}
76+
}
77+
if strings.Contains(item, ",") {
78+
return fmt.Errorf("'%s' contains comma characters, which are not allowed", item)
79+
}
80+
}
81+
// Second pass: scheme check
82+
for _, item := range input {
83+
if strings.Contains(item, "http://") || strings.Contains(item, "https://") {
84+
return fmt.Errorf("'%s' must not include http:// or https://", item)
85+
}
86+
}
87+
// Third pass: port check (but allow IPv6 literals without ports)
88+
for _, item := range input {
89+
trimmed := strings.TrimSpace(item)
90+
colonCount := strings.Count(trimmed, ":")
91+
// If bracketed, rely on url.Parse to detect a port in "]:<port>" form.
92+
if strings.HasPrefix(trimmed, "[") {
93+
if u, err := url.Parse("http://" + trimmed); err == nil {
94+
if u.Port() != "" {
95+
return fmt.Errorf("'%s' must not include a port", item)
96+
}
97+
}
98+
continue
99+
}
100+
// Unbracketed IPv6: multiple colons and no brackets; treat as valid (no ports allowed here)
101+
if colonCount >= 2 {
102+
continue
103+
}
104+
// IPv4 or hostname: if URL parsing finds a port or there's a single colon, it's invalid
105+
if u, err := url.Parse("http://" + trimmed); err == nil {
106+
if u.Port() != "" {
107+
return fmt.Errorf("'%s' must not include a port", item)
108+
}
109+
}
110+
if colonCount == 1 {
111+
return fmt.Errorf("'%s' must not include a port", item)
112+
}
113+
}
114+
return nil
115+
}
116+
117+
// Validate allowed origins don't contain whitespace or commas.
118+
// Origins must include a scheme, validated later by the HTTP layer.
119+
func validateAllowedOrigins(input []string) error {
66120
if len(input) == 0 {
67121
return fmt.Errorf("the list must not be empty")
68122
}
@@ -195,11 +249,11 @@ func CreateServerCmd() *cobra.Command {
195249
Args: cobra.MinimumNArgs(1),
196250
PreRunE: func(cmd *cobra.Command, args []string) error {
197251
allowedHosts := viper.GetStringSlice(FlagAllowedHosts)
198-
if err := validateAllowedHostsOrOrigins(allowedHosts); err != nil {
252+
if err := validateAllowedHosts(allowedHosts); err != nil {
199253
return xerrors.Errorf("failed to validate allowed hosts: %w", err)
200254
}
201255
allowedOrigins := viper.GetStringSlice(FlagAllowedOrigins)
202-
if err := validateAllowedHostsOrOrigins(allowedOrigins); err != nil {
256+
if err := validateAllowedOrigins(allowedOrigins); err != nil {
203257
return xerrors.Errorf("failed to validate allowed origins: %w", err)
204258
}
205259
return nil
@@ -225,8 +279,8 @@ func CreateServerCmd() *cobra.Command {
225279
{FlagChatBasePath, "c", "/chat", "Base path for assets and routes used in the static files of the chat interface", "string"},
226280
{FlagTermWidth, "W", uint16(80), "Width of the emulated terminal", "uint16"},
227281
{FlagTermHeight, "H", uint16(1000), "Height of the emulated terminal", "uint16"},
228-
// localhost:3284 is the default host for the server
229-
{FlagAllowedHosts, "a", []string{"localhost:3284"}, "HTTP allowed hosts. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
282+
// localhost is the default host for the server. Port is ignored during matching.
283+
{FlagAllowedHosts, "a", []string{"localhost"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
230284
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
231285
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
232286
}

cmd/server/server_test.go

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ func TestServerCmd_AllArgs_Defaults(t *testing.T) {
155155
{"chat-base-path default", FlagChatBasePath, "/chat", func() any { return viper.GetString(FlagChatBasePath) }},
156156
{"term-width default", FlagTermWidth, uint16(80), func() any { return viper.GetUint16(FlagTermWidth) }},
157157
{"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }},
158-
{"allowed-hosts default", FlagAllowedHosts, []string{"localhost:3284"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
158+
{"allowed-hosts default", FlagAllowedHosts, []string{"localhost"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
159159
{"allowed-origins default", FlagAllowedOrigins, []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }},
160160
}
161161

@@ -189,7 +189,7 @@ func TestServerCmd_AllEnvVars(t *testing.T) {
189189
{"AGENTAPI_CHAT_BASE_PATH", "AGENTAPI_CHAT_BASE_PATH", "/api", "/api", func() any { return viper.GetString(FlagChatBasePath) }},
190190
{"AGENTAPI_TERM_WIDTH", "AGENTAPI_TERM_WIDTH", "120", uint16(120), func() any { return viper.GetUint16(FlagTermWidth) }},
191191
{"AGENTAPI_TERM_HEIGHT", "AGENTAPI_TERM_HEIGHT", "500", uint16(500), func() any { return viper.GetUint16(FlagTermHeight) }},
192-
{"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost:3284 localhost:3285", []string{"localhost:3284", "localhost:3285"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
192+
{"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost example.com", []string{"localhost", "example.com"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
193193
{"AGENTAPI_ALLOWED_ORIGINS", "AGENTAPI_ALLOWED_ORIGINS", "https://example.com http://localhost:3000", []string{"https://example.com", "http://localhost:3000"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }},
194194
}
195195

@@ -325,66 +325,110 @@ func TestServerCmd_AllowedHosts(t *testing.T) {
325325
// Environment variable scenarios (space-separated format)
326326
{
327327
name: "env: single valid host",
328-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284"},
328+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"},
329329
args: []string{},
330-
expected: []string{"localhost:3284"},
330+
expected: []string{"localhost"},
331331
},
332332
{
333333
name: "env: multiple valid hosts space-separated",
334-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284 example.com 192.168.1.1:8080"},
334+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost example.com 192.168.1.1"},
335335
args: []string{},
336-
expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"},
336+
expected: []string{"localhost", "example.com", "192.168.1.1"},
337337
},
338338
{
339339
name: "env: host with tab",
340-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284\texample.com"},
340+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost\texample.com"},
341341
args: []string{},
342-
expected: []string{"localhost:3284", "example.com"},
342+
expected: []string{"localhost", "example.com"},
343343
},
344344
{
345345
name: "env: host with comma (invalid)",
346346
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284,example.com"},
347347
args: []string{},
348348
expectedErr: "contains comma characters",
349349
},
350+
{
351+
name: "env: host with port (invalid)",
352+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284"},
353+
args: []string{},
354+
expectedErr: "must not include a port",
355+
},
356+
{
357+
name: "env: ipv6 literal",
358+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "2001:db8::1"},
359+
args: []string{},
360+
expected: []string{"2001:db8::1"},
361+
},
362+
{
363+
name: "env: ipv6 bracketed literal",
364+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "[2001:db8::1]"},
365+
args: []string{},
366+
expected: []string{"[2001:db8::1]"},
367+
},
368+
{
369+
name: "env: ipv6 with port (invalid)",
370+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "[2001:db8::1]:443"},
371+
args: []string{},
372+
expectedErr: "must not include a port",
373+
},
350374

351375
// CLI flag scenarios (comma-separated format)
352376
{
353377
name: "flag: single valid host",
354-
args: []string{"--allowed-hosts", "localhost:3284"},
355-
expected: []string{"localhost:3284"},
378+
args: []string{"--allowed-hosts", "localhost"},
379+
expected: []string{"localhost"},
356380
},
357381
{
358382
name: "flag: multiple valid hosts comma-separated",
359-
args: []string{"--allowed-hosts", "localhost:3284,example.com,192.168.1.1:8080"},
360-
expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"},
383+
args: []string{"--allowed-hosts", "localhost,example.com,192.168.1.1"},
384+
expected: []string{"localhost", "example.com", "192.168.1.1"},
361385
},
362386
{
363387
name: "flag: multiple valid hosts with multiple flags",
364-
args: []string{"--allowed-hosts", "localhost:3284", "--allowed-hosts", "example.com"},
365-
expected: []string{"localhost:3284", "example.com"},
388+
args: []string{"--allowed-hosts", "localhost", "--allowed-hosts", "example.com"},
389+
expected: []string{"localhost", "example.com"},
366390
},
367391
{
368392
name: "flag: host with newline",
369-
args: []string{"--allowed-hosts", "localhost:3284\n"},
370-
expected: []string{"localhost:3284"},
393+
args: []string{"--allowed-hosts", "localhost\n"},
394+
expected: []string{"localhost"},
371395
},
372396
{
373397
name: "flag: host with space in comma-separated list (invalid)",
374398
args: []string{"--allowed-hosts", "localhost:3284,example .com"},
375399
expectedErr: "contains whitespace characters",
376400
},
401+
{
402+
name: "flag: host with port (invalid)",
403+
args: []string{"--allowed-hosts", "localhost:3284"},
404+
expectedErr: "must not include a port",
405+
},
406+
{
407+
name: "flag: ipv6 literal",
408+
args: []string{"--allowed-hosts", "2001:db8::1"},
409+
expected: []string{"2001:db8::1"},
410+
},
411+
{
412+
name: "flag: ipv6 bracketed literal",
413+
args: []string{"--allowed-hosts", "[2001:db8::1]"},
414+
expected: []string{"[2001:db8::1]"},
415+
},
416+
{
417+
name: "flag: ipv6 with port (invalid)",
418+
args: []string{"--allowed-hosts", "[2001:db8::1]:443"},
419+
expectedErr: "must not include a port",
420+
},
377421

378422
// Mixed scenarios (env + flag precedence)
379423
{
380424
name: "mixed: flag overrides env",
381-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"},
425+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"},
382426
args: []string{"--allowed-hosts", "override.com"},
383427
expected: []string{"override.com"},
384428
},
385429
{
386430
name: "mixed: flag overrides env but flag is invalid",
387-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"},
431+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"},
388432
args: []string{"--allowed-hosts", "invalid .com"},
389433
expectedErr: "contains whitespace characters",
390434
},
@@ -400,7 +444,7 @@ func TestServerCmd_AllowedHosts(t *testing.T) {
400444
{
401445
name: "default hosts when neither env nor flag provided",
402446
args: []string{},
403-
expected: []string{"localhost:3284"},
447+
expected: []string{"localhost"},
404448
},
405449
}
406450

go.mod

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ require (
5454
github.com/rogpeppe/go-internal v1.14.1 // indirect
5555
github.com/spf13/afero v1.14.0
5656
github.com/spf13/pflag v1.0.6 // indirect
57-
github.com/unrolled/secure v1.17.0
5857
golang.org/x/sync v0.12.0 // indirect
5958
golang.org/x/sys v0.31.0 // indirect
6059
golang.org/x/text v0.23.0 // indirect

go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8
113113
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
114114
github.com/tmaxmax/go-sse v0.10.0 h1:j9F93WB4Hxt8wUf6oGffMm4dutALvUPoDDxfuDQOSqA=
115115
github.com/tmaxmax/go-sse v0.10.0/go.mod h1:u/2kZQR1tyngo1lKaNCj1mJmhXGZWS1Zs5yiSOD+Eg8=
116-
github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU=
117-
github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40=
118116
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
119117
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
120118
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=

lib/httpapi/server.go

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"log/slog"
8+
"net"
89
"net/http"
910
"net/url"
1011
"slices"
@@ -21,7 +22,6 @@ import (
2122
"github.com/danielgtaylor/huma/v2/sse"
2223
"github.com/go-chi/chi/v5"
2324
"github.com/go-chi/cors"
24-
"github.com/unrolled/secure"
2525
"golang.org/x/xerrors"
2626
)
2727

@@ -82,7 +82,32 @@ func parseAllowedHosts(hosts []string) ([]string, error) {
8282
return nil, xerrors.Errorf("host must not contain http:// or https://: %q", host)
8383
}
8484
}
85-
return hosts, nil
85+
// Normalize hosts to bare hostnames/IPs by stripping any port and brackets.
86+
// This ensures allowed entries match the Host header hostname only.
87+
normalized := make([]string, 0, len(hosts))
88+
for _, raw := range hosts {
89+
h := strings.TrimSpace(raw)
90+
// If it's an IPv6 literal (possibly bracketed) without an obvious port, keep the literal.
91+
unbracketed := strings.Trim(h, "[]")
92+
if ip := net.ParseIP(unbracketed); ip != nil {
93+
// It's an IP literal; use the bare form without brackets.
94+
normalized = append(normalized, unbracketed)
95+
continue
96+
}
97+
// If likely host:port (single colon) or bracketed host, use url.Parse to extract hostname.
98+
if strings.Count(h, ":") == 1 || (strings.HasPrefix(h, "[") && strings.Contains(h, "]")) {
99+
if u, err := url.Parse("http://" + h); err == nil {
100+
hn := u.Hostname()
101+
if hn != "" {
102+
normalized = append(normalized, hn)
103+
continue
104+
}
105+
}
106+
}
107+
// Fallback: use as-is (e.g., hostname without port)
108+
normalized = append(normalized, h)
109+
}
110+
return normalized, nil
86111
}
87112

88113
func parseAllowedOrigins(origins []string) ([]string, error) {
@@ -116,14 +141,11 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
116141
}
117142
logger.Info(fmt.Sprintf("Allowed origins: %s", strings.Join(allowedOrigins, ", ")))
118143

119-
secureMiddleware := secure.New(secure.Options{
120-
AllowedHosts: allowedHosts,
121-
})
144+
// Enforce allowed hosts in a custom middleware that ignores the port during matching.
122145
badHostHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
123146
http.Error(w, "Invalid host header. Allowed hosts: "+strings.Join(allowedHosts, ", "), http.StatusBadRequest)
124147
})
125-
secureMiddleware.SetBadHostHandler(badHostHandler)
126-
router.Use(secureMiddleware.Handler)
148+
router.Use(hostAuthorizationMiddleware(allowedHosts, badHostHandler))
127149

128150
corsMiddleware := cors.New(cors.Options{
129151
AllowedOrigins: allowedOrigins,
@@ -174,6 +196,39 @@ func (s *Server) Handler() http.Handler {
174196
return s.router
175197
}
176198

199+
// hostAuthorizationMiddleware enforces that the request Host header matches one of the allowed
200+
// hosts, ignoring any port in the comparison. If allowedHosts is empty, all hosts are allowed.
201+
// Always uses url.Parse("http://" + r.Host) to robustly extract the hostname (handles IPv6).
202+
func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Handler) func(next http.Handler) http.Handler {
203+
// Copy for safety; also build a map for O(1) lookups with case-insensitive keys.
204+
allowed := make(map[string]struct{}, len(allowedHosts))
205+
for _, h := range allowedHosts {
206+
allowed[strings.ToLower(h)] = struct{}{}
207+
}
208+
return func(next http.Handler) http.Handler {
209+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
210+
if len(allowedHosts) == 0 { // wildcard semantics: allow all
211+
next.ServeHTTP(w, r)
212+
return
213+
}
214+
// Extract hostname from the Host header using url.Parse; ignore any port.
215+
hostHeader := r.Host
216+
if hostHeader == "" {
217+
badHostHandler.ServeHTTP(w, r)
218+
return
219+
}
220+
if u, err := url.Parse("http://" + hostHeader); err == nil {
221+
hostname := u.Hostname()
222+
if _, ok := allowed[strings.ToLower(hostname)]; ok {
223+
next.ServeHTTP(w, r)
224+
return
225+
}
226+
}
227+
badHostHandler.ServeHTTP(w, r)
228+
})
229+
}
230+
}
231+
177232
func (s *Server) StartSnapshotLoop(ctx context.Context) {
178233
s.conversation.StartSnapshotLoop(ctx)
179234
go func() {

0 commit comments

Comments
 (0)