diff --git a/cmd/root.go b/cmd/root.go index 3a2fffb96439..dba9f9cc2b7a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -124,6 +124,7 @@ func NewCommand(opts *internal.ToolboxOptions) *cobra.Command { flags.BoolVar(&opts.Cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.") flags.BoolVar(&opts.Cfg.UI, "ui", false, "Launches the Toolbox UI web server.") flags.StringVar(&opts.Cfg.ToolboxUrl, "toolbox-url", "", "Specifies the Toolbox URL. Used as the resource field in the MCP PRM file when MCP Auth is enabled. Falls back to TOOLBOX_URL environment variable.") + flags.StringVar(&opts.Cfg.McpPrmFile, "mcp-prm-file", "", "Path to a manual Protected Resource Metadata (PRM) JSON file. If provided, overrides auto-generation.") // TODO: Insecure by default. Might consider updating this for v1.0.0 flags.StringSliceVar(&opts.Cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.") flags.StringSliceVar(&opts.Cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.") diff --git a/docs/en/reference/cli.md b/docs/en/reference/cli.md index 41c179b49517..c0ae5c0427db 100644 --- a/docs/en/reference/cli.md +++ b/docs/en/reference/cli.md @@ -15,6 +15,7 @@ description: > | `-h` | `--help` | help for toolbox | | | | `--log-level` | Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'. | `info` | | | `--logging-format` | Specify logging format to use. Allowed: 'standard' or 'JSON'. | `standard` | +| | `--mcp-prm-file` | Path to a manual Protected Resource Metadata (PRM) JSON file. If provided, overrides auto-generation for MCP Server-Wide Authentication. | | | `-p` | `--port` | Port the server will listen on. | `5000` | | | `--prebuilt` | Use one or more prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | | | | `--stdio` | Listens via MCP STDIO instead of acting as a remote HTTP server. | | diff --git a/internal/auth/generic/generic_test.go b/internal/auth/generic/generic_test.go index 8c7e2ff57317..d84570976d30 100644 --- a/internal/auth/generic/generic_test.go +++ b/internal/auth/generic/generic_test.go @@ -163,7 +163,7 @@ func TestGetClaimsFromHeader(t *testing.T) { return header }, wantError: true, - errContains: "authorization header format must be Bearer {token}", + errContains: "Authorization header format must be Bearer {token}", }, { name: "wrong audience", diff --git a/internal/server/config.go b/internal/server/config.go index a9031f1bac56..60c6b8c1e043 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -72,6 +72,8 @@ type ServerConfig struct { UI bool // ToolboxUrl specifies the URL to advertise in the MCP PRM file as the resource field. ToolboxUrl string + // McpPrmFile specifies the path to a manual Protected Resource Metadata (PRM) JSON file. If provided, overrides auto-generation. + McpPrmFile string // Specifies a list of origins permitted to access this server. AllowedOrigins []string // Specifies a list of hosts permitted to access this server. diff --git a/internal/server/mcp.go b/internal/server/mcp.go index 69f7f203d663..5160a5c45803 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -23,6 +23,7 @@ import ( "fmt" "io" "net/http" + "os" "sync" "time" @@ -341,7 +342,7 @@ func mcpRouter(s *Server) (chi.Router, error) { } } - if mcpAuthEnabled { + if mcpAuthEnabled || s.mcpPrmFile != "" { r.Get("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) { prmHandler(s, w, r) }) } @@ -773,6 +774,27 @@ type prmResponse struct { // prmHandler generates the Protected Resource Metadata (PRM) file for MCP Authorization. func prmHandler(s *Server, w http.ResponseWriter, r *http.Request) { + if s.mcpPrmFile != "" { + prmBytes, err := os.ReadFile(s.mcpPrmFile) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to read manual PRM file", "error", err, "path", s.mcpPrmFile) + // Returning 500 when it explicitly fails to read a configured file + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + if !json.Valid(prmBytes) { + s.logger.ErrorContext(r.Context(), "manual PRM file is not valid JSON", "path", s.mcpPrmFile) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := w.Write(prmBytes); err != nil { + s.logger.ErrorContext(r.Context(), "failed to write manual PRM file response", "error", err) + } + return + } + var servers []string var scopes []string for _, authSvc := range s.ResourceMgr.GetAuthServiceMap() { diff --git a/internal/server/mcp_test.go b/internal/server/mcp_test.go index 833ad0e18c21..43a47d13cf9b 100644 --- a/internal/server/mcp_test.go +++ b/internal/server/mcp_test.go @@ -1247,3 +1247,77 @@ func TestPRMEndpoint(t *testing.T) { t.Errorf("unexpected PRM response: got %+v, want %+v", got, want) } } + +func TestPRMEndpoint_ManualFile(t *testing.T) { + // Create a temporary manual PRM file + tmpFile, err := os.CreateTemp("", "manual_prm_*.json") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + + manualPRMContent := []byte(`{ + "resource": "https://manual.example.com/mcp", + "authorization_servers": ["https://manual-auth.example.com"], + "scopes_supported": ["manual:scope"], + "bearer_methods_supported": ["header"] + }`) + + if _, err := tmpFile.Write(manualPRMContent); err != nil { + t.Fatalf("failed to write to temp file: %v", err) + } + tmpFile.Close() + + // Initialize the server with the manual PRM file path + resourceManager := resources.NewResourceManager(nil, nil, nil, nil, nil, nil, nil) + testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info") + if err != nil { + t.Fatalf("unable to initialize logger: %s", err) + } + + s := &Server{ + logger: testLogger, + ResourceMgr: resourceManager, + mcpPrmFile: tmpFile.Name(), // Inject manual config path + } + + r, err := mcpRouter(s) + if err != nil { + t.Fatalf("unexpected error creating router: %v", err) + } + + ts := httptest.NewServer(r) + defer ts.Close() + + // Make the request + resp, body, err := runRequest(ts, http.MethodGet, "/.well-known/oauth-protected-resource", nil, nil) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, resp.StatusCode) + } + if contentType := resp.Header.Get("Content-Type"); contentType != "application/json" { + t.Fatalf("expected content-type application/json, got %s", contentType) + } + + // Verify the response body matches the exact contents of the manual file + var got map[string]any + if err := json.Unmarshal(body, &got); err != nil { + t.Fatalf("unexpected error unmarshalling body: %s", err) + } + + want := map[string]any{ + "resource": "https://manual.example.com/mcp", + "authorization_servers": []any{ + "https://manual-auth.example.com", + }, + "scopes_supported": []any{"manual:scope"}, + "bearer_methods_supported": []any{"header"}, + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("unexpected manual PRM response: got %+v, want %+v", got, want) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 1952902f7c3f..d795564a3926 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -53,6 +53,7 @@ type Server struct { sseManager *sseManager ResourceMgr *resources.ResourceManager toolboxUrl string + mcpPrmFile string } func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( @@ -380,6 +381,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { sseManager: sseManager, ResourceMgr: resourceManager, toolboxUrl: cfg.ToolboxUrl, + mcpPrmFile: cfg.McpPrmFile, } // cors