diff --git a/README.md b/README.md index 77c571f..c4c39bc 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ Control [Claude Code](https://github.com/anthropics/claude-code), [Goose](https: ![agentapi-chat](https://github.com/user-attachments/assets/57032c9f-4146-4b66-b219-09e38ab7690d) - You can use AgentAPI: - to build a unified chat interface for coding agents @@ -54,9 +53,6 @@ You can use AgentAPI: Run an HTTP server that lets you control an agent. If you'd like to start an agent with additional arguments, pass the full agent command after the `--` flag. -> [!NOTE] -> When using Codex, always specify the agent type explicitly (`agentapi server --type=codex -- codex`), or message formatting may break. - ```bash agentapi server -- claude --allowedTools "Bash(git*) Edit Replace" ``` @@ -68,6 +64,9 @@ agentapi server -- aider --model sonnet --api-key anthropic=sk-ant-apio3-XXX agentapi server -- goose ``` +> [!NOTE] +> When using Codex, always specify the agent type explicitly (`agentapi server --type=codex -- codex`), or message formatting may break. + An OpenAPI schema is available in [openapi.json](openapi.json). By default, the server runs on port 3284. Additionally, the server exposes the same OpenAPI schema at http://localhost:3284/openapi.json and the available endpoints in a documentation UI at http://localhost:3284/docs. @@ -79,6 +78,54 @@ There are 4 endpoints: - GET `/status` - returns the current status of the agent, either "stable" or "running" - GET `/events` - an SSE stream of events from the agent: message and status updates +#### Allowed hosts + +By default, the server only allows requests with the host header set to `localhost`. If you'd like to host AgentAPI elsewhere, you can change this by using the `AGENTAPI_ALLOWED_HOSTS` environment variable or the `--allowed-hosts` flag. Hosts must be hostnames only (no ports); the server ignores the port portion of incoming requests when authorizing. + +To allow requests from any host, use `*` as the allowed host. + +```bash +agentapi server --allowed-hosts '*' -- claude +``` + +To allow a specific host, use: + +```bash +agentapi server --allowed-hosts 'example.com' -- claude +``` + +To specify multiple hosts, use a comma-separated list when using the `--allowed-hosts` flag, or a space-separated list when using the `AGENTAPI_ALLOWED_HOSTS` environment variable. + +```bash +agentapi server --allowed-hosts 'example.com,example.org' -- claude +# or +AGENTAPI_ALLOWED_HOSTS='example.com example.org' agentapi server -- claude +``` + +#### Allowed origins + +By default, the server allows CORS requests from `http://localhost:3284`, `http://localhost:3000`, and `http://localhost:3001`. If you'd like to change which origins can make cross-origin requests to AgentAPI, you can change this by using the `AGENTAPI_ALLOWED_ORIGINS` environment variable or the `--allowed-origins` flag. + +To allow requests from any origin, use `*` as the allowed origin: + +```bash +agentapi server --allowed-origins '*' -- claude +``` + +To allow a specific origin, use: + +```bash +agentapi server --allowed-origins 'https://example.com' -- claude +``` + +To specify multiple origins, use a comma-separated list when using the `--allowed-origins` flag, or a space-separated list when using the `AGENTAPI_ALLOWED_ORIGINS` environment variable. Origins must include the protocol (`http://` or `https://`) and support wildcards (e.g., `https://*.example.com`): + +```bash +agentapi server --allowed-origins 'https://example.com,http://localhost:3000' -- claude +# or +AGENTAPI_ALLOWED_ORIGINS='https://example.com http://localhost:3000' agentapi server -- claude +``` + ### `agentapi attach` Attach to a running agent's terminal session. diff --git a/cmd/server/server.go b/cmd/server/server.go index 3313b51..b236532 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -95,12 +95,17 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } } port := viper.GetInt(FlagPort) - srv := httpapi.NewServer(ctx, httpapi.ServerConfig{ - AgentType: agentType, - Process: process, - Port: port, - ChatBasePath: viper.GetString(FlagChatBasePath), + srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: agentType, + Process: process, + Port: port, + ChatBasePath: viper.GetString(FlagChatBasePath), + AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), + AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), }) + if err != nil { + return xerrors.Errorf("failed to create server: %w", err) + } if printOpenAPI { fmt.Println(srv.GetOpenAPI()) return nil @@ -150,12 +155,15 @@ type flagSpec struct { } const ( - FlagType = "type" - FlagPort = "port" - FlagPrintOpenAPI = "print-openapi" - FlagChatBasePath = "chat-base-path" - FlagTermWidth = "term-width" - FlagTermHeight = "term-height" + FlagType = "type" + FlagPort = "port" + FlagPrintOpenAPI = "print-openapi" + FlagChatBasePath = "chat-base-path" + FlagTermWidth = "term-width" + FlagTermHeight = "term-height" + FlagAllowedHosts = "allowed-hosts" + FlagAllowedOrigins = "allowed-origins" + FlagExit = "exit" ) func CreateServerCmd() *cobra.Command { @@ -165,6 +173,10 @@ func CreateServerCmd() *cobra.Command { Long: fmt.Sprintf("Run the server with the specified agent (one of: %s)", strings.Join(agentNames, ", ")), Args: cobra.MinimumNArgs(1), Run: func(cmd *cobra.Command, args []string) { + // The --exit flag is used for testing validation of flags in the test suite + if viper.GetBool(FlagExit) { + return + } logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) ctx := logctx.WithLogger(context.Background(), logger) if err := runServer(ctx, logger, cmd.Flags().Args()); err != nil { @@ -181,6 +193,10 @@ func CreateServerCmd() *cobra.Command { {FlagChatBasePath, "c", "/chat", "Base path for assets and routes used in the static files of the chat interface", "string"}, {FlagTermWidth, "W", uint16(80), "Width of the emulated terminal", "uint16"}, {FlagTermHeight, "H", uint16(1000), "Height of the emulated terminal", "uint16"}, + // localhost is the default host for the server. Port is ignored during matching. + {FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"}, + // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. + {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, } for _, spec := range flagSpecs { @@ -193,6 +209,8 @@ func CreateServerCmd() *cobra.Command { serverCmd.Flags().BoolP(spec.name, spec.shorthand, spec.defaultValue.(bool), spec.usage) case "uint16": serverCmd.Flags().Uint16P(spec.name, spec.shorthand, spec.defaultValue.(uint16), spec.usage) + case "stringSlice": + serverCmd.Flags().StringSliceP(spec.name, spec.shorthand, spec.defaultValue.([]string), spec.usage) default: panic(fmt.Sprintf("unknown flag type: %s", spec.flagType)) } @@ -201,6 +219,14 @@ func CreateServerCmd() *cobra.Command { } } + serverCmd.Flags().Bool(FlagExit, false, "Exit immediately after parsing arguments") + if err := serverCmd.Flags().MarkHidden(FlagExit); err != nil { + panic(fmt.Sprintf("failed to mark flag %s as hidden: %v", FlagExit, err)) + } + if err := viper.BindPFlag(FlagExit, serverCmd.Flags().Lookup(FlagExit)); err != nil { + panic(fmt.Sprintf("failed to bind flag %s: %v", FlagExit, err)) + } + viper.SetEnvPrefix("AGENTAPI") viper.AutomaticEnv() viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) diff --git a/cmd/server/server_test.go b/cmd/server/server_test.go index 59b1ccc..ed88fce 100644 --- a/cmd/server/server_test.go +++ b/cmd/server/server_test.go @@ -12,6 +12,20 @@ import ( "github.com/stretchr/testify/require" ) +type nullWriter struct{} + +func (w *nullWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +// setupCommandOutput configures a cobra command to use a null writer for output capture. +func setupCommandOutput(t *testing.T, cmd *cobra.Command) { + t.Helper() + + cmd.SetOut(&nullWriter{}) + cmd.SetErr(&nullWriter{}) +} + func TestParseAgentType(t *testing.T) { tests := []struct { firstArg string @@ -141,17 +155,17 @@ func TestServerCmd_AllArgs_Defaults(t *testing.T) { {"chat-base-path default", FlagChatBasePath, "/chat", func() any { return viper.GetString(FlagChatBasePath) }}, {"term-width default", FlagTermWidth, uint16(80), func() any { return viper.GetUint16(FlagTermWidth) }}, {"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }}, + {"allowed-hosts default", FlagAllowedHosts, []string{"localhost", "127.0.0.1", "[::1]"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }}, + {"allowed-origins default", FlagAllowedOrigins, []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { isolateViper(t) serverCmd := CreateServerCmd() - cmd := &cobra.Command{} - cmd.AddCommand(serverCmd) - - // Execute with no args to get defaults - serverCmd.SetArgs([]string{"--help"}) // Use help to avoid actual execution + setupCommandOutput(t, serverCmd) + // Execute with --exit to get defaults + serverCmd.SetArgs([]string{"--exit", "dummy-command"}) if err := serverCmd.Execute(); err != nil { t.Fatalf("Failed to execute server command: %v", err) } @@ -175,6 +189,8 @@ func TestServerCmd_AllEnvVars(t *testing.T) { {"AGENTAPI_CHAT_BASE_PATH", "AGENTAPI_CHAT_BASE_PATH", "/api", "/api", func() any { return viper.GetString(FlagChatBasePath) }}, {"AGENTAPI_TERM_WIDTH", "AGENTAPI_TERM_WIDTH", "120", uint16(120), func() any { return viper.GetUint16(FlagTermWidth) }}, {"AGENTAPI_TERM_HEIGHT", "AGENTAPI_TERM_HEIGHT", "500", uint16(500), func() any { return viper.GetUint16(FlagTermHeight) }}, + {"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost example.com", []string{"localhost", "example.com"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }}, + {"AGENTAPI_ALLOWED_ORIGINS", "AGENTAPI_ALLOWED_ORIGINS", "https://example.com http://localhost:3000", []string{"https://example.com", "http://localhost:3000"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }}, } for _, tt := range tests { @@ -183,10 +199,8 @@ func TestServerCmd_AllEnvVars(t *testing.T) { t.Setenv(tt.envVar, tt.envValue) serverCmd := CreateServerCmd() - cmd := &cobra.Command{} - cmd.AddCommand(serverCmd) - - serverCmd.SetArgs([]string{"--help"}) + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--exit", "dummy-command"}) if err := serverCmd.Execute(); err != nil { t.Fatalf("Failed to execute server command: %v", err) } @@ -247,6 +261,13 @@ func TestServerCmd_ArgsPrecedenceOverEnv(t *testing.T) { uint16(600), func() any { return viper.GetUint16(FlagTermHeight) }, }, + { + "allowed-origins: CLI overrides env", + "AGENTAPI_ALLOWED_ORIGINS", "https://env-example.com http://localhost:3000", + []string{"--allowed-origins", "https://cli-example.com"}, + []string{"https://cli-example.com"}, + func() any { return viper.GetStringSlice(FlagAllowedOrigins) }, + }, } for _, tt := range tests { @@ -254,9 +275,9 @@ func TestServerCmd_ArgsPrecedenceOverEnv(t *testing.T) { isolateViper(t) t.Setenv(tt.envVar, tt.envValue) - // Mock execution to test arg parsing without running server - args := append(tt.args, "--help") + args := append(tt.args, "--exit", "dummy-command") serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) serverCmd.SetArgs(args) if err := serverCmd.Execute(); err != nil { t.Fatalf("Failed to execute server command: %v", err) @@ -277,7 +298,8 @@ func TestMixed_ConfigurationScenarios(t *testing.T) { // Set some CLI args serverCmd := CreateServerCmd() - serverCmd.SetArgs([]string{"--port", "9999", "--print-openapi", "--help"}) + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--port", "9999", "--print-openapi", "--exit", "dummy-command"}) if err := serverCmd.Execute(); err != nil { t.Fatalf("Failed to execute server command: %v", err) } @@ -291,3 +313,186 @@ func TestMixed_ConfigurationScenarios(t *testing.T) { assert.Equal(t, uint16(1000), viper.GetUint16(FlagTermHeight)) // default }) } + +func TestServerCmd_AllowedHosts(t *testing.T) { + tests := []struct { + name string + env map[string]string + args []string + expectedErr string + expected []string // only checked if expectedErr is empty + }{ + // Environment variable scenarios (space-separated format) + { + name: "env: single valid host", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"}, + args: []string{}, + expected: []string{"localhost"}, + }, + { + name: "env: multiple valid hosts space-separated", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost example.com 192.168.1.1"}, + args: []string{}, + expected: []string{"localhost", "example.com", "192.168.1.1"}, + }, + { + name: "env: host with tab", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost\texample.com"}, + args: []string{}, + expected: []string{"localhost", "example.com"}, + }, + // CLI flag scenarios (comma-separated format) + { + name: "flag: single valid host", + args: []string{"--allowed-hosts", "localhost"}, + expected: []string{"localhost"}, + }, + { + name: "flag: multiple valid hosts comma-separated", + args: []string{"--allowed-hosts", "localhost,example.com,192.168.1.1"}, + expected: []string{"localhost", "example.com", "192.168.1.1"}, + }, + { + name: "flag: multiple valid hosts with multiple flags", + args: []string{"--allowed-hosts", "localhost", "--allowed-hosts", "example.com"}, + expected: []string{"localhost", "example.com"}, + }, + { + name: "flag: host with newline", + args: []string{"--allowed-hosts", "localhost\n"}, + expected: []string{"localhost"}, + }, + { + name: "flag: ipv6 bracketed literal", + args: []string{"--allowed-hosts", "[2001:db8::1]"}, + expected: []string{"[2001:db8::1]"}, + }, + + // Mixed scenarios (env + flag precedence) + { + name: "mixed: flag overrides env", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"}, + args: []string{"--allowed-hosts", "override.com"}, + expected: []string{"override.com"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isolateViper(t) + + // Set environment variables if provided + for key, value := range tt.env { + t.Setenv(key, value) + } + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs(append(tt.args, "--exit", "dummy-command")) + err := serverCmd.Execute() + + if tt.expectedErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedErr) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, viper.GetStringSlice(FlagAllowedHosts)) + } + }) + } +} + +func TestServerCmd_AllowedOrigins(t *testing.T) { + tests := []struct { + name string + env map[string]string + args []string + expectedErr string + expected []string // only checked if expectedErr is empty + }{ + // Environment variable scenarios (space-separated format) + { + name: "env: single valid origin", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://example.com"}, + args: []string{}, + expected: []string{"https://example.com"}, + }, + { + name: "env: multiple valid origins space-separated", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://example.com http://localhost:3000 https://app.example.com"}, + args: []string{}, + expected: []string{"https://example.com", "http://localhost:3000", "https://app.example.com"}, + }, + { + name: "env: wildcard origin", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "*"}, + args: []string{}, + expected: []string{"*"}, + }, + { + name: "env: origin with tab", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://example.com\thttp://localhost:3000"}, + args: []string{}, + expected: []string{"https://example.com", "http://localhost:3000"}, + }, + + // CLI flag scenarios (comma-separated format) + { + name: "flag: single valid origin", + args: []string{"--allowed-origins", "https://example.com"}, + expected: []string{"https://example.com"}, + }, + { + name: "flag: multiple valid origins comma-separated", + args: []string{"--allowed-origins", "https://example.com,http://localhost:3000,https://app.example.com"}, + expected: []string{"https://example.com", "http://localhost:3000", "https://app.example.com"}, + }, + { + name: "flag: multiple valid origins with multiple flags", + args: []string{"--allowed-origins", "https://example.com", "--allowed-origins", "http://localhost:3000"}, + expected: []string{"https://example.com", "http://localhost:3000"}, + }, + { + name: "flag: wildcard origin", + args: []string{"--allowed-origins", "*"}, + expected: []string{"*"}, + }, + { + name: "flag: origin with newline", + args: []string{"--allowed-origins", "https://example.com\n"}, + expected: []string{"https://example.com"}, + }, + + // Mixed scenarios (env + flag precedence) + { + name: "mixed: flag overrides env", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://env-example.com"}, + args: []string{"--allowed-origins", "https://override.com"}, + expected: []string{"https://override.com"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isolateViper(t) + + // Set environment variables if provided + for key, value := range tt.env { + t.Setenv(key, value) + } + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs(append(tt.args, "--exit", "dummy-command")) + err := serverCmd.Execute() + + if tt.expectedErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedErr) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, viper.GetStringSlice(FlagAllowedOrigins)) + } + }) + } +} diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index e9e71cb..4f1bb14 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -7,9 +7,11 @@ import ( "log/slog" "net/http" "net/url" + "slices" "strings" "sync" "time" + "unicode" "github.com/coder/agentapi/lib/logctx" mf "github.com/coder/agentapi/lib/msgfmt" @@ -60,18 +62,124 @@ func (s *Server) GetOpenAPI() string { const snapshotInterval = 25 * time.Millisecond type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string +} + +// Validate allowed hosts don't contain whitespace, commas, schemes, or ports. +// Viper/Cobra use different separators (space for env vars, comma for flags), +// so these characters likely indicate user error. +func parseAllowedHosts(input []string) ([]string, error) { + if len(input) == 0 { + return nil, fmt.Errorf("the list must not be empty") + } + if slices.Contains(input, "*") { + return []string{"*"}, nil + } + // First pass: whitespace & comma checks (surface these errors first) + // Viper/Cobra use different separators (space for env vars, comma for flags), + // so these characters likely indicate user error. + for _, item := range input { + for _, r := range item { + if unicode.IsSpace(r) { + return nil, fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item) + } + } + if strings.Contains(item, ",") { + return nil, fmt.Errorf("'%s' contains comma characters, which are not allowed", item) + } + } + // Second pass: scheme check + for _, item := range input { + if strings.Contains(item, "http://") || strings.Contains(item, "https://") { + return nil, fmt.Errorf("'%s' must not include http:// or https://", item) + } + } + hosts := make([]*url.URL, 0, len(input)) + // Third pass: url parse + for _, item := range input { + trimmed := strings.TrimSpace(item) + u, err := url.Parse("http://" + trimmed) + if err != nil { + return nil, fmt.Errorf("'%s' is not a valid host: %w", item, err) + } + hosts = append(hosts, u) + } + // Fourth pass: port check + for _, u := range hosts { + if u.Port() != "" { + return nil, fmt.Errorf("'%s' must not include a port", u.Host) + } + } + hostStrings := make([]string, 0, len(hosts)) + for _, u := range hosts { + hostStrings = append(hostStrings, u.Hostname()) + } + return hostStrings, nil +} + +// Validate allowed origins +func parseAllowedOrigins(input []string) ([]string, error) { + if len(input) == 0 { + return nil, fmt.Errorf("the list must not be empty") + } + if slices.Contains(input, "*") { + return []string{"*"}, nil + } + // Viper/Cobra use different separators (space for env vars, comma for flags), + // so these characters likely indicate user error. + for _, item := range input { + for _, r := range item { + if unicode.IsSpace(r) { + return nil, fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item) + } + } + if strings.Contains(item, ",") { + return nil, fmt.Errorf("'%s' contains comma characters, which are not allowed", item) + } + } + origins := make([]string, 0, len(input)) + for _, item := range input { + trimmed := strings.TrimSpace(item) + u, err := url.Parse(trimmed) + if err != nil { + return nil, fmt.Errorf("'%s' is not a valid origin: %w", item, err) + } + origins = append(origins, fmt.Sprintf("%s://%s", u.Scheme, u.Host)) + } + return origins, nil } // NewServer creates a new server instance -func NewServer(ctx context.Context, config ServerConfig) *Server { +func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { router := chi.NewMux() + logger := logctx.From(ctx) + + allowedHosts, err := parseAllowedHosts(config.AllowedHosts) + if err != nil { + return nil, xerrors.Errorf("failed to parse allowed hosts: %w", err) + } + allowedOrigins, err := parseAllowedOrigins(config.AllowedOrigins) + if err != nil { + return nil, xerrors.Errorf("failed to parse allowed origins: %w", err) + } + + logger.Info(fmt.Sprintf("Allowed hosts: %s", strings.Join(allowedHosts, ", "))) + logger.Info(fmt.Sprintf("Allowed origins: %s", strings.Join(allowedOrigins, ", "))) + + // Enforce allowed hosts in a custom middleware that ignores the port during matching. + badHostHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Invalid host header. Allowed hosts: "+strings.Join(allowedHosts, ", "), http.StatusBadRequest) + }) + router.Use(hostAuthorizationMiddleware(allowedHosts, badHostHandler)) + corsMiddleware := cors.New(cors.Options{ - AllowedOrigins: []string{"*"}, + AllowedOrigins: allowedOrigins, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, ExposedHeaders: []string{"Link"}, @@ -101,7 +209,7 @@ func NewServer(ctx context.Context, config ServerConfig) *Server { api: api, port: config.Port, conversation: conversation, - logger: logctx.From(ctx), + logger: logger, agentio: config.Process, agentType: config.AgentType, emitter: emitter, @@ -111,7 +219,7 @@ func NewServer(ctx context.Context, config ServerConfig) *Server { // Register API routes s.registerRoutes() - return s + return s, nil } // Handler returns the underlying chi.Router for testing purposes. @@ -119,6 +227,40 @@ func (s *Server) Handler() http.Handler { return s.router } +// hostAuthorizationMiddleware enforces that the request Host header matches one of the allowed +// hosts, ignoring any port in the comparison. If allowedHosts is empty, all hosts are allowed. +// Always uses url.Parse("http://" + r.Host) to robustly extract the hostname (handles IPv6). +func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Handler) func(next http.Handler) http.Handler { + // Copy for safety; also build a map for O(1) lookups with case-insensitive keys. + allowed := make(map[string]struct{}, len(allowedHosts)) + for _, h := range allowedHosts { + allowed[strings.ToLower(h)] = struct{}{} + } + wildcard := slices.Contains(allowedHosts, "*") + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if wildcard { // wildcard semantics: allow all + next.ServeHTTP(w, r) + return + } + // Extract hostname from the Host header using url.Parse; ignore any port. + hostHeader := r.Host + if hostHeader == "" { + badHostHandler.ServeHTTP(w, r) + return + } + if u, err := url.Parse("http://" + hostHeader); err == nil { + hostname := u.Hostname() + if _, ok := allowed[strings.ToLower(hostname)]; ok { + next.ServeHTTP(w, r) + return + } + } + badHostHandler.ServeHTTP(w, r) + }) + } +} + func (s *Server) StartSnapshotLoop(ctx context.Context) { s.conversation.StartSnapshotLoop(ctx) go func() { diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index badc974..bc50d3e 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -46,12 +46,15 @@ func TestOpenAPISchema(t *testing.T) { t.Parallel() ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - srv := httpapi.NewServer(ctx, httpapi.ServerConfig{ - AgentType: msgfmt.AgentTypeClaude, - Process: nil, - Port: 0, - ChatBasePath: "/chat", + srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, + AllowedOrigins: []string{"*"}, }) + require.NoError(t, err) currentSchemaStr := srv.GetOpenAPI() var currentSchema any if err := json.Unmarshal([]byte(currentSchemaStr), ¤tSchema); err != nil { @@ -95,12 +98,15 @@ func TestServer_redirectToChat(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() tCtx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - s := httpapi.NewServer(tCtx, httpapi.ServerConfig{ - AgentType: msgfmt.AgentTypeClaude, - Process: nil, - Port: 0, - ChatBasePath: tc.chatBasePath, + s, err := httpapi.NewServer(tCtx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: tc.chatBasePath, + AllowedHosts: []string{"*"}, + AllowedOrigins: []string{"*"}, }) + require.NoError(t, err) tsServer := httptest.NewServer(s.Handler()) t.Cleanup(tsServer.Close) @@ -120,3 +126,508 @@ func TestServer_redirectToChat(t *testing.T) { }) } } + +func TestServer_AllowedHosts(t *testing.T) { + cases := []struct { + name string + allowedHosts []string + hostHeader string + expectedStatusCode int + expectedErrorMsg string + validationErrorMsg string + }{ + { + name: "wildcard hosts - any host allowed", + allowedHosts: []string{"*"}, + hostHeader: "example.com", + expectedStatusCode: http.StatusOK, + }, + { + name: "wildcard hosts - another host allowed", + allowedHosts: []string{"*"}, + hostHeader: "malicious.com", + expectedStatusCode: http.StatusOK, + }, + { + name: "specific hosts - valid host allowed", + allowedHosts: []string{"localhost", "app.example.com"}, + hostHeader: "localhost:3000", + expectedStatusCode: http.StatusOK, + }, + { + name: "specific hosts - another valid host allowed", + allowedHosts: []string{"localhost", "app.example.com"}, + hostHeader: "app.example.com", + expectedStatusCode: http.StatusOK, + }, + { + name: "specific hosts - invalid host rejected", + allowedHosts: []string{"localhost", "app.example.com"}, + hostHeader: "malicious.com", + expectedStatusCode: http.StatusBadRequest, + expectedErrorMsg: "Invalid host header. Allowed hosts: localhost, app.example.com", + }, + { + name: "ipv6 bracketed configured allowed - with port", + allowedHosts: []string{"[2001:db8::1]"}, + hostHeader: "[2001:db8::1]:80", + expectedStatusCode: http.StatusOK, + }, + { + name: "ipv6 literal invalid host rejected", + allowedHosts: []string{"[2001:db8::1]"}, + hostHeader: "[2001:db8::2]", + expectedStatusCode: http.StatusBadRequest, + expectedErrorMsg: "Invalid host header. Allowed hosts: 2001:db8::1", + }, + { + name: "allowed hosts must not be empty", + allowedHosts: []string{}, + validationErrorMsg: "the list must not be empty", + }, + { + name: "ipv6 literal without square brackets is invalid", + allowedHosts: []string{"2001:db8::1"}, + validationErrorMsg: "must not include a port", + }, + { + name: "host with port in config is invalid", + allowedHosts: []string{"example.com:8080"}, + validationErrorMsg: "must not include a port", + }, + { + name: "bracketed ipv6 with port in config is invalid", + allowedHosts: []string{"[2001:db8::1]:443"}, + validationErrorMsg: "must not include a port", + }, + { + name: "hostname with http scheme is invalid", + allowedHosts: []string{"http://example.com"}, + validationErrorMsg: "must not include http:// or https://", + }, + { + name: "hostname with https scheme is invalid", + allowedHosts: []string{"https://example.com"}, + validationErrorMsg: "must not include http:// or https://", + }, + { + name: "hostname containing comma is invalid", + allowedHosts: []string{"example.com,malicious.com"}, + validationErrorMsg: "contains comma characters, which are not allowed", + }, + { + name: "hostname with leading whitespace is invalid", + allowedHosts: []string{" example.com"}, + validationErrorMsg: "contains whitespace characters, which are not allowed", + }, + { + name: "hostname with internal whitespace is invalid", + allowedHosts: []string{"exa mple.com"}, + validationErrorMsg: "contains whitespace characters, which are not allowed", + }, + { + name: "uppercase allowed host matches lowercase request", + allowedHosts: []string{"EXAMPLE.COM"}, + hostHeader: "example.com:80", + expectedStatusCode: http.StatusOK, + }, + { + name: "wildcard with extra invalid entries still allows all", + allowedHosts: []string{"*", "https://bad.com", "example.com:8080", " space.com"}, + hostHeader: "malicious.com", + expectedStatusCode: http.StatusOK, + }, + { + name: "trailing dot in allowed host requires trailing dot in request (no match)", + allowedHosts: []string{"example.com."}, + hostHeader: "example.com", + expectedStatusCode: http.StatusBadRequest, + expectedErrorMsg: "Invalid host header. Allowed hosts: example.com.", + }, + { + name: "trailing dot in allowed host matches trailing dot in request", + allowedHosts: []string{"example.com."}, + hostHeader: "example.com.:80", + expectedStatusCode: http.StatusOK, + }, + { + name: "ipv6 bracketed configured allowed - without port header", + allowedHosts: []string{"[2001:db8::1]"}, + hostHeader: "[2001:db8::1]", + expectedStatusCode: http.StatusOK, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: tc.allowedHosts, + AllowedOrigins: []string{"https://example.com"}, // Set a default to isolate host testing + }) + if tc.validationErrorMsg != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.validationErrorMsg) + return + } else { + require.NoError(t, err) + } + tsServer := httptest.NewServer(s.Handler()) + t.Cleanup(tsServer.Close) + + req, err := http.NewRequest("GET", tsServer.URL+"/status", nil) + require.NoError(t, err) + + if tc.hostHeader != "" { + req.Host = tc.hostHeader + } + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + + require.Equal(t, tc.expectedStatusCode, resp.StatusCode, + "expected status code %d, got %d", tc.expectedStatusCode, resp.StatusCode) + + if tc.expectedErrorMsg != "" { + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), tc.expectedErrorMsg) + } + }) + } +} + +func TestServer_CORSPreflightWithHosts(t *testing.T) { + cases := []struct { + name string + allowedHosts []string + hostHeader string + originHeader string + expectedStatusCode int + expectCORSHeaders bool + }{ + { + name: "preflight with wildcard hosts", + allowedHosts: []string{"*"}, + hostHeader: "example.com", + originHeader: "https://example.com", + expectedStatusCode: http.StatusOK, + expectCORSHeaders: true, + }, + { + name: "preflight with specific valid host", + allowedHosts: []string{"localhost"}, + hostHeader: "localhost:3000", + originHeader: "https://localhost:3000", + expectedStatusCode: http.StatusOK, + expectCORSHeaders: true, + }, + { + name: "preflight with invalid host", + allowedHosts: []string{"localhost"}, + hostHeader: "malicious.com", + originHeader: "https://malicious.com", + expectedStatusCode: http.StatusBadRequest, + expectCORSHeaders: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: tc.allowedHosts, + AllowedOrigins: []string{"*"}, // Set wildcard origins to isolate host testing + }) + require.NoError(t, err) + tsServer := httptest.NewServer(s.Handler()) + t.Cleanup(tsServer.Close) + + // Test CORS preflight request + req, err := http.NewRequest("OPTIONS", tsServer.URL+"/status", nil) + require.NoError(t, err) + + if tc.hostHeader != "" { + req.Host = tc.hostHeader + } + if tc.originHeader != "" { + req.Header.Set("Origin", tc.originHeader) + } + req.Header.Set("Access-Control-Request-Method", "GET") + req.Header.Set("Access-Control-Request-Headers", "Content-Type") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + + require.Equal(t, tc.expectedStatusCode, resp.StatusCode, + "expected status code %d, got %d", tc.expectedStatusCode, resp.StatusCode) + + if tc.expectCORSHeaders { + allowMethods := resp.Header.Get("Access-Control-Allow-Methods") + require.Contains(t, allowMethods, "GET", "expected GET in allowed methods") + + allowHeaders := resp.Header.Get("Access-Control-Allow-Headers") + require.Contains(t, allowHeaders, "Content-Type", "expected Content-Type in allowed headers") + } + }) + } +} + +func TestServer_CORSOrigins(t *testing.T) { + cases := []struct { + name string + allowedOrigins []string + originHeader string + expectedStatusCode int + expectedCORSOrigin string + expectCORSOriginHeader bool + validationErrorMsg string + }{ + { + name: "wildcard origins - any origin allowed", + allowedOrigins: []string{"*"}, + originHeader: "https://example.com", + expectedStatusCode: http.StatusOK, + expectedCORSOrigin: "*", + expectCORSOriginHeader: true, + }, + { + name: "wildcard origins - malicious origin allowed", + allowedOrigins: []string{"*"}, + originHeader: "http://malicious.com", + expectedStatusCode: http.StatusOK, + expectedCORSOrigin: "*", + expectCORSOriginHeader: true, + }, + { + name: "specific origins - valid origin allowed https", + allowedOrigins: []string{"https://localhost:3000", "http://app.example.com"}, + originHeader: "https://localhost:3000", + expectedStatusCode: http.StatusOK, + expectedCORSOrigin: "https://localhost:3000", + expectCORSOriginHeader: true, + }, + { + name: "specific origins - valid origin allowed http", + allowedOrigins: []string{"https://localhost:3000", "http://app.example.com"}, + originHeader: "http://app.example.com", + expectedStatusCode: http.StatusOK, + expectedCORSOrigin: "http://app.example.com", + expectCORSOriginHeader: true, + }, + { + name: "specific origins - invalid origin rejected", + allowedOrigins: []string{"https://localhost:3000", "http://app.example.com"}, + originHeader: "https://malicious.com", + expectedStatusCode: http.StatusOK, // Server allows request - CORS is enforced by browser + expectCORSOriginHeader: false, + }, + { + name: "no origin header - request not coming from a browser", + allowedOrigins: []string{"https://example.com"}, + originHeader: "", + expectedStatusCode: http.StatusOK, + }, + { + name: "allowed origins must not be empty", + allowedOrigins: []string{}, + validationErrorMsg: "the list must not be empty", + }, + { + name: "origin containing comma is invalid", + allowedOrigins: []string{"https://example.com,http://localhost:3000"}, + validationErrorMsg: "contains comma characters, which are not allowed", + }, + { + name: "origin with internal whitespace is invalid", + allowedOrigins: []string{"https://exa mple.com"}, + validationErrorMsg: "contains whitespace characters, which are not allowed", + }, + { + name: "origin with leading whitespace is invalid", + allowedOrigins: []string{" https://example.com"}, + validationErrorMsg: "contains whitespace characters, which are not allowed", + }, + { + name: "wildcard with extra invalid entries still allows all", + allowedOrigins: []string{"*", "https://bad.com,too", "http://bad host"}, + originHeader: "http://malicious.com", + expectedCORSOrigin: "*", + expectCORSOriginHeader: true, + expectedStatusCode: http.StatusOK, + }, + { + name: "ipv6 origin allowed", + allowedOrigins: []string{"http://[2001:db8::1]:8080"}, + originHeader: "http://[2001:db8::1]:8080", + expectedCORSOrigin: "http://[2001:db8::1]:8080", + expectCORSOriginHeader: true, + expectedStatusCode: http.StatusOK, + }, + { + name: "origin with path, query, and fragment normalizes to scheme+host", + allowedOrigins: []string{"https://example.com/path?x=1#frag"}, + originHeader: "https://example.com", + expectedCORSOrigin: "https://example.com", + expectCORSOriginHeader: true, + expectedStatusCode: http.StatusOK, + }, + { + name: "trailing slash is ignored for matching", + allowedOrigins: []string{"https://example.com/"}, + originHeader: "https://example.com", + expectedCORSOrigin: "https://example.com", + expectCORSOriginHeader: true, + expectedStatusCode: http.StatusOK, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, // Set wildcard to isolate CORS testing + AllowedOrigins: tc.allowedOrigins, + }) + if tc.validationErrorMsg != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.validationErrorMsg) + return + } + tsServer := httptest.NewServer(s.Handler()) + t.Cleanup(tsServer.Close) + + req, err := http.NewRequest("GET", tsServer.URL+"/status", nil) + require.NoError(t, err) + + if tc.originHeader != "" { + req.Header.Set("Origin", tc.originHeader) + } + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + + require.Equal(t, tc.expectedStatusCode, resp.StatusCode, + "expected status code %d, got %d", tc.expectedStatusCode, resp.StatusCode) + + if tc.expectCORSOriginHeader { + corsOrigin := resp.Header.Get("Access-Control-Allow-Origin") + require.Equal(t, tc.expectedCORSOrigin, corsOrigin, + "expected CORS origin %q, got %q", tc.expectedCORSOrigin, corsOrigin) + } else if tc.expectedStatusCode == http.StatusOK && tc.originHeader != "" { + corsOrigin := resp.Header.Get("Access-Control-Allow-Origin") + require.Empty(t, corsOrigin, "expected no CORS origin header, got %q", corsOrigin) + } + }) + } +} + +func TestServer_CORSPreflightOrigins(t *testing.T) { + cases := []struct { + name string + allowedOrigins []string + originHeader string + expectedStatusCode int + expectCORSHeaders bool + }{ + { + name: "preflight with wildcard origins", + allowedOrigins: []string{"*"}, + originHeader: "https://example.com", + expectedStatusCode: http.StatusOK, + expectCORSHeaders: true, + }, + { + name: "preflight with specific valid origin", + allowedOrigins: []string{"https://localhost:3000"}, + originHeader: "https://localhost:3000", + expectedStatusCode: http.StatusOK, + expectCORSHeaders: true, + }, + { + name: "preflight with invalid origin", + allowedOrigins: []string{"https://localhost:3000"}, + originHeader: "https://malicious.com", + expectedStatusCode: http.StatusOK, // Request succeeds but no CORS headers + expectCORSHeaders: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, // Set wildcard to isolate CORS testing + AllowedOrigins: tc.allowedOrigins, + }) + require.NoError(t, err) + tsServer := httptest.NewServer(s.Handler()) + t.Cleanup(tsServer.Close) + + req, err := http.NewRequest("OPTIONS", tsServer.URL+"/status", nil) + require.NoError(t, err) + + if tc.originHeader != "" { + req.Header.Set("Origin", tc.originHeader) + } + req.Header.Set("Access-Control-Request-Method", "GET") + req.Header.Set("Access-Control-Request-Headers", "Content-Type") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + + require.Equal(t, tc.expectedStatusCode, resp.StatusCode, + "expected status code %d, got %d", tc.expectedStatusCode, resp.StatusCode) + + if tc.expectCORSHeaders { + allowMethods := resp.Header.Get("Access-Control-Allow-Methods") + require.Contains(t, allowMethods, "GET", "expected GET in allowed methods") + + allowHeaders := resp.Header.Get("Access-Control-Allow-Headers") + require.Contains(t, allowHeaders, "Content-Type", "expected Content-Type in allowed headers") + + corsOrigin := resp.Header.Get("Access-Control-Allow-Origin") + require.NotEmpty(t, corsOrigin, "expected CORS origin header for valid preflight") + } else if tc.originHeader != "" { + corsOrigin := resp.Header.Get("Access-Control-Allow-Origin") + require.Empty(t, corsOrigin, "expected no CORS origin header for invalid origin") + } + }) + } +}