Skip to content

Commit 6fa56b6

Browse files
committed
feat(http): optional X-Forwarded-Host support for host authorization\n\nAdds --use-x-forwarded-host CLI flag and ServerConfig.UseXForwardedHost.\nWhen enabled, the middleware prefers X-Forwarded-Host (first value, before comma),\nparses and matches hostname ignoring port. Includes tests for enabled/disabled,\nports, IPv6, and comma-separated cases.\n\nCo-authored-by: hugodutka <[email protected]>
1 parent d2400b9 commit 6fa56b6

File tree

3 files changed

+135
-28
lines changed

3 files changed

+135
-28
lines changed

cmd/server/server.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er
171171
}
172172
port := viper.GetInt(FlagPort)
173173
srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{
174-
AgentType: agentType,
175-
Process: process,
176-
Port: port,
177-
ChatBasePath: viper.GetString(FlagChatBasePath),
178-
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
179-
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
174+
AgentType: agentType,
175+
Process: process,
176+
Port: port,
177+
ChatBasePath: viper.GetString(FlagChatBasePath),
178+
AllowedHosts: viper.GetStringSlice(FlagAllowedHosts),
179+
AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins),
180+
UseXForwardedHost: viper.GetBool(FlagUseXForwardedHost),
180181
})
181182
if err != nil {
182183
return xerrors.Errorf("failed to create server: %w", err)
@@ -230,15 +231,16 @@ type flagSpec struct {
230231
}
231232

232233
const (
233-
FlagType = "type"
234-
FlagPort = "port"
235-
FlagPrintOpenAPI = "print-openapi"
236-
FlagChatBasePath = "chat-base-path"
237-
FlagTermWidth = "term-width"
238-
FlagTermHeight = "term-height"
239-
FlagAllowedHosts = "allowed-hosts"
240-
FlagAllowedOrigins = "allowed-origins"
241-
FlagExit = "exit"
234+
FlagType = "type"
235+
FlagPort = "port"
236+
FlagPrintOpenAPI = "print-openapi"
237+
FlagChatBasePath = "chat-base-path"
238+
FlagTermWidth = "term-width"
239+
FlagTermHeight = "term-height"
240+
FlagAllowedHosts = "allowed-hosts"
241+
FlagAllowedOrigins = "allowed-origins"
242+
FlagUseXForwardedHost = "use-x-forwarded-host"
243+
FlagExit = "exit"
242244
)
243245

244246
func CreateServerCmd() *cobra.Command {
@@ -283,6 +285,7 @@ func CreateServerCmd() *cobra.Command {
283285
{FlagAllowedHosts, "a", []string{"localhost"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
284286
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
285287
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
288+
{FlagUseXForwardedHost, "", false, "Use X-Forwarded-Host header for host authorization (behind trusted proxies)", "bool"},
286289
}
287290

288291
for _, spec := range flagSpecs {

lib/httpapi/server.go

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,13 @@ func (s *Server) GetOpenAPI() string {
6262
const snapshotInterval = 25 * time.Millisecond
6363

6464
type ServerConfig struct {
65-
AgentType mf.AgentType
66-
Process *termexec.Process
67-
Port int
68-
ChatBasePath string
69-
AllowedHosts []string
70-
AllowedOrigins []string
65+
AgentType mf.AgentType
66+
Process *termexec.Process
67+
Port int
68+
ChatBasePath string
69+
AllowedHosts []string
70+
AllowedOrigins []string
71+
UseXForwardedHost bool
7172
}
7273

7374
func parseAllowedHosts(hosts []string) ([]string, error) {
@@ -145,7 +146,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
145146
badHostHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
146147
http.Error(w, "Invalid host header. Allowed hosts: "+strings.Join(allowedHosts, ", "), http.StatusBadRequest)
147148
})
148-
router.Use(hostAuthorizationMiddleware(allowedHosts, badHostHandler))
149+
router.Use(hostAuthorizationMiddleware(allowedHosts, config.UseXForwardedHost, badHostHandler))
149150

150151
corsMiddleware := cors.New(cors.Options{
151152
AllowedOrigins: allowedOrigins,
@@ -198,8 +199,9 @@ func (s *Server) Handler() http.Handler {
198199

199200
// hostAuthorizationMiddleware enforces that the request Host header matches one of the allowed
200201
// hosts, ignoring any port in the comparison. If allowedHosts is empty, all hosts are allowed.
201-
// Always uses url.Parse("http://" + r.Host) to robustly extract the hostname (handles IPv6).
202-
func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Handler) func(next http.Handler) http.Handler {
202+
// If useXForwardedHost is true and the X-Forwarded-Host header is present, that header is used
203+
// as the source of host. Hostname is extracted via url.Parse to handle IPv6 and strip ports.
204+
func hostAuthorizationMiddleware(allowedHosts []string, useXForwardedHost bool, badHostHandler http.Handler) func(next http.Handler) http.Handler {
203205
// Copy for safety; also build a map for O(1) lookups with case-insensitive keys.
204206
allowed := make(map[string]struct{}, len(allowedHosts))
205207
for _, h := range allowedHosts {
@@ -211,13 +213,24 @@ func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Hand
211213
next.ServeHTTP(w, r)
212214
return
213215
}
214-
// Extract hostname from the Host header using url.Parse; ignore any port.
215-
hostHeader := r.Host
216-
if hostHeader == "" {
216+
// Choose header source
217+
rawHost := r.Host
218+
if useXForwardedHost {
219+
if xfhs := r.Header.Values("X-Forwarded-Host"); len(xfhs) > 0 {
220+
// Use the first value and trim anything after a comma
221+
h := xfhs[0]
222+
if idx := strings.IndexByte(h, ','); idx >= 0 {
223+
h = h[:idx]
224+
}
225+
rawHost = strings.TrimSpace(h)
226+
}
227+
}
228+
if rawHost == "" {
217229
badHostHandler.ServeHTTP(w, r)
218230
return
219231
}
220-
if u, err := url.Parse("http://" + hostHeader); err == nil {
232+
// Extract hostname via url.Parse; ignore any port.
233+
if u, err := url.Parse("http://" + rawHost); err == nil {
221234
hostname := u.Hostname()
222235
if _, ok := allowed[strings.ToLower(hostname)]; ok {
223236
next.ServeHTTP(w, r)

lib/httpapi/server_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,97 @@ func TestServer_AllowedHosts(t *testing.T) {
241241
}
242242
}
243243

244+
func TestServer_UseXForwardedHost(t *testing.T) {
245+
cases := []struct {
246+
name string
247+
allowedHosts []string
248+
useXForwardedHost bool
249+
hostHeader string
250+
xForwardedHostHeader string
251+
expectedStatusCode int
252+
expectedErrorMsg string
253+
}{
254+
{
255+
name: "disabled flag ignores X-Forwarded-Host",
256+
allowedHosts: []string{"app.example.com"},
257+
useXForwardedHost: false,
258+
hostHeader: "malicious.com",
259+
xForwardedHostHeader: "app.example.com",
260+
expectedStatusCode: http.StatusBadRequest,
261+
expectedErrorMsg: "Invalid host header. Allowed hosts: app.example.com",
262+
},
263+
{
264+
name: "enabled flag uses X-Forwarded-Host",
265+
allowedHosts: []string{"app.example.com"},
266+
useXForwardedHost: true,
267+
hostHeader: "malicious.com",
268+
xForwardedHostHeader: "app.example.com",
269+
expectedStatusCode: http.StatusOK,
270+
},
271+
{
272+
name: "enabled with port in X-Forwarded-Host",
273+
allowedHosts: []string{"app.example.com"},
274+
useXForwardedHost: true,
275+
hostHeader: "malicious.com",
276+
xForwardedHostHeader: "app.example.com:443",
277+
expectedStatusCode: http.StatusOK,
278+
},
279+
{
280+
name: "enabled with IPv6 literal in X-Forwarded-Host",
281+
allowedHosts: []string{"2001:db8::1"},
282+
useXForwardedHost: true,
283+
hostHeader: "malicious.com",
284+
xForwardedHostHeader: "[2001:db8::1]:8443",
285+
expectedStatusCode: http.StatusOK,
286+
},
287+
{
288+
name: "enabled with comma-separated X-Forwarded-Host takes first",
289+
allowedHosts: []string{"first.example.com"},
290+
useXForwardedHost: true,
291+
hostHeader: "malicious.com",
292+
xForwardedHostHeader: "first.example.com, other.example.com",
293+
expectedStatusCode: http.StatusOK,
294+
},
295+
}
296+
297+
for _, tc := range cases {
298+
t.Run(tc.name, func(t *testing.T) {
299+
t.Parallel()
300+
ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil)))
301+
s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{
302+
AgentType: msgfmt.AgentTypeClaude,
303+
Process: nil,
304+
Port: 0,
305+
ChatBasePath: "/chat",
306+
AllowedHosts: tc.allowedHosts,
307+
AllowedOrigins: []string{"https://example.com"}, // isolate
308+
UseXForwardedHost: tc.useXForwardedHost,
309+
})
310+
require.NoError(t, err)
311+
tsServer := httptest.NewServer(s.Handler())
312+
t.Cleanup(tsServer.Close)
313+
314+
req, err := http.NewRequest("GET", tsServer.URL+"/status", nil)
315+
require.NoError(t, err)
316+
if tc.hostHeader != "" {
317+
req.Host = tc.hostHeader
318+
}
319+
if tc.xForwardedHostHeader != "" {
320+
req.Header.Set("X-Forwarded-Host", tc.xForwardedHostHeader)
321+
}
322+
323+
resp, err := (&http.Client{}).Do(req)
324+
require.NoError(t, err)
325+
t.Cleanup(func() { _ = resp.Body.Close() })
326+
require.Equal(t, tc.expectedStatusCode, resp.StatusCode)
327+
if tc.expectedErrorMsg != "" {
328+
b, _ := io.ReadAll(resp.Body)
329+
require.Contains(t, string(b), tc.expectedErrorMsg)
330+
}
331+
})
332+
}
333+
}
334+
244335
func TestServer_CORSPreflightWithHosts(t *testing.T) {
245336
cases := []struct {
246337
name string

0 commit comments

Comments
 (0)