Skip to content

feat(http): optional X-Forwarded-Host support for host authorization #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: hugodutka/allowed-hosts
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ agentapi server --allowed-hosts 'example.com,example.org' -- claude
AGENTAPI_ALLOWED_HOSTS='example.com example.org' agentapi server -- claude
```

If you're running behind a trusted reverse proxy that sets the `X-Forwarded-Host` header, you can opt in to using that header for host authorization with `--use-x-forwarded-host` (or `AGENTAPI_USE_X_FORWARDED_HOST=true`). When enabled, the server prefers the first `X-Forwarded-Host` value, and matches it against the allowed host list. Leave this disabled unless your deployment terminates at a trusted proxy.

#### Allowed origins

By default, the server allows CORS requests from `http://localhost:3284`, `http://localhost:3000`, and `http://localhost:3001`. If you'd like to change which origins can make cross-origin requests to AgentAPI, you can change this by using the `AGENTAPI_ALLOWED_ORIGINS` environment variable or the `--allowed-origins` flag.
Expand Down
33 changes: 18 additions & 15 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
}
port := viper.GetInt(FlagPort)
srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{
AgentType: agentType,
Process: process,
Port: port,
ChatBasePath: viper.GetString(FlagChatBasePath),
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
AgentType: agentType,
Process: process,
Port: port,
ChatBasePath: viper.GetString(FlagChatBasePath),
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
UseXForwardedHost: viper.GetBool(FlagUseXForwardedHost),
})
if err != nil {
return xerrors.Errorf("failed to create server: %w", err)
Expand Down Expand Up @@ -230,15 +231,16 @@ type flagSpec struct {
}

const (
FlagType = "type"
FlagPort = "port"
FlagPrintOpenAPI = "print-openapi"
FlagChatBasePath = "chat-base-path"
FlagTermWidth = "term-width"
FlagTermHeight = "term-height"
FlagAllowedHosts = "allowed-hosts"
FlagAllowedOrigins = "allowed-origins"
FlagExit = "exit"
FlagType = "type"
FlagPort = "port"
FlagPrintOpenAPI = "print-openapi"
FlagChatBasePath = "chat-base-path"
FlagTermWidth = "term-width"
FlagTermHeight = "term-height"
FlagAllowedHosts = "allowed-hosts"
FlagAllowedOrigins = "allowed-origins"
FlagUseXForwardedHost = "use-x-forwarded-host"
FlagExit = "exit"
)

func CreateServerCmd() *cobra.Command {
Expand Down Expand Up @@ -283,6 +285,7 @@ func CreateServerCmd() *cobra.Command {
{FlagAllowedHosts, "a", []string{"localhost"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
{FlagUseXForwardedHost, "", false, "Use X-Forwarded-Host header for host authorization (behind trusted proxies)", "bool"},
}

for _, spec := range flagSpecs {
Expand Down
9 changes: 9 additions & 0 deletions cmd/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ func TestServerCmd_AllArgs_Defaults(t *testing.T) {
{"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }},
{"allowed-hosts default", FlagAllowedHosts, []string{"localhost"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
{"allowed-origins default", FlagAllowedOrigins, []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }},
{"use-x-forwarded-host default", FlagUseXForwardedHost, false, func() any { return viper.GetBool(FlagUseXForwardedHost) }},
}

for _, tt := range tests {
Expand Down Expand Up @@ -191,6 +192,7 @@ func TestServerCmd_AllEnvVars(t *testing.T) {
{"AGENTAPI_TERM_HEIGHT", "AGENTAPI_TERM_HEIGHT", "500", uint16(500), func() any { return viper.GetUint16(FlagTermHeight) }},
{"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost example.com", []string{"localhost", "example.com"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
{"AGENTAPI_ALLOWED_ORIGINS", "AGENTAPI_ALLOWED_ORIGINS", "https://example.com http://localhost:3000", []string{"https://example.com", "http://localhost:3000"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }},
{"AGENTAPI_USE_X_FORWARDED_HOST", "AGENTAPI_USE_X_FORWARDED_HOST", "true", true, func() any { return viper.GetBool(FlagUseXForwardedHost) }},
}

for _, tt := range tests {
Expand Down Expand Up @@ -268,6 +270,13 @@ func TestServerCmd_ArgsPrecedenceOverEnv(t *testing.T) {
[]string{"https://cli-example.com"},
func() any { return viper.GetStringSlice(FlagAllowedOrigins) },
},
{
"use-x-forwarded-host: CLI overrides env",
"AGENTAPI_USE_X_FORWARDED_HOST", "false",
[]string{"--use-x-forwarded-host"},
true,
func() any { return viper.GetBool(FlagUseXForwardedHost) },
},
}

for _, tt := range tests {
Expand Down
39 changes: 26 additions & 13 deletions lib/httpapi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ func (s *Server) GetOpenAPI() string {
const snapshotInterval = 25 * time.Millisecond

type ServerConfig struct {
AgentType mf.AgentType
Process *termexec.Process
Port int
ChatBasePath string
AllowedHosts []string
AllowedOrigins []string
AgentType mf.AgentType
Process *termexec.Process
Port int
ChatBasePath string
AllowedHosts []string
AllowedOrigins []string
UseXForwardedHost bool
}

func parseAllowedHosts(hosts []string) ([]string, error) {
Expand Down Expand Up @@ -145,7 +146,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
badHostHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Invalid host header. Allowed hosts: "+strings.Join(allowedHosts, ", "), http.StatusBadRequest)
})
router.Use(hostAuthorizationMiddleware(allowedHosts, badHostHandler))
router.Use(hostAuthorizationMiddleware(allowedHosts, config.UseXForwardedHost, badHostHandler))

corsMiddleware := cors.New(cors.Options{
AllowedOrigins: allowedOrigins,
Expand Down Expand Up @@ -198,8 +199,9 @@ func (s *Server) Handler() http.Handler {

// hostAuthorizationMiddleware enforces that the request Host header matches one of the allowed
// hosts, ignoring any port in the comparison. If allowedHosts is empty, all hosts are allowed.
// Always uses url.Parse("http://" + r.Host) to robustly extract the hostname (handles IPv6).
func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Handler) func(next http.Handler) http.Handler {
// If useXForwardedHost is true and the X-Forwarded-Host header is present, that header is used
// as the source of host. Hostname is extracted via url.Parse to handle IPv6 and strip ports.
func hostAuthorizationMiddleware(allowedHosts []string, useXForwardedHost bool, badHostHandler http.Handler) func(next http.Handler) http.Handler {
// Copy for safety; also build a map for O(1) lookups with case-insensitive keys.
allowed := make(map[string]struct{}, len(allowedHosts))
for _, h := range allowedHosts {
Expand All @@ -211,13 +213,24 @@ func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Hand
next.ServeHTTP(w, r)
return
}
// Extract hostname from the Host header using url.Parse; ignore any port.
hostHeader := r.Host
if hostHeader == "" {
// Choose header source
rawHost := r.Host
if useXForwardedHost {
if xfhs := r.Header.Values("X-Forwarded-Host"); len(xfhs) > 0 {
// Use the first value and trim anything after a comma
h := xfhs[0]
if idx := strings.IndexByte(h, ','); idx >= 0 {
h = h[:idx]
}
rawHost = strings.TrimSpace(h)
}
}
if rawHost == "" {
badHostHandler.ServeHTTP(w, r)
return
}
if u, err := url.Parse("http://" + hostHeader); err == nil {
// Extract hostname via url.Parse; ignore any port.
if u, err := url.Parse("http://" + rawHost); err == nil {
hostname := u.Hostname()
if _, ok := allowed[strings.ToLower(hostname)]; ok {
next.ServeHTTP(w, r)
Expand Down
91 changes: 91 additions & 0 deletions lib/httpapi/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,97 @@ func TestServer_AllowedHosts(t *testing.T) {
}
}

func TestServer_UseXForwardedHost(t *testing.T) {
cases := []struct {
name string
allowedHosts []string
useXForwardedHost bool
hostHeader string
xForwardedHostHeader string
expectedStatusCode int
expectedErrorMsg string
}{
{
name: "disabled flag ignores X-Forwarded-Host",
allowedHosts: []string{"app.example.com"},
useXForwardedHost: false,
hostHeader: "malicious.com",
xForwardedHostHeader: "app.example.com",
expectedStatusCode: http.StatusBadRequest,
expectedErrorMsg: "Invalid host header. Allowed hosts: app.example.com",
},
{
name: "enabled flag uses X-Forwarded-Host",
allowedHosts: []string{"app.example.com"},
useXForwardedHost: true,
hostHeader: "malicious.com",
xForwardedHostHeader: "app.example.com",
expectedStatusCode: http.StatusOK,
},
{
name: "enabled with port in X-Forwarded-Host",
allowedHosts: []string{"app.example.com"},
useXForwardedHost: true,
hostHeader: "malicious.com",
xForwardedHostHeader: "app.example.com:443",
expectedStatusCode: http.StatusOK,
},
{
name: "enabled with IPv6 literal in X-Forwarded-Host",
allowedHosts: []string{"2001:db8::1"},
useXForwardedHost: true,
hostHeader: "malicious.com",
xForwardedHostHeader: "[2001:db8::1]:8443",
expectedStatusCode: http.StatusOK,
},
{
name: "enabled with comma-separated X-Forwarded-Host takes first",
allowedHosts: []string{"first.example.com"},
useXForwardedHost: true,
hostHeader: "malicious.com",
xForwardedHostHeader: "first.example.com, other.example.com",
expectedStatusCode: http.StatusOK,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil)))
s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{
AgentType: msgfmt.AgentTypeClaude,
Process: nil,
Port: 0,
ChatBasePath: "/chat",
AllowedHosts: tc.allowedHosts,
AllowedOrigins: []string{"https://example.com"}, // isolate
UseXForwardedHost: tc.useXForwardedHost,
})
require.NoError(t, err)
tsServer := httptest.NewServer(s.Handler())
t.Cleanup(tsServer.Close)

req, err := http.NewRequest("GET", tsServer.URL+"/status", nil)
require.NoError(t, err)
if tc.hostHeader != "" {
req.Host = tc.hostHeader
}
if tc.xForwardedHostHeader != "" {
req.Header.Set("X-Forwarded-Host", tc.xForwardedHostHeader)
}

resp, err := (&http.Client{}).Do(req)
require.NoError(t, err)
t.Cleanup(func() { _ = resp.Body.Close() })
require.Equal(t, tc.expectedStatusCode, resp.StatusCode)
if tc.expectedErrorMsg != "" {
b, _ := io.ReadAll(resp.Body)
require.Contains(t, string(b), tc.expectedErrorMsg)
}
})
}
}

func TestServer_CORSPreflightWithHosts(t *testing.T) {
cases := []struct {
name string
Expand Down