diff --git a/CLAUDE.md b/CLAUDE.md index 4c2da00..72a25fc 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -56,6 +56,8 @@ mcp-front is a Go-based OAuth 2.1 proxy server for MCP (Model Context Protocol) - **Define interfaces where they are used** - Not in the package that implements them - **Avoid circular imports** - Use interface segregation in separate packages when needed - **Dependency injection over getter methods** - Pass dependencies to constructors +- **Functional core, imperative shell** - Prefer pure functions for business logic, keep side effects (I/O, state mutations) at the boundaries. Makes code more testable and reasoning easier. +- **Upstream lifecycle control** - Manage goroutines, servers, and background processes from the application root. Library code should expose Start/Stop methods, not start things autonomously. ### 🎯 Core Development Principles (from Zig Zen) @@ -153,6 +155,7 @@ cmd/mcp-front/ # Main application entry point 3. Don't create new auth patterns - use existing OAuth or bearer token auth 4. Don't modify git configuration 5. Don't create README files proactively +6. **Variable shadowing package names** - `config.MCPClientConfig is not a type` means a variable named `config` is shadowing the package. Always check for variables that shadow imported package names ### When Working on Features diff --git a/Makefile b/Makefile index 7a44aba..fab58ac 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,8 @@ doc: format: go fmt ./... + # go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test ./... + modernize -fix -test ./... cd docs-site && npm run format lint: diff --git a/README.md b/README.md index 5590ec2..33d06fb 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,8 @@ Some MCP servers are better to use with each users having their own integration "notion": { "transportType": "stdio", "requiresUserToken": true, - "tokenSetup": { + "userAuthentication": { + "type": "manual", "displayName": "Notion Integration Token", "instructions": "Create an integration and copy the token", "helpUrl": "https://www.notion.so/my-integrations" diff --git a/cmd/mcp-front/main.go b/cmd/mcp-front/main.go index 799f624..7fd26a3 100644 --- a/cmd/mcp-front/main.go +++ b/cmd/mcp-front/main.go @@ -1,14 +1,15 @@ package main import ( + "context" "encoding/json" "flag" "fmt" "os" + "github.com/dgellow/mcp-front/internal" "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/log" - "github.com/dgellow/mcp-front/internal/server" ) var BuildVersion = "dev" @@ -151,12 +152,19 @@ func main() { os.Exit(1) } - log.LogInfoWithFields("main", "Starting mcp-front", map[string]interface{}{ + log.LogInfoWithFields("main", "Starting mcp-front", map[string]any{ "version": BuildVersion, "config": *conf, }) - err = server.Run(cfg) + ctx := context.Background() + mcpFront, err := internal.NewMCPFront(ctx, cfg) + if err != nil { + log.LogError("Failed to create MCP proxy: %v", err) + os.Exit(1) + } + + err = mcpFront.Run() if err != nil { log.LogError("Failed to start server: %v", err) os.Exit(1) diff --git a/config-oauth.json b/config-oauth.json index df4f529..c02bf83 100644 --- a/config-oauth.json +++ b/config-oauth.json @@ -32,7 +32,8 @@ "notion": { "transportType": "stdio", "requiresUserToken": true, - "tokenSetup": { + "userAuthentication": { + "type": "manual", "displayName": "Notion Integration Token", "instructions": "Create an integration at https://www.notion.so/my-integrations and copy the token", "helpUrl": "https://developers.notion.com/docs/create-a-notion-integration" diff --git a/config-user-tokens-example.json b/config-user-tokens-example.json index 53b4a29..bd93f36 100644 --- a/config-user-tokens-example.json +++ b/config-user-tokens-example.json @@ -28,7 +28,8 @@ "OPENAPI_MCP_HEADERS": {"$userToken": "{\"Authorization\": \"Bearer {{token}}\"}"} }, "requiresUserToken": true, - "tokenSetup": { + "userAuthentication": { + "type": "manual", "displayName": "Notion API Token", "instructions": "Enter your Notion API token. You can find this at https://www.notion.so/my-integrations", "helpUrl": "https://developers.notion.com/docs/authorization", @@ -43,7 +44,8 @@ "GITHUB_TOKEN": {"$userToken": "{{token}}"} }, "requiresUserToken": true, - "tokenSetup": { + "userAuthentication": { + "type": "manual", "displayName": "GitHub Personal Access Token", "instructions": "Create a personal access token at https://github.com/settings/tokens", "helpUrl": "https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token", diff --git a/docs-site/src/content/docs/configuration.md b/docs-site/src/content/docs/configuration.md index 86c81dd..cbc6b6d 100644 --- a/docs-site/src/content/docs/configuration.md +++ b/docs-site/src/content/docs/configuration.md @@ -185,7 +185,8 @@ Use the `options` field for additional configuration: "X-API-Key": { "$env": "DB_API_KEY" } }, "requiresUserToken": true, - "tokenSetup": { + "userAuthentication": { + "type": "manual", "displayName": "Database Token", "instructions": "Get your token from the admin panel", "helpUrl": "https://db.company.com/tokens" diff --git a/go.mod b/go.mod index 8c1575d..0507e3c 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/mark3labs/mcp-go v0.28.0 github.com/ory/fosite v0.42.0 github.com/stretchr/testify v1.10.0 + golang.org/x/crypto v0.38.0 golang.org/x/oauth2 v0.30.0 google.golang.org/api v0.214.0 google.golang.org/grpc v1.67.3 @@ -61,7 +62,6 @@ require ( go.opentelemetry.io/otel v1.29.0 // indirect go.opentelemetry.io/otel/metric v1.29.0 // indirect go.opentelemetry.io/otel/trace v1.29.0 // indirect - golang.org/x/crypto v0.38.0 // indirect golang.org/x/net v0.40.0 // indirect golang.org/x/sync v0.14.0 // indirect golang.org/x/sys v0.33.0 // indirect diff --git a/integration/cli_test.go b/integration/cli_test.go new file mode 100644 index 0000000..660678b --- /dev/null +++ b/integration/cli_test.go @@ -0,0 +1,42 @@ +package integration + +import ( + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCLIConfigInitGeneratesValidConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "generated-config.json") + + t.Run("generate and validate config", func(t *testing.T) { + // Step 1: Generate config with -config-init + cmd := exec.Command("../cmd/mcp-front/mcp-front", "-config-init", configPath) + output, err := cmd.CombinedOutput() + + t.Logf("config-init output: %s", output) + + require.NoError(t, err, "config-init should succeed") + assert.Contains(t, string(output), "Generated default config at:", "should report generation") + + // Verify file was created + fi, err := os.Stat(configPath) + require.NoError(t, err, "config file should exist") + require.Greater(t, fi.Size(), int64(0), "config file should not be empty") + + // Step 2: Validate the generated config + cmd = exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath, "-validate") + output, err = cmd.CombinedOutput() + + t.Logf("validate output: %s", output) + + // The generated config should be valid + require.NoError(t, err, "validate should succeed for config-init generated file") + assert.Contains(t, string(output), "Result: PASS", "validation should pass") + }) +} diff --git a/integration/config/config.oauth-token-test.json b/integration/config/config.oauth-token-test.json index ee4ebba..5d39ce4 100644 --- a/integration/config/config.oauth-token-test.json +++ b/integration/config/config.oauth-token-test.json @@ -24,18 +24,20 @@ "transportType": "sse", "url": "https://notion-mcp.example.com", "requiresUserToken": true, - "tokenSetup": { + "userAuthentication": { + "type": "manual", "displayName": "Notion", "instructions": "Create a Notion integration token", "helpUrl": "https://developers.notion.com", - "tokenFormat": "^secret_[a-zA-Z0-9]{43}$" + "validation": "^secret_[a-zA-Z0-9]{43}$" } }, "github": { "transportType": "sse", "url": "https://github-mcp.example.com", "requiresUserToken": true, - "tokenSetup": { + "userAuthentication": { + "type": "manual", "displayName": "GitHub", "instructions": "Create a GitHub personal access token", "helpUrl": "https://github.com/settings/tokens" diff --git a/integration/config/config.oauth-usertoken-tools-test.json b/integration/config/config.oauth-usertoken-tools-test.json index c1c912a..9d2f6d4 100644 --- a/integration/config/config.oauth-usertoken-tools-test.json +++ b/integration/config/config.oauth-usertoken-tools-test.json @@ -36,7 +36,8 @@ "USER_TOKEN": {"$userToken": "{{token}}"} }, "requiresUserToken": true, - "tokenSetup": { + "userAuthentication": { + "type": "manual", "displayName": "Test Service", "instructions": "Enter your test token", "helpUrl": "https://example.com/help" diff --git a/integration/inline_test.go b/integration/inline_test.go index 5a3df11..c3b34ca 100644 --- a/integration/inline_test.go +++ b/integration/inline_test.go @@ -36,9 +36,9 @@ func TestInlineMCPServer(t *testing.T) { // Test 1: Basic echo tool (static args) t.Run("echo tool", func(t *testing.T) { - params := map[string]interface{}{ + params := map[string]any{ "name": "echo", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "message": "Hello, inline MCP!", }, } @@ -47,18 +47,18 @@ func TestInlineMCPServer(t *testing.T) { require.NoError(t, err, "Failed to call echo tool") // Check for error in response - errorMap, hasError := result["error"].(map[string]interface{}) + errorMap, hasError := result["error"].(map[string]any) assert.False(t, hasError, "Echo tool returned error: %v", errorMap) // Verify result - resultMap, ok := result["result"].(map[string]interface{}) + resultMap, ok := result["result"].(map[string]any) require.True(t, ok, "Expected result in response") - content, ok := resultMap["content"].([]interface{}) + content, ok := resultMap["content"].([]any) require.True(t, ok, "Expected content in result") require.NotEmpty(t, content, "Expected content array") - firstContent, ok := content[0].(map[string]interface{}) + firstContent, ok := content[0].(map[string]any) require.True(t, ok, "Expected content item to be map") text, ok := firstContent["text"].(string) @@ -69,18 +69,18 @@ func TestInlineMCPServer(t *testing.T) { // Test 2: Environment variables t.Run("environment variables", func(t *testing.T) { - params := map[string]interface{}{ + params := map[string]any{ "name": "env_test", - "arguments": map[string]interface{}{}, + "arguments": map[string]any{}, } result, err := client.SendMCPRequest("tools/call", params) require.NoError(t, err, "Failed to call env_test tool") // Check result - resultMap, _ := result["result"].(map[string]interface{}) - content, _ := resultMap["content"].([]interface{}) - firstContent, _ := content[0].(map[string]interface{}) + resultMap, _ := result["result"].(map[string]any) + content, _ := resultMap["content"].([]any) + firstContent, _ := content[0].(map[string]any) text, _ := firstContent["text"].(string) // printenv outputs all environment variables @@ -90,18 +90,18 @@ func TestInlineMCPServer(t *testing.T) { // Test 3: Static output test t.Run("static output", func(t *testing.T) { - params := map[string]interface{}{ + params := map[string]any{ "name": "static_test", - "arguments": map[string]interface{}{}, + "arguments": map[string]any{}, } result, err := client.SendMCPRequest("tools/call", params) require.NoError(t, err, "Failed to call static_test tool") // Check result - resultMap, _ := result["result"].(map[string]interface{}) - content, _ := resultMap["content"].([]interface{}) - firstContent, _ := content[0].(map[string]interface{}) + resultMap, _ := result["result"].(map[string]any) + content, _ := resultMap["content"].([]any) + firstContent, _ := content[0].(map[string]any) text, _ := firstContent["text"].(string) assert.Contains(t, text, "Static output: test") @@ -109,9 +109,9 @@ func TestInlineMCPServer(t *testing.T) { // Test 4: JSON output parsing t.Run("JSON output", func(t *testing.T) { - params := map[string]interface{}{ + params := map[string]any{ "name": "json_output", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "value": "test-input", }, } @@ -120,9 +120,9 @@ func TestInlineMCPServer(t *testing.T) { require.NoError(t, err, "Failed to call json_output tool") // For JSON output, the content should be parsed as JSON - resultMap, _ := result["result"].(map[string]interface{}) - content, _ := resultMap["content"].([]interface{}) - firstContent, _ := content[0].(map[string]interface{}) + resultMap, _ := result["result"].(map[string]any) + content, _ := resultMap["content"].([]any) + firstContent, _ := content[0].(map[string]any) // The JSON output should be in the text field as a string text, ok := firstContent["text"].(string) @@ -135,16 +135,16 @@ func TestInlineMCPServer(t *testing.T) { // Test 6: Error handling t.Run("failing tool", func(t *testing.T) { - params := map[string]interface{}{ + params := map[string]any{ "name": "failing_tool", - "arguments": map[string]interface{}{}, + "arguments": map[string]any{}, } result, err := client.SendMCPRequest("tools/call", params) require.NoError(t, err, "Request should succeed even if tool fails") // Check for error in response - errorMap, hasError := result["error"].(map[string]interface{}) + errorMap, hasError := result["error"].(map[string]any) assert.True(t, hasError, "Expected error for failing tool") if hasError { @@ -158,21 +158,21 @@ func TestInlineMCPServer(t *testing.T) { // Test 7: List tools t.Run("list tools", func(t *testing.T) { - result, err := client.SendMCPRequest("tools/list", map[string]interface{}{}) + result, err := client.SendMCPRequest("tools/list", map[string]any{}) require.NoError(t, err, "Failed to list tools") // Check result - resultMap, ok := result["result"].(map[string]interface{}) + resultMap, ok := result["result"].(map[string]any) require.True(t, ok, "Expected result in response") - tools, ok := resultMap["tools"].([]interface{}) + tools, ok := resultMap["tools"].([]any) require.True(t, ok, "Expected tools array") assert.Len(t, tools, 6, "Expected 6 tools") // Verify tool names toolNames := make([]string, 0) for _, tool := range tools { - toolMap, _ := tool.(map[string]interface{}) + toolMap, _ := tool.(map[string]any) name, _ := toolMap["name"].(string) toolNames = append(toolNames, name) } diff --git a/integration/integration_test.go b/integration/integration_test.go index 2ea1a6d..270e54f 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -40,9 +40,9 @@ func TestIntegration(t *testing.T) { t.Log("Connected to MCP server with session") - queryParams := map[string]interface{}{ + queryParams := map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "sql": "SELECT COUNT(*) as user_count FROM users", }, } @@ -55,33 +55,33 @@ func TestIntegration(t *testing.T) { require.NotNil(t, result, "Expected some response from MCP server") // Check for error in response - errorMap, hasError := result["error"].(map[string]interface{}) + errorMap, hasError := result["error"].(map[string]any) assert.False(t, hasError, "Query returned error: %v", errorMap) // Verify we got result content - resultMap, ok := result["result"].(map[string]interface{}) + resultMap, ok := result["result"].(map[string]any) require.True(t, ok, "Expected result in response") - content, ok := resultMap["content"].([]interface{}) + content, ok := resultMap["content"].([]any) require.True(t, ok, "Expected content in result") assert.NotEmpty(t, content, "Query result missing content") t.Log("Query executed successfully") // Test resources list - resourcesResult, err := client.SendMCPRequest("resources/list", map[string]interface{}{}) + resourcesResult, err := client.SendMCPRequest("resources/list", map[string]any{}) require.NoError(t, err, "Failed to list resources") t.Logf("Resources response: %+v", resourcesResult) // Check for error in resources response - errorMap, hasError = resourcesResult["error"].(map[string]interface{}) + errorMap, hasError = resourcesResult["error"].(map[string]any) assert.False(t, hasError, "Resources list returned error: %v", errorMap) // Verify we got resources - resultMap, ok = resourcesResult["result"].(map[string]interface{}) + resultMap, ok = resourcesResult["result"].(map[string]any) require.True(t, ok, "Expected result in resources response") - resources, ok := resultMap["resources"].([]interface{}) + resources, ok := resultMap["resources"].([]any) require.True(t, ok, "Expected resources array in result") assert.NotEmpty(t, resources, "Expected at least one resource") t.Logf("Found %d resources", len(resources)) diff --git a/integration/isolation_test.go b/integration/isolation_test.go index 7fef43f..41d5d0f 100644 --- a/integration/isolation_test.go +++ b/integration/isolation_test.go @@ -1,6 +1,7 @@ package integration import ( + "slices" "testing" "time" @@ -53,11 +54,8 @@ func TestMultiUserSessionIsolation(t *testing.T) { var client1Container string for _, container := range containersAfterClient1 { isNew := true - for _, initial := range initialContainers { - if container == initial { - isNew = false - break - } + if slices.Contains(initialContainers, container) { + isNew = false } if isNew { client1Container = container @@ -70,9 +68,9 @@ func TestMultiUserSessionIsolation(t *testing.T) { t.Error("No new container created for client1") } - query1Result, err := client1.SendMCPRequest("tools/call", map[string]interface{}{ + query1Result, err := client1.SendMCPRequest("tools/call", map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "sql": "SELECT 'user1-query1' as test_id, COUNT(*) as count FROM users", }, }) @@ -101,11 +99,8 @@ func TestMultiUserSessionIsolation(t *testing.T) { var client2Container string for _, container := range containersAfterClient2 { isNew := true - for _, existing := range containersAfterClient1 { - if container == existing { - isNew = false - break - } + if slices.Contains(containersAfterClient1, container) { + isNew = false } if isNew { client2Container = container @@ -126,9 +121,9 @@ func TestMultiUserSessionIsolation(t *testing.T) { t.Logf("Confirmed different stdio processes: User1 container=%s, User2 container=%s", client1Container, client2Container) } - query2Result, err := client2.SendMCPRequest("tools/call", map[string]interface{}{ + query2Result, err := client2.SendMCPRequest("tools/call", map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "sql": "SELECT 'user2-query1' as test_id, COUNT(*) as count FROM orders", }, }) @@ -139,9 +134,9 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Step 3: First user sends another query t.Log("\nStep 3: First user sends another query") - query3Result, err := client1.SendMCPRequest("tools/call", map[string]interface{}{ + query3Result, err := client1.SendMCPRequest("tools/call", map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "sql": "SELECT 'user1-query2' as test_id, current_timestamp as ts", }, }) @@ -152,9 +147,9 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Step 4: First user sends another query t.Log("\nStep 4: First user sends another query") - query4Result, err := client1.SendMCPRequest("tools/call", map[string]interface{}{ + query4Result, err := client1.SendMCPRequest("tools/call", map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "sql": "SELECT 'user1-query3' as test_id, version() as db_version", }, }) @@ -165,9 +160,9 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Step 5: Second user sends a query t.Log("\nStep 5: Second user sends a query") - query5Result, err := client2.SendMCPRequest("tools/call", map[string]interface{}{ + query5Result, err := client2.SendMCPRequest("tools/call", map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "sql": "SELECT 'user2-query2' as test_id, current_database() as db_name", }, }) @@ -241,9 +236,9 @@ func TestSessionCleanupAfterTimeout(t *testing.T) { assert.Greater(t, len(containersAfterConnect), len(initialContainers), "No new container created for client") // Send a query to ensure session is active - _, err = client.SendMCPRequest("tools/call", map[string]interface{}{ + _, err = client.SendMCPRequest("tools/call", map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "sql": "SELECT 'test' as test_id", }, }) @@ -310,11 +305,11 @@ func TestSessionTimerReset(t *testing.T) { // Keep session alive by sending queries every 5 seconds // With 8s timeout, this should keep it alive - for i := 0; i < 3; i++ { + for i := range 3 { t.Logf("Sending keepalive query %d/3...", i+1) - _, err := client.SendMCPRequest("tools/call", map[string]interface{}{ + _, err := client.SendMCPRequest("tools/call", map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "sql": "SELECT 'keepalive' as status, NOW() as timestamp", }, }) @@ -413,11 +408,11 @@ func TestMultiUserTimerIndependence(t *testing.T) { go func() { defer close(done) // Run 4 queries - the last one AFTER client1 should be cleaned up - for i := 0; i < 4; i++ { + for i := range 4 { time.Sleep(4 * time.Second) - _, err := client2.SendMCPRequest("tools/call", map[string]interface{}{ + _, err := client2.SendMCPRequest("tools/call", map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "sql": "SELECT 'client2-keepalive' as status", }, }) diff --git a/integration/main_test.go b/integration/main_test.go index ea22072..49eb876 100644 --- a/integration/main_test.go +++ b/integration/main_test.go @@ -53,21 +53,21 @@ func TestMain(m *testing.M) { os.Exit(exitCode) }() - // Start mock GCP server for OAuth - mockGCP := NewMockGCPServer("9090") - err := mockGCP.Start() + // Start fake GCP server for OAuth + fakeGCP := NewFakeGCPServer("9090") + err := fakeGCP.Start() if err != nil { - fmt.Printf("Failed to start mock GCP server: %v\n", err) + fmt.Printf("Failed to start fake GCP server: %v\n", err) exitCode = 1 return } defer func() { - _ = mockGCP.Stop() + _ = fakeGCP.Stop() }() // Wait for database to be ready fmt.Println("Waiting for database to be ready...") - for i := 0; i < 30; i++ { // Wait up to 30 seconds + for i := range 30 { // Wait up to 30 seconds checkCmd := exec.Command("docker", "compose", "exec", "-T", "test-postgres", "pg_isready", "-U", "testuser", "-d", "testdb") if err := checkCmd.Run(); err == nil { fmt.Println("Database is ready!") @@ -83,7 +83,7 @@ func TestMain(m *testing.M) { // Wait for SSE server to be ready fmt.Println("Waiting for SSE server to be ready...") - for i := 0; i < 30; i++ { // Wait up to 30 seconds + for i := range 30 { // Wait up to 30 seconds resp, err := http.Get("http://localhost:3001") if err == nil { resp.Body.Close() @@ -102,7 +102,7 @@ func TestMain(m *testing.M) { // Wait for Streamable server to be ready fmt.Println("Waiting for Streamable server to be ready...") - for i := 0; i < 30; i++ { // Wait up to 30 seconds + for i := range 30 { // Wait up to 30 seconds resp, err := http.Get("http://localhost:3002") if err == nil { resp.Body.Close() diff --git a/integration/oauth_test.go b/integration/oauth_test.go index e077862..f377194 100644 --- a/integration/oauth_test.go +++ b/integration/oauth_test.go @@ -48,7 +48,7 @@ func TestBasicOAuthFlow(t *testing.T) { assert.Equal(t, 200, resp.StatusCode, "OAuth discovery failed") - var discovery map[string]interface{} + var discovery map[string]any err = json.NewDecoder(resp.Body).Decode(&discovery) require.NoError(t, err, "Failed to decode discovery") @@ -66,7 +66,7 @@ func TestBasicOAuthFlow(t *testing.T) { } // Verify client_secret_post is advertised - authMethods, ok := discovery["token_endpoint_auth_methods_supported"].([]interface{}) + authMethods, ok := discovery["token_endpoint_auth_methods_supported"].([]any) assert.True(t, ok, "token_endpoint_auth_methods_supported should be present") var hasNone, hasClientSecretPost bool @@ -157,13 +157,13 @@ func TestClientRegistration(t *testing.T) { }) defer stopServer(mcpCmd) - if !waitForHealthCheck(t, 30) { + if !waitForHealthCheck(30) { t.Fatal("OAuth server failed to start") } t.Run("PublicClientRegistration", func(t *testing.T) { // Register a public client (no secret) - clientReq := map[string]interface{}{ + clientReq := map[string]any{ "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback/debug"}, "scope": "read write", } @@ -184,7 +184,7 @@ func TestClientRegistration(t *testing.T) { t.Fatalf("Client registration failed with status %d: %s", resp.StatusCode, string(body)) } - var clientResp map[string]interface{} + var clientResp map[string]any if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -205,8 +205,8 @@ func TestClientRegistration(t *testing.T) { // Register multiple clients and verify they get different IDs var clientIDs []string - for i := 0; i < 3; i++ { - clientReq := map[string]interface{}{ + for i := range 3 { + clientReq := map[string]any{ "redirect_uris": []string{fmt.Sprintf("http://example.com/callback%d", i)}, "scope": "read", } @@ -222,7 +222,7 @@ func TestClientRegistration(t *testing.T) { } defer resp.Body.Close() - var clientResp map[string]interface{} + var clientResp map[string]any _ = json.NewDecoder(resp.Body).Decode(&clientResp) clientIDs = append(clientIDs, clientResp["client_id"].(string)) } @@ -240,7 +240,7 @@ func TestClientRegistration(t *testing.T) { t.Run("ConfidentialClientRegistration", func(t *testing.T) { // Register a confidential client with client_secret_post - clientReq := map[string]interface{}{ + clientReq := map[string]any{ "redirect_uris": []string{"https://example.com/callback"}, "scope": "read write", "token_endpoint_auth_method": "client_secret_post", @@ -262,7 +262,7 @@ func TestClientRegistration(t *testing.T) { t.Fatalf("Confidential client registration failed with status %d: %s", resp.StatusCode, string(body)) } - var clientResp map[string]interface{} + var clientResp map[string]any if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -295,7 +295,7 @@ func TestClientRegistration(t *testing.T) { // Test that public clients don't get secrets and confidential ones do // First, create a public client - publicReq := map[string]interface{}{ + publicReq := map[string]any{ "redirect_uris": []string{"https://public.example.com/callback"}, "scope": "read", // No token_endpoint_auth_method specified - defaults to "none" @@ -310,7 +310,7 @@ func TestClientRegistration(t *testing.T) { require.NoError(t, err) defer resp.Body.Close() - var publicResp map[string]interface{} + var publicResp map[string]any _ = json.NewDecoder(resp.Body).Decode(&publicResp) // Verify public client has no secret @@ -322,7 +322,7 @@ func TestClientRegistration(t *testing.T) { } // Now create a confidential client - confidentialReq := map[string]interface{}{ + confidentialReq := map[string]any{ "redirect_uris": []string{"https://confidential.example.com/callback"}, "scope": "read write", "token_endpoint_auth_method": "client_secret_post", @@ -337,7 +337,7 @@ func TestClientRegistration(t *testing.T) { require.NoError(t, err) defer resp.Body.Close() - var confResp map[string]interface{} + var confResp map[string]any _ = json.NewDecoder(resp.Body).Decode(&confResp) // Verify confidential client has a secret @@ -358,7 +358,7 @@ func TestUserTokenFlow(t *testing.T) { mcpCmd := startOAuthServerWithTokenConfig(t) defer stopServer(mcpCmd) - if !waitForHealthCheck(t, 30) { + if !waitForHealthCheck(30) { t.Fatal("Server failed to start") } @@ -524,7 +524,7 @@ func TestStateParameterHandling(t *testing.T) { } for _, tt := range tests { - tt := tt // capture range variable + // capture range variable t.Run(tt.name, func(t *testing.T) { // Start server with specific environment mcpCmd := startOAuthServer(t, map[string]string{ @@ -532,7 +532,7 @@ func TestStateParameterHandling(t *testing.T) { }) defer stopServer(mcpCmd) - if !waitForHealthCheck(t, 10) { + if !waitForHealthCheck(10) { t.Fatal("Server failed to start") } @@ -602,7 +602,7 @@ func TestEnvironmentModes(t *testing.T) { }) defer stopServer(mcpCmd) - if !waitForHealthCheck(t, 30) { + if !waitForHealthCheck(30) { t.Fatal("Server failed to start") } @@ -643,7 +643,7 @@ func TestEnvironmentModes(t *testing.T) { }) defer stopServer(mcpCmd) - if !waitForHealthCheck(t, 30) { + if !waitForHealthCheck(30) { t.Fatal("Server failed to start") } @@ -693,7 +693,7 @@ func TestOAuthEndpoints(t *testing.T) { }) defer stopServer(mcpCmd) - if !waitForHealthCheck(t, 10) { + if !waitForHealthCheck(10) { t.Fatal("Server failed to start") } @@ -708,7 +708,7 @@ func TestOAuthEndpoints(t *testing.T) { t.Fatalf("Discovery failed with status %d", resp.StatusCode) } - var discovery map[string]interface{} + var discovery map[string]any if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil { t.Fatalf("Failed to decode discovery response: %v", err) } @@ -761,7 +761,7 @@ func TestCORSHeaders(t *testing.T) { }) defer stopServer(mcpCmd) - if !waitForHealthCheck(t, 10) { + if !waitForHealthCheck(10) { t.Fatal("Server failed to start") } @@ -813,7 +813,7 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { "LOG_LEVEL=debug", ) - if !waitForHealthCheck(t, 30) { + if !waitForHealthCheck(30) { t.Fatal("Server failed to start") } @@ -831,21 +831,21 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { defer mcpClient.Close() // Request tools list - toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]interface{}{}) + toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{}) require.NoError(t, err, "Should list tools without user token") // Verify we got tools - resultMap, ok := toolsResp["result"].(map[string]interface{}) + resultMap, ok := toolsResp["result"].(map[string]any) require.True(t, ok, "Expected result in tools response") - tools, ok := resultMap["tools"].([]interface{}) + tools, ok := resultMap["tools"].([]any) require.True(t, ok, "Expected tools array in result") assert.NotEmpty(t, tools, "Should have tools advertised") // Check for common postgres tools var toolNames []string for _, tool := range tools { - if toolMap, ok := tool.(map[string]interface{}); ok { + if toolMap, ok := tool.(map[string]any); ok { if name, ok := toolMap["name"].(string); ok { toolNames = append(toolNames, name) } @@ -867,9 +867,9 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { defer mcpClient.Close() // Try to invoke a tool without user token - queryParams := map[string]interface{}{ + queryParams := map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "sql": "SELECT 1", }, } @@ -880,20 +880,20 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { // MCP protocol returns errors as successful responses with error content require.NotNil(t, result["result"], "Should have result in response") - resultMap := result["result"].(map[string]interface{}) - content := resultMap["content"].([]interface{}) + resultMap := result["result"].(map[string]any) + content := resultMap["content"].([]any) require.NotEmpty(t, content, "Should have content in result") - contentItem := content[0].(map[string]interface{}) + contentItem := content[0].(map[string]any) errorJSON := contentItem["text"].(string) // Parse the error JSON - var errorData map[string]interface{} + var errorData map[string]any err = json.Unmarshal([]byte(errorJSON), &errorData) require.NoError(t, err, "Error should be valid JSON") // Verify error structure - errorInfo := errorData["error"].(map[string]interface{}) + errorInfo := errorData["error"].(map[string]any) assert.Equal(t, "token_required", errorInfo["code"], "Error code should be token_required") errorMessage := errorInfo["message"].(string) @@ -902,12 +902,12 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { assert.Contains(t, errorMessage, "Test Service", "Error should mention service name") // Verify error data - errData := errorInfo["data"].(map[string]interface{}) + errData := errorInfo["data"].(map[string]any) assert.Equal(t, "postgres", errData["service"], "Should identify the service") assert.Contains(t, errData["tokenSetupUrl"].(string), "/my/tokens", "Should include token setup URL") // Verify instructions - instructions := errData["instructions"].(map[string]interface{}) + instructions := errData["instructions"].(map[string]any) assert.Contains(t, instructions["ai"].(string), "CRITICAL", "Should have AI instructions") assert.Contains(t, instructions["human"].(string), "token required", "Should have human instructions") }) @@ -980,13 +980,14 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { defer resp.Body.Close() // Check the response - it might be 200 if following redirects - if resp.StatusCode == 200 { + switch resp.StatusCode { + case 200: // That's fine, it means the token was set and we got the page back t.Log("Token set successfully, got page response") - } else if resp.StatusCode == 302 || resp.StatusCode == 303 { + case 302, 303: // Also fine, redirect means success t.Log("Token set successfully, got redirect") - } else { + default: body, _ := io.ReadAll(resp.Body) t.Fatalf("Unexpected response setting token: status=%d, body=%s", resp.StatusCode, string(body)) } @@ -1000,9 +1001,9 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { defer mcpClient.Close() // Call the query tool with a simple query - queryParams := map[string]interface{}{ + queryParams := map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "sql": "SELECT 1 as test", }, } @@ -1013,11 +1014,11 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { // Verify we got a successful result, not an error require.NotNil(t, result["result"], "Should have result in response") - resultMap := result["result"].(map[string]interface{}) - content := resultMap["content"].([]interface{}) + resultMap := result["result"].(map[string]any) + content := resultMap["content"].([]any) require.NotEmpty(t, content, "Should have content in result") - contentItem := content[0].(map[string]interface{}) + contentItem := content[0].(map[string]any) resultText := contentItem["text"].(string) // The result should contain actual query results, not an error @@ -1119,8 +1120,8 @@ func stopServer(cmd *exec.Cmd) { } } -func waitForHealthCheck(t *testing.T, seconds int) bool { - for i := 0; i < seconds; i++ { +func waitForHealthCheck(seconds int) bool { + for range seconds { if checkHealth() { return true } @@ -1142,7 +1143,7 @@ func checkHealth() bool { } func registerTestClient(t *testing.T) string { - clientReq := map[string]interface{}{ + clientReq := map[string]any{ "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback"}, "scope": "openid email profile read write", } @@ -1163,7 +1164,7 @@ func registerTestClient(t *testing.T) string { t.Fatalf("Client registration failed: %d - %s", resp.StatusCode, string(body)) } - var clientResp map[string]interface{} + var clientResp map[string]any _ = json.NewDecoder(resp.Body).Decode(&clientResp) return clientResp["client_id"].(string) } @@ -1269,7 +1270,7 @@ func getOAuthAccessToken(t *testing.T) string { require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed") - var tokenData map[string]interface{} + var tokenData map[string]any err = json.NewDecoder(tokenResp.Body).Decode(&tokenData) require.NoError(t, err) diff --git a/integration/security_test.go b/integration/security_test.go index ebce8c0..2305c92 100644 --- a/integration/security_test.go +++ b/integration/security_test.go @@ -100,9 +100,9 @@ func TestSecurityScenarios(t *testing.T) { // Testing SQL injection payload // Try to inject via the query parameter - _, err := client.SendMCPRequest("tools/call", map[string]interface{}{ + _, err := client.SendMCPRequest("tools/call", map[string]any{ "name": "query", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "query": payload, }, }) @@ -221,10 +221,11 @@ func TestSecurityScenarios(t *testing.T) { } resp.Body.Close() - if resp.StatusCode == 200 { + switch resp.StatusCode { + case 200: t.Errorf("CRITICAL: Auth bypass! 'test-token' without Bearer returned 200") - } else if resp.StatusCode == 401 { - } else { + case 401: + default: t.Logf("Unexpected status %d for malformed auth", resp.StatusCode) } }) @@ -275,7 +276,7 @@ func TestSecurityScenarios(t *testing.T) { errorCount := 0 // Make rapid requests to see if there's any rate limiting - for i := 0; i < 10; i++ { + for range 10 { err := client.ValidateBackendConnectivity() if err != nil { errorCount++ diff --git a/integration/sse_test.go b/integration/sse_test.go index 5193820..92eadfb 100644 --- a/integration/sse_test.go +++ b/integration/sse_test.go @@ -34,9 +34,9 @@ func TestSSEServerIntegration(t *testing.T) { t.Log("Connected to SSE MCP server") // Let's list available tools first - params := map[string]interface{}{ + params := map[string]any{ "method": "tools/list", - "params": map[string]interface{}{}, + "params": map[string]any{}, } result, err := client.SendMCPRequest("tools/list", params) @@ -49,9 +49,9 @@ func TestSSEServerIntegration(t *testing.T) { t.Run("SSE tool invocation", func(t *testing.T) { // The mock server should provide echo_text tool - params := map[string]interface{}{ + params := map[string]any{ "name": "echo_text", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "text": "Hello from SSE test!", }, } @@ -60,7 +60,7 @@ func TestSSEServerIntegration(t *testing.T) { require.NoError(t, err, "Failed to call echo_text tool") // Check for successful response - if errorMap, hasError := result["error"].(map[string]interface{}); hasError { + if errorMap, hasError := result["error"].(map[string]any); hasError { t.Fatalf("Got error response: %v", errorMap) } @@ -68,7 +68,7 @@ func TestSSEServerIntegration(t *testing.T) { assert.NotNil(t, result["result"]) // Verify the echo result contains our text - if resultData, ok := result["result"].(map[string]interface{}); ok { + if resultData, ok := result["result"].(map[string]any); ok { if toolResult, ok := resultData["toolResult"].(string); ok { assert.Equal(t, "Hello from SSE test!", toolResult) } @@ -84,9 +84,9 @@ func TestSSEServerIntegration(t *testing.T) { require.NoError(t, err, "Failed to reconnect to SSE server") // Verify we can still make requests - params := map[string]interface{}{ + params := map[string]any{ "method": "tools/list", - "params": map[string]interface{}{}, + "params": map[string]any{}, } result, err := client.SendMCPRequest("tools/list", params) @@ -96,9 +96,9 @@ func TestSSEServerIntegration(t *testing.T) { t.Run("SSE streaming functionality", func(t *testing.T) { // Test that SSE streaming works by calling the sample_stream tool - params := map[string]interface{}{ + params := map[string]any{ "name": "sample_stream", - "arguments": map[string]interface{}{}, + "arguments": map[string]any{}, } // The mock server provides sample_stream tool @@ -109,23 +109,23 @@ func TestSSEServerIntegration(t *testing.T) { // Verify we got a successful result assert.NotContains(t, result, "error", "Should not have error for sample_stream") - if resultData, ok := result["result"].(map[string]interface{}); ok { + if resultData, ok := result["result"].(map[string]any); ok { assert.Equal(t, "Tool executed successfully", resultData["toolResult"]) } }) t.Run("SSE error handling", func(t *testing.T) { // Test calling a non-existent tool - params := map[string]interface{}{ + params := map[string]any{ "name": "non_existent_tool_xyz", - "arguments": map[string]interface{}{}, + "arguments": map[string]any{}, } result, err := client.SendMCPRequest("tools/call", params) require.NoError(t, err, "Should not get connection error for non-existent tool") // Should get an error in the response - errorMap, hasError := result["error"].(map[string]interface{}) + errorMap, hasError := result["error"].(map[string]any) assert.True(t, hasError, "Expected error for non-existent tool") if hasError { assert.NotEmpty(t, errorMap["message"], "Error should have a message") @@ -138,13 +138,13 @@ func TestSSEServerIntegration(t *testing.T) { // Test that we can handle multiple concurrent requests done := make(chan bool, COUNT) - for i := 0; i < COUNT; i++ { + for i := range COUNT { go func(index int) { defer func() { done <- true }() - params := map[string]interface{}{ + params := map[string]any{ "name": "echo_text", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "text": string(rune('A' + index)), }, } @@ -157,7 +157,7 @@ func TestSSEServerIntegration(t *testing.T) { // Wait for all requests to complete timeout := time.After(10 * time.Second) - for i := 0; i < COUNT; i++ { + for range COUNT { select { case <-done: // Good @@ -186,9 +186,9 @@ func TestSSEServerRestart(t *testing.T) { require.NoError(t, err, "Failed initial connection") // Make a successful request - params := map[string]interface{}{ + params := map[string]any{ "method": "tools/list", - "params": map[string]interface{}{}, + "params": map[string]any{}, } result, err := client.SendMCPRequest("tools/list", params) diff --git a/integration/streamable_test.go b/integration/streamable_test.go index f4e37a1..9c63ea2 100644 --- a/integration/streamable_test.go +++ b/integration/streamable_test.go @@ -38,9 +38,9 @@ func TestStreamableServerIntegration(t *testing.T) { t.Log("Connected to Streamable MCP server") // List available tools - params := map[string]interface{}{ + params := map[string]any{ "method": "tools/list", - "params": map[string]interface{}{}, + "params": map[string]any{}, } result, err := client.SendMCPRequest("tools/list", params) @@ -51,8 +51,8 @@ func TestStreamableServerIntegration(t *testing.T) { assert.NotContains(t, result, "error", "Expected no error in response") // Verify tools are present - if resultData, ok := result["result"].(map[string]interface{}); ok { - if tools, ok := resultData["tools"].([]interface{}); ok { + if resultData, ok := result["result"].(map[string]any); ok { + if tools, ok := resultData["tools"].([]any); ok { assert.Equal(t, 2, len(tools), "Expected 2 tools") } } @@ -60,16 +60,16 @@ func TestStreamableServerIntegration(t *testing.T) { t.Run("Streamable tool invocation with JSON response", func(t *testing.T) { // Call the get_time tool - params := map[string]interface{}{ + params := map[string]any{ "name": "get_time", - "arguments": map[string]interface{}{}, + "arguments": map[string]any{}, } result, err := client.SendMCPRequest("tools/call", params) require.NoError(t, err, "Failed to call get_time tool") // Check for successful response - if errorMap, hasError := result["error"].(map[string]interface{}); hasError { + if errorMap, hasError := result["error"].(map[string]any); hasError { t.Fatalf("Got error response: %v", errorMap) } @@ -77,7 +77,7 @@ func TestStreamableServerIntegration(t *testing.T) { assert.NotNil(t, result["result"]) // Verify the time result - if resultData, ok := result["result"].(map[string]interface{}); ok { + if resultData, ok := result["result"].(map[string]any); ok { if toolResult, ok := resultData["toolResult"].(string); ok { assert.NotEmpty(t, toolResult, "Should have gotten a timestamp") t.Logf("Got time: %s", toolResult) @@ -88,13 +88,13 @@ func TestStreamableServerIntegration(t *testing.T) { t.Run("Streamable POST with actual SSE response", func(t *testing.T) { baseURL := "http://localhost:8080/test-streamable/sse" - request := map[string]interface{}{ + request := map[string]any{ "jsonrpc": "2.0", "id": 1, "method": "tools/call", - "params": map[string]interface{}{ + "params": map[string]any{ "name": "echo_streamable", - "arguments": map[string]interface{}{ + "arguments": map[string]any{ "text": "Hello SSE!", }, }, @@ -119,13 +119,13 @@ func TestStreamableServerIntegration(t *testing.T) { assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) scanner := bufio.NewScanner(resp.Body) - var responses []map[string]interface{} + var responses []map[string]any for scanner.Scan() { line := scanner.Text() - if strings.HasPrefix(line, "data: ") { - data := strings.TrimPrefix(line, "data: ") - var msg map[string]interface{} + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after + var msg map[string]any if err := json.Unmarshal([]byte(data), &msg); err == nil { responses = append(responses, msg) } @@ -138,7 +138,7 @@ func TestStreamableServerIntegration(t *testing.T) { found := false for _, response := range responses { if id, ok := response["id"]; ok && id == float64(1) { - if result, ok := response["result"].(map[string]interface{}); ok { + if result, ok := response["result"].(map[string]any); ok { if toolResult, ok := result["toolResult"].(string); ok { assert.Equal(t, "Echo: Hello SSE!", toolResult) found = true @@ -152,16 +152,16 @@ func TestStreamableServerIntegration(t *testing.T) { t.Run("Streamable error handling", func(t *testing.T) { // Test calling a non-existent tool - params := map[string]interface{}{ + params := map[string]any{ "name": "non_existent_tool", - "arguments": map[string]interface{}{}, + "arguments": map[string]any{}, } result, err := client.SendMCPRequest("tools/call", params) require.NoError(t, err, "Should not get connection error for non-existent tool") // Should get an error in the response - errorMap, hasError := result["error"].(map[string]interface{}) + errorMap, hasError := result["error"].(map[string]any) assert.True(t, hasError, "Expected error for non-existent tool") if hasError { assert.Equal(t, float64(-32601), errorMap["code"], "Expected method not found error code") diff --git a/integration/test_utils.go b/integration/test_utils.go index 8f6fe4b..3f5f749 100644 --- a/integration/test_utils.go +++ b/integration/test_utils.go @@ -11,6 +11,7 @@ import ( "net/url" "os" "os/exec" + "slices" "strings" "sync" "syscall" @@ -119,8 +120,8 @@ func (c *MCPSSEClient) ConnectToServer(serverName string) error { tracef("ConnectToServer: SSE line: %s", line) // Look for data lines - if strings.HasPrefix(line, "data: ") { - data := strings.TrimPrefix(line, "data: ") + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after // Check if it's an endpoint message (for inline servers) if strings.Contains(data, `"type":"endpoint"`) { @@ -172,7 +173,7 @@ func (c *MCPSSEClient) Close() { } // SendMCPRequest sends an MCP JSON-RPC request and returns the response -func (c *MCPSSEClient) SendMCPRequest(method string, params interface{}) (map[string]interface{}, error) { +func (c *MCPSSEClient) SendMCPRequest(method string, params any) (map[string]any, error) { // Ensure we have a connection if c.messageEndpoint == "" { if err := c.Connect(); err != nil { @@ -181,7 +182,7 @@ func (c *MCPSSEClient) SendMCPRequest(method string, params interface{}) (map[st } // Send MCP request to the message endpoint - request := map[string]interface{}{ + request := map[string]any{ "jsonrpc": "2.0", "id": 1, "method": method, @@ -223,10 +224,10 @@ func (c *MCPSSEClient) SendMCPRequest(method string, params interface{}) (map[st for c.sseScanner.Scan() { line := c.sseScanner.Text() - if strings.HasPrefix(line, "data: ") { - data := strings.TrimPrefix(line, "data: ") + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after // Try to parse as JSON - var msg map[string]interface{} + var msg map[string]any if err := json.Unmarshal([]byte(data), &msg); err == nil { // Check if this is our response (matching ID) if id, ok := msg["id"]; ok && id == float64(1) { @@ -243,7 +244,7 @@ func (c *MCPSSEClient) SendMCPRequest(method string, params interface{}) (map[st return nil, fmt.Errorf("no response received from SSE stream") } - var result map[string]interface{} + var result map[string]any if err := json.Unmarshal(respBody, &result); err != nil { return nil, fmt.Errorf("failed to parse response: %v - %s", err, string(respBody)) } @@ -355,7 +356,7 @@ func (c *MCPStreamableClient) readSSEMessages() { } // SendMCPRequest sends a JSON-RPC request via POST -func (c *MCPStreamableClient) SendMCPRequest(method string, params interface{}) (map[string]interface{}, error) { +func (c *MCPStreamableClient) SendMCPRequest(method string, params any) (map[string]any, error) { c.mu.Lock() serverName := c.serverName c.mu.Unlock() @@ -368,7 +369,7 @@ func (c *MCPStreamableClient) SendMCPRequest(method string, params interface{}) url := c.baseURL + "/" + serverName + "/sse" // Construct JSON-RPC request - request := map[string]interface{}{ + request := map[string]any{ "jsonrpc": "2.0", "id": 1, "method": method, @@ -409,8 +410,8 @@ func (c *MCPStreamableClient) SendMCPRequest(method string, params interface{}) } // handleJSONResponse processes a regular JSON response -func (c *MCPStreamableClient) handleJSONResponse(body io.Reader) (map[string]interface{}, error) { - var response map[string]interface{} +func (c *MCPStreamableClient) handleJSONResponse(body io.Reader) (map[string]any, error) { + var response map[string]any if err := json.NewDecoder(body).Decode(&response); err != nil { return nil, fmt.Errorf("failed to decode JSON response: %v", err) } @@ -418,15 +419,15 @@ func (c *MCPStreamableClient) handleJSONResponse(body io.Reader) (map[string]int } // handleSSEResponse processes an SSE stream response from a POST -func (c *MCPStreamableClient) handleSSEResponse(body io.Reader) (map[string]interface{}, error) { +func (c *MCPStreamableClient) handleSSEResponse(body io.Reader) (map[string]any, error) { scanner := bufio.NewScanner(body) - var lastResponse map[string]interface{} + var lastResponse map[string]any for scanner.Scan() { line := scanner.Text() - if strings.HasPrefix(line, "data: ") { - data := strings.TrimPrefix(line, "data: ") - var msg map[string]interface{} + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after + var msg map[string]any if err := json.Unmarshal([]byte(data), &msg); err == nil { // Keep the last response with an ID (not a notification) if _, hasID := msg["id"]; hasID { @@ -466,14 +467,14 @@ func (c *MCPStreamableClient) close() { c.serverName = "" } -// MockGCPServer provides a mock GCP IAM server for testing -type MockGCPServer struct { +// FakeGCPServer provides a fake GCP OAuth server for testing +type FakeGCPServer struct { server *http.Server port string } -// NewMockGCPServer creates a new mock GCP server -func NewMockGCPServer(port string) *MockGCPServer { +// NewFakeGCPServer creates a new fake GCP server +func NewFakeGCPServer(port string) *FakeGCPServer { mux := http.NewServeMux() mux.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) { @@ -494,7 +495,7 @@ func NewMockGCPServer(port string) *MockGCPServer { if code != "test-auth-code" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + _ = json.NewEncoder(w).Encode(map[string]any{ "error": "invalid_grant", "error_description": "Invalid authorization code", }) @@ -502,7 +503,7 @@ func NewMockGCPServer(port string) *MockGCPServer { } w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + _ = json.NewEncoder(w).Encode(map[string]any{ "access_token": "test-access-token", "token_type": "Bearer", "expires_in": 3600, @@ -511,7 +512,7 @@ func NewMockGCPServer(port string) *MockGCPServer { mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]interface{}{ + _ = json.NewEncoder(w).Encode(map[string]any{ "email": "test@test.com", "hd": "test.com", }) @@ -522,14 +523,14 @@ func NewMockGCPServer(port string) *MockGCPServer { Handler: mux, } - return &MockGCPServer{ + return &FakeGCPServer{ server: server, port: port, } } -// Start starts the mock GCP server -func (m *MockGCPServer) Start() error { +// Start starts the fake GCP server +func (m *FakeGCPServer) Start() error { go func() { if err := m.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { panic(err) @@ -540,8 +541,8 @@ func (m *MockGCPServer) Start() error { return nil } -// Stop stops the mock GCP server -func (m *MockGCPServer) Stop() error { +// Stop stops the fake GCP server +func (m *FakeGCPServer) Stop() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() return m.server.Shutdown(ctx) @@ -551,7 +552,7 @@ func (m *MockGCPServer) Stop() error { type TestEnvironment struct { dbCmd *exec.Cmd mcpCmd *exec.Cmd - mockGCP *MockGCPServer + fakeGCP *FakeGCPServer client *MCPSSEClient } @@ -570,8 +571,8 @@ func SetupTestEnvironment(t *testing.T) *TestEnvironment { // Start mock GCP server t.Log("🚀 Starting mock GCP server...") - env.mockGCP = NewMockGCPServer("9090") - if err := env.mockGCP.Start(); err != nil { + env.fakeGCP = NewFakeGCPServer("9090") + if err := env.fakeGCP.Start(); err != nil { t.Fatalf("Failed to start mock GCP server: %v", err) } @@ -610,8 +611,8 @@ func (env *TestEnvironment) Cleanup() { _ = env.mcpCmd.Process.Kill() } - if env.mockGCP != nil { - _ = env.mockGCP.Stop() + if env.fakeGCP != nil { + _ = env.fakeGCP.Stop() } if env.dbCmd != nil { @@ -661,7 +662,7 @@ func GetTestConfig() TestConfig { func waitForDB(t *testing.T) { waitForSec := 5 - for i := 0; i < waitForSec; i++ { + for range waitForSec { // Check if container is running psCmd := exec.Command("docker", "compose", "ps", "-q", "test-postgres") if output, err := psCmd.Output(); err != nil || len(output) == 0 { @@ -681,14 +682,14 @@ func waitForDB(t *testing.T) { } // trace logs a message if TRACE environment variable is set -func trace(t *testing.T, format string, args ...interface{}) { +func trace(t *testing.T, format string, args ...any) { if os.Getenv("TRACE") == "1" { t.Logf("TRACE: "+format, args...) } } // tracef logs a formatted message to stdout if TRACE is set (for use outside tests) -func tracef(format string, args ...interface{}) { +func tracef(format string, args ...any) { if os.Getenv("TRACE") == "1" { fmt.Printf("TRACE: "+format+"\n", args...) } @@ -778,7 +779,7 @@ func stopMCPFront(cmd *exec.Cmd) { // waitForMCPFront waits for the mcp-front server to be ready func waitForMCPFront(t *testing.T) { t.Helper() - for i := 0; i < 10; i++ { + for range 10 { resp, err := http.Get("http://localhost:8080/health") if err == nil && resp.StatusCode == 200 { resp.Body.Close() @@ -814,13 +815,7 @@ func cleanupContainers(t *testing.T, initialContainers []string) { time.Sleep(2 * time.Second) containers := getMCPContainers() for _, container := range containers { - isInitial := false - for _, initial := range initialContainers { - if container == initial { - isInitial = true - break - } - } + isInitial := slices.Contains(initialContainers, container) if !isInitial { t.Logf("Force stopping container: %s...", container) if err := exec.Command("docker", "stop", container).Run(); err != nil { diff --git a/internal/auth/admin.go b/internal/adminauth/admin.go similarity index 80% rename from internal/auth/admin.go rename to internal/adminauth/admin.go index b1c0a3e..61c8826 100644 --- a/internal/auth/admin.go +++ b/internal/adminauth/admin.go @@ -1,11 +1,11 @@ -package auth +package adminauth import ( "context" "github.com/dgellow/mcp-front/internal/config" + emailutil "github.com/dgellow/mcp-front/internal/email" "github.com/dgellow/mcp-front/internal/storage" - "github.com/dgellow/mcp-front/internal/utils" ) // IsAdmin checks if a user is admin (either config-based or promoted) @@ -15,7 +15,7 @@ func IsAdmin(ctx context.Context, email string, adminConfig *config.AdminConfig, } // Normalize the input email - normalizedEmail := utils.NormalizeEmail(email) + normalizedEmail := emailutil.Normalize(email) // Check if user is a config admin (super admin) if IsConfigAdmin(normalizedEmail, adminConfig) { @@ -27,7 +27,7 @@ func IsAdmin(ctx context.Context, email string, adminConfig *config.AdminConfig, users, err := store.GetAllUsers(ctx) if err == nil { for _, user := range users { - if utils.NormalizeEmail(user.Email) == normalizedEmail && user.IsAdmin { + if emailutil.Normalize(user.Email) == normalizedEmail && user.IsAdmin { return true } } @@ -44,12 +44,12 @@ func IsConfigAdmin(email string, adminConfig *config.AdminConfig) bool { } // Email should already be normalized by the caller, but normalize anyway for safety - normalizedEmail := utils.NormalizeEmail(email) + normalizedEmail := emailutil.Normalize(email) for _, adminEmail := range adminConfig.AdminEmails { // Admin emails should be normalized during config load, but we normalize here too // to handle any legacy configs or manual edits - if utils.NormalizeEmail(adminEmail) == normalizedEmail { + if emailutil.Normalize(adminEmail) == normalizedEmail { return true } } diff --git a/internal/auth/admin_test.go b/internal/adminauth/admin_test.go similarity index 99% rename from internal/auth/admin_test.go rename to internal/adminauth/admin_test.go index 23064b5..bd90227 100644 --- a/internal/auth/admin_test.go +++ b/internal/adminauth/admin_test.go @@ -1,4 +1,4 @@ -package auth +package adminauth import ( "context" diff --git a/internal/auth/service_oauth.go b/internal/auth/service_oauth.go new file mode 100644 index 0000000..4ed7a44 --- /dev/null +++ b/internal/auth/service_oauth.go @@ -0,0 +1,304 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/crypto" + "github.com/dgellow/mcp-front/internal/log" + "github.com/dgellow/mcp-front/internal/storage" + "golang.org/x/oauth2" +) + +// ServiceOAuthClient handles OAuth flows for external MCP services +type ServiceOAuthClient struct { + storage storage.UserTokenStore + baseURL string + httpClient *http.Client + stateCache map[string]*ServiceOAuthState // In production, use distributed cache +} + +// ServiceOAuthState stores OAuth flow state for external service authentication (mcp-front → external service) +type ServiceOAuthState struct { + Service string + UserEmail string + CreatedAt time.Time +} + +// NewServiceOAuthClient creates a new OAuth client for external services +func NewServiceOAuthClient(storage storage.UserTokenStore, baseURL string) *ServiceOAuthClient { + return &ServiceOAuthClient{ + storage: storage, + baseURL: baseURL, + httpClient: &http.Client{Timeout: 30 * time.Second}, + stateCache: make(map[string]*ServiceOAuthState), + } +} + +// StartOAuthFlow initiates OAuth flow for a service +func (c *ServiceOAuthClient) StartOAuthFlow( + ctx context.Context, + userEmail string, + serviceName string, + serviceConfig *config.MCPClientConfig, +) (string, error) { + if serviceConfig.UserAuthentication == nil || + serviceConfig.UserAuthentication.Type != config.UserAuthTypeOAuth { + return "", fmt.Errorf("service %s does not support OAuth", serviceName) + } + + auth := serviceConfig.UserAuthentication + + // Create OAuth2 config + oauth2Config := &oauth2.Config{ + ClientID: string(auth.ClientID), + ClientSecret: string(auth.ClientSecret), + Endpoint: oauth2.Endpoint{ + AuthURL: auth.AuthorizationURL, + TokenURL: auth.TokenURL, + }, + RedirectURL: fmt.Sprintf("%s/oauth/callback/%s", c.baseURL, serviceName), + Scopes: auth.Scopes, + } + + // Generate state parameter + state := crypto.GenerateSecureToken() + c.stateCache[state] = &ServiceOAuthState{ + Service: serviceName, + UserEmail: userEmail, + CreatedAt: time.Now(), + } + + // Clean up old states (older than 10 minutes) + c.cleanupOldStates() + + // Generate authorization URL + authURL := oauth2Config.AuthCodeURL(state, oauth2.AccessTypeOffline) + + log.LogInfoWithFields("service_oauth", "Starting OAuth flow", map[string]any{ + "service": serviceName, + "user": userEmail, + "authURL": authURL, + "redirect": oauth2Config.RedirectURL, + }) + + return authURL, nil +} + +// HandleCallback processes OAuth callback +func (c *ServiceOAuthClient) HandleCallback( + ctx context.Context, + serviceName string, + code string, + state string, + serviceConfig *config.MCPClientConfig, +) (userEmail string, err error) { + // Validate state + oauthState, exists := c.stateCache[state] + if !exists { + return "", fmt.Errorf("invalid state parameter") + } + delete(c.stateCache, state) // One-time use + + // Validate service matches + if oauthState.Service != serviceName { + return "", fmt.Errorf("service mismatch in OAuth callback") + } + + auth := serviceConfig.UserAuthentication + if auth == nil || auth.Type != config.UserAuthTypeOAuth { + return "", fmt.Errorf("service %s does not support OAuth", serviceName) + } + + // Create OAuth2 config + oauth2Config := &oauth2.Config{ + ClientID: string(auth.ClientID), + ClientSecret: string(auth.ClientSecret), + Endpoint: oauth2.Endpoint{ + AuthURL: auth.AuthorizationURL, + TokenURL: auth.TokenURL, + }, + RedirectURL: fmt.Sprintf("%s/oauth/callback/%s", c.baseURL, serviceName), + Scopes: auth.Scopes, + } + + // Exchange code for token + token, err := oauth2Config.Exchange(ctx, code) + if err != nil { + log.LogErrorWithFields("service_oauth", "Failed to exchange code for token", map[string]any{ + "service": serviceName, + "error": err.Error(), + }) + return "", fmt.Errorf("failed to exchange code: %w", err) + } + + // Store the token + storedToken := &storage.StoredToken{ + Type: storage.TokenTypeOAuth, + OAuthData: &storage.OAuthTokenData{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + ExpiresAt: token.Expiry, + TokenType: token.TokenType, + Scopes: auth.Scopes, + }, + UpdatedAt: time.Now(), + } + + if err := c.storage.SetUserToken(ctx, oauthState.UserEmail, serviceName, storedToken); err != nil { + log.LogErrorWithFields("service_oauth", "Failed to store OAuth token", map[string]any{ + "service": serviceName, + "user": oauthState.UserEmail, + "error": err.Error(), + }) + return "", fmt.Errorf("failed to store token: %w", err) + } + + log.LogInfoWithFields("service_oauth", "OAuth flow completed successfully", map[string]any{ + "service": serviceName, + "user": oauthState.UserEmail, + }) + + return oauthState.UserEmail, nil +} + +// RefreshToken refreshes an OAuth token if needed +func (c *ServiceOAuthClient) RefreshToken( + ctx context.Context, + userEmail string, + serviceName string, + serviceConfig *config.MCPClientConfig, +) error { + // Get current token + storedToken, err := c.storage.GetUserToken(ctx, userEmail, serviceName) + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + + if storedToken.Type != storage.TokenTypeOAuth || storedToken.OAuthData == nil { + return fmt.Errorf("token is not an OAuth token") + } + + // Check if refresh is needed (refresh if expires within 5 minutes) + if time.Until(storedToken.OAuthData.ExpiresAt) > 5*time.Minute { + return nil // Token still valid + } + + if storedToken.OAuthData.RefreshToken == "" { + return fmt.Errorf("no refresh token available") + } + + auth := serviceConfig.UserAuthentication + if auth == nil || auth.Type != config.UserAuthTypeOAuth { + return fmt.Errorf("service configuration missing OAuth settings") + } + + // Create OAuth2 config + oauth2Config := &oauth2.Config{ + ClientID: string(auth.ClientID), + ClientSecret: string(auth.ClientSecret), + Endpoint: oauth2.Endpoint{ + AuthURL: auth.AuthorizationURL, + TokenURL: auth.TokenURL, + }, + Scopes: auth.Scopes, + } + + // Create token source for refresh + oldToken := &oauth2.Token{ + AccessToken: storedToken.OAuthData.AccessToken, + RefreshToken: storedToken.OAuthData.RefreshToken, + Expiry: storedToken.OAuthData.ExpiresAt, + TokenType: storedToken.OAuthData.TokenType, + } + + tokenSource := oauth2Config.TokenSource(ctx, oldToken) + newToken, err := tokenSource.Token() + if err != nil { + log.LogErrorWithFields("service_oauth", "Failed to refresh token", map[string]any{ + "service": serviceName, + "user": userEmail, + "error": err.Error(), + }) + return fmt.Errorf("failed to refresh token: %w", err) + } + + // Update stored token + storedToken.OAuthData.AccessToken = newToken.AccessToken + if newToken.RefreshToken != "" { + storedToken.OAuthData.RefreshToken = newToken.RefreshToken + } + storedToken.OAuthData.ExpiresAt = newToken.Expiry + storedToken.UpdatedAt = time.Now() + + if err := c.storage.SetUserToken(ctx, userEmail, serviceName, storedToken); err != nil { + return fmt.Errorf("failed to store refreshed token: %w", err) + } + + log.LogInfoWithFields("service_oauth", "Token refreshed successfully", map[string]any{ + "service": serviceName, + "user": userEmail, + "expiry": newToken.Expiry, + }) + + return nil +} + +// GetConnectURL generates the OAuth connect URL for a service +func (c *ServiceOAuthClient) GetConnectURL(serviceName string, returnPath string) string { + params := url.Values{} + params.Set("service", serviceName) + if returnPath != "" { + params.Set("return", returnPath) + } + return fmt.Sprintf("%s/oauth/connect?%s", c.baseURL, params.Encode()) +} + +// cleanupOldStates removes expired state entries +func (c *ServiceOAuthClient) cleanupOldStates() { + cutoff := time.Now().Add(-10 * time.Minute) + for state, oauthState := range c.stateCache { + if oauthState.CreatedAt.Before(cutoff) { + delete(c.stateCache, state) + } + } +} + +// ParseTokenResponse parses a token response for custom OAuth implementations +func ParseTokenResponse(body []byte) (*oauth2.Token, error) { + var resp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + } + + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + token := &oauth2.Token{ + AccessToken: resp.AccessToken, + RefreshToken: resp.RefreshToken, + TokenType: resp.TokenType, + } + + if resp.ExpiresIn > 0 { + token.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second) + } + + if resp.Scope != "" { + token = token.WithExtra(map[string]any{ + "scope": strings.Split(resp.Scope, " "), + }) + } + + return token, nil +} diff --git a/internal/browserauth/session.go b/internal/browserauth/session.go new file mode 100644 index 0000000..a1876d9 --- /dev/null +++ b/internal/browserauth/session.go @@ -0,0 +1,15 @@ +package browserauth + +import "time" + +// SessionCookie represents the data stored in encrypted browser session cookies +type SessionCookie struct { + Email string `json:"email"` + Expires time.Time `json:"expires"` +} + +// AuthorizationState represents the OAuth authorization code flow state parameter +type AuthorizationState struct { + Nonce string `json:"nonce"` + ReturnURL string `json:"return_url"` +} diff --git a/internal/client/client.go b/internal/client/client.go index ba975ec..68a438c 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -81,7 +81,7 @@ func DefaultTransportCreator(conf *config.MCPClientConfig) (MCPClientInterface, envs = append(envs, fmt.Sprintf("%s=%s", k, v)) } - log.LogInfoWithFields("client", "Starting stdio MCP process", map[string]interface{}{ + log.LogInfoWithFields("client", "Starting stdio MCP process", map[string]any{ "command": conf.Command, "args": conf.Args, "env": envs, @@ -89,7 +89,7 @@ func DefaultTransportCreator(conf *config.MCPClientConfig) (MCPClientInterface, mcpClient, err := client.NewStdioMCPClient(conf.Command, envs, conf.Args...) if err != nil { - log.LogErrorWithFields("client", "Failed to start stdio MCP process", map[string]interface{}{ + log.LogErrorWithFields("client", "Failed to start stdio MCP process", map[string]any{ "command": conf.Command, "args": conf.Args, "error": err.Error(), @@ -97,7 +97,7 @@ func DefaultTransportCreator(conf *config.MCPClientConfig) (MCPClientInterface, return nil, err } - log.LogInfoWithFields("client", "Successfully started stdio MCP process", map[string]interface{}{ + log.LogInfoWithFields("client", "Successfully started stdio MCP process", map[string]any{ "command": conf.Command, }) @@ -162,7 +162,7 @@ func (c *Client) addToolsToServer( tokenStore storage.UserTokenStore, serverName string, setupBaseURL string, - tokenSetup *config.TokenSetupConfig, + userAuth *config.UserAuthentication, session server.ClientSession, ) error { toolsRequest := mcp.ListToolsRequest{} @@ -198,7 +198,7 @@ func (c *Client) addToolsToServer( } } - log.LogInfoWithFields("client", "Starting tool discovery", map[string]interface{}{ + log.LogInfoWithFields("client", "Starting tool discovery", map[string]any{ "server": c.name, }) @@ -213,7 +213,7 @@ func (c *Client) addToolsToServer( return fmt.Errorf("session does not support session-specific tools") } sessionTools = make(map[string]server.ServerTool) - log.LogInfoWithFields("client", "Using session-specific tool registration", map[string]interface{}{ + log.LogInfoWithFields("client", "Using session-specific tool registration", map[string]any{ "server": c.name, "sessionID": session.SessionID(), }) @@ -222,7 +222,7 @@ func (c *Client) addToolsToServer( for { tools, err := c.client.ListTools(ctx, toolsRequest) if err != nil { - log.LogErrorWithFields("client", "Failed to list tools", map[string]interface{}{ + log.LogErrorWithFields("client", "Failed to list tools", map[string]any{ "server": c.name, "error": err.Error(), }) @@ -236,7 +236,7 @@ func (c *Client) addToolsToServer( for _, tool := range tools.Tools { if filterFunc(tool.Name) { - log.LogDebugWithFields("client", "Adding tool", map[string]interface{}{ + log.LogDebugWithFields("client", "Adding tool", map[string]any{ "server": c.name, "tool": tool.Name, "description": tool.Description, @@ -251,7 +251,7 @@ func (c *Client) addToolsToServer( userEmail, serverName, setupBaseURL, - tokenSetup, + userAuth, ) } else { handler = c.client.CallTool @@ -275,14 +275,14 @@ func (c *Client) addToolsToServer( if len(sessionTools) > 0 { sessionWithTools.SetSessionTools(sessionTools) - log.LogInfoWithFields("client", "Registered session-specific tools", map[string]interface{}{ + log.LogInfoWithFields("client", "Registered session-specific tools", map[string]any{ "server": c.name, "sessionID": session.SessionID(), "toolCount": len(sessionTools), }) } - log.LogInfoWithFields("client", "Tool discovery completed", map[string]interface{}{ + log.LogInfoWithFields("client", "Tool discovery completed", map[string]any{ "server": c.name, "totalTools": totalTools, }) @@ -291,7 +291,7 @@ func (c *Client) addToolsToServer( } func (c *Client) addPromptsToServer(ctx context.Context, mcpServer *server.MCPServer) error { - log.LogInfoWithFields("client", "Starting prompt discovery", map[string]interface{}{ + log.LogInfoWithFields("client", "Starting prompt discovery", map[string]any{ "server": c.name, }) @@ -300,7 +300,7 @@ func (c *Client) addPromptsToServer(ctx context.Context, mcpServer *server.MCPSe for { prompts, err := c.client.ListPrompts(ctx, promptsRequest) if err != nil { - log.LogErrorWithFields("client", "Failed to list prompts", map[string]interface{}{ + log.LogErrorWithFields("client", "Failed to list prompts", map[string]any{ "server": c.name, "error": err.Error(), }) @@ -321,7 +321,7 @@ func (c *Client) addPromptsToServer(ctx context.Context, mcpServer *server.MCPSe promptsRequest.Params.Cursor = prompts.NextCursor } - log.LogInfoWithFields("client", "Prompt discovery completed", map[string]interface{}{ + log.LogInfoWithFields("client", "Prompt discovery completed", map[string]any{ "server": c.name, "totalPrompts": totalPrompts, }) @@ -330,7 +330,7 @@ func (c *Client) addPromptsToServer(ctx context.Context, mcpServer *server.MCPSe } func (c *Client) addResourcesToServer(ctx context.Context, mcpServer *server.MCPServer) error { - log.LogInfoWithFields("client", "Starting resource discovery", map[string]interface{}{ + log.LogInfoWithFields("client", "Starting resource discovery", map[string]any{ "server": c.name, }) @@ -339,7 +339,7 @@ func (c *Client) addResourcesToServer(ctx context.Context, mcpServer *server.MCP for { resources, err := c.client.ListResources(ctx, resourcesRequest) if err != nil { - log.LogErrorWithFields("client", "Failed to list resources", map[string]interface{}{ + log.LogErrorWithFields("client", "Failed to list resources", map[string]any{ "server": c.name, "error": err.Error(), }) @@ -367,7 +367,7 @@ func (c *Client) addResourcesToServer(ctx context.Context, mcpServer *server.MCP } - log.LogInfoWithFields("client", "Resource discovery completed", map[string]interface{}{ + log.LogInfoWithFields("client", "Resource discovery completed", map[string]any{ "server": c.name, "totalResources": totalResources, }) @@ -412,11 +412,11 @@ func (c *Client) wrapToolHandler( userEmail string, serverName string, setupBaseURL string, - tokenSetup *config.TokenSetupConfig, + userAuth *config.UserAuthentication, ) func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(toolCtx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Log tool invocation - log.LogInfoWithFields("client", "Tool invocation requested", map[string]interface{}{ + log.LogInfoWithFields("client", "Tool invocation requested", map[string]any{ "server": serverName, "tool": request.Params.Name, "user": userEmail, @@ -427,7 +427,7 @@ func (c *Client) wrapToolHandler( if userEmail == "" { // This shouldn't happen with proper config validation // (requiresUserToken requires OAuth to be configured) - log.LogErrorWithFields("client", "User token required but no user email provided", map[string]interface{}{ + log.LogErrorWithFields("client", "User token required but no user email provided", map[string]any{ "service": serverName, "tool": request.Params.Name, }) @@ -435,7 +435,6 @@ func (c *Client) wrapToolHandler( errorData := createTokenRequiredError( serverName, setupBaseURL, - tokenSetup, "configuration error: this service requires user tokens but OAuth is not properly configured.", ) @@ -449,14 +448,14 @@ func (c *Client) wrapToolHandler( tokenSetupURL := fmt.Sprintf("%s/my/tokens", setupBaseURL) var errorMessage string - if tokenSetup != nil { + if userAuth != nil { errorMessage = fmt.Sprintf( "token required: %s requires a user token to access the API. "+ "please visit %s to set up your %s token. %s", - tokenSetup.DisplayName, + userAuth.DisplayName, tokenSetupURL, - tokenSetup.DisplayName, - tokenSetup.Instructions, + userAuth.DisplayName, + userAuth.Instructions, ) } else { errorMessage = fmt.Sprintf( @@ -469,7 +468,6 @@ func (c *Client) wrapToolHandler( errorData := createTokenRequiredError( serverName, setupBaseURL, - tokenSetup, errorMessage, ) @@ -482,14 +480,14 @@ func (c *Client) wrapToolHandler( result, err := originalHandler(toolCtx, request) if err != nil { - log.LogErrorWithFields("client", "Tool invocation failed", map[string]interface{}{ + log.LogErrorWithFields("client", "Tool invocation failed", map[string]any{ "server": serverName, "tool": request.Params.Name, "user": userEmail, "error": err.Error(), }) } else { - log.LogInfoWithFields("client", "Tool invocation completed", map[string]interface{}{ + log.LogInfoWithFields("client", "Tool invocation completed", map[string]any{ "server": serverName, "tool": request.Params.Name, "user": userEmail, @@ -509,18 +507,18 @@ func (c *Client) Close() error { } // createTokenRequiredError creates the structured error for missing user tokens -func createTokenRequiredError(serverName, setupBaseURL string, tokenSetup *config.TokenSetupConfig, message string) map[string]interface{} { +func createTokenRequiredError(serverName, setupBaseURL string, message string) map[string]any { tokenSetupURL := fmt.Sprintf("%s/my/tokens", setupBaseURL) - return map[string]interface{}{ + return map[string]any{ "jsonrpc": "2.0", - "error": map[string]interface{}{ + "error": map[string]any{ "code": "token_required", "message": message, - "data": map[string]interface{}{ + "data": map[string]any{ "service": serverName, "tokenSetupUrl": tokenSetupURL, - "instructions": map[string]interface{}{ + "instructions": map[string]any{ "ai": "CRITICAL: You MUST pass these error details EXACTLY as provided to the human. " + "You are FORBIDDEN from modifying URLs or instructions. " + "Tell the human they need to visit the tokenSetupUrl to configure their token.", diff --git a/internal/client/session_manager.go b/internal/client/session_manager.go index cd09486..1c129f1 100644 --- a/internal/client/session_manager.go +++ b/internal/client/session_manager.go @@ -115,6 +115,7 @@ func (sm *StdioSessionManager) GetOrCreateSession( config *config.MCPClientConfig, info mcp.Implementation, baseURL string, + userToken string, ) (*StdioSession, error) { // Try to get existing session first if session, ok := sm.GetSession(key); ok { @@ -125,7 +126,7 @@ func (sm *StdioSessionManager) GetOrCreateSession( return nil, err } - return sm.createSession(ctx, key, config, info, baseURL) + return sm.createSession(key, config, userToken) } // GetSession retrieves an existing session @@ -139,7 +140,7 @@ func (sm *StdioSessionManager) GetSession(key SessionKey) (*StdioSession, bool) lastAccessed := session.lastAccessed.Load() session.lastAccessed.Store(&now) - log.LogTraceWithFields("session_manager", "Session accessed", map[string]interface{}{ + log.LogTraceWithFields("session_manager", "Session accessed", map[string]any{ "sessionID": key.SessionID, "server": key.ServerName, "user": key.UserEmail, @@ -150,7 +151,7 @@ func (sm *StdioSessionManager) GetSession(key SessionKey) (*StdioSession, bool) select { case <-session.ctx.Done(): // Process died, remove it - log.LogTraceWithFields("session_manager", "Session context cancelled, removing", map[string]interface{}{ + log.LogTraceWithFields("session_manager", "Session context cancelled, removing", map[string]any{ "sessionID": key.SessionID, "server": key.ServerName, "user": key.UserEmail, @@ -162,7 +163,7 @@ func (sm *StdioSessionManager) GetSession(key SessionKey) (*StdioSession, bool) } } - log.LogTraceWithFields("session_manager", "Session not found", map[string]interface{}{ + log.LogTraceWithFields("session_manager", "Session not found", map[string]any{ "sessionID": key.SessionID, "server": key.ServerName, "user": key.UserEmail, @@ -177,7 +178,7 @@ func (sm *StdioSessionManager) RemoveSession(key SessionKey) { session, ok := sm.sessions[key] if ok { delete(sm.sessions, key) - log.LogTraceWithFields("session_manager", "Removing session from map", map[string]interface{}{ + log.LogTraceWithFields("session_manager", "Removing session from map", map[string]any{ "sessionID": key.SessionID, "server": key.ServerName, "user": key.UserEmail, @@ -192,7 +193,7 @@ func (sm *StdioSessionManager) RemoveSession(key SessionKey) { // Close the client if err := session.client.Close(); err != nil { - log.LogErrorWithFields("session_manager", "Failed to close client", map[string]interface{}{ + log.LogErrorWithFields("session_manager", "Failed to close client", map[string]any{ "error": err.Error(), "sessionID": key.SessionID, "server": key.ServerName, @@ -200,13 +201,13 @@ func (sm *StdioSessionManager) RemoveSession(key SessionKey) { }) } - log.LogInfoWithFields("session_manager", "Removed session", map[string]interface{}{ + log.LogInfoWithFields("session_manager", "Removed session", map[string]any{ "sessionID": key.SessionID, "server": key.ServerName, "user": key.UserEmail, }) - log.LogTraceWithFields("session_manager", "Session removed with details", map[string]interface{}{ + log.LogTraceWithFields("session_manager", "Session removed with details", map[string]any{ "sessionID": key.SessionID, "server": key.ServerName, "user": key.UserEmail, @@ -251,7 +252,7 @@ func (s *StdioSession) DiscoverAndRegisterCapabilities( tokenStore storage.UserTokenStore, serverName string, setupBaseURL string, - tokenSetup *config.TokenSetupConfig, + userAuth *config.UserAuthentication, session server.ClientSession, ) error { // Initialize the client @@ -268,7 +269,7 @@ func (s *StdioSession) DiscoverAndRegisterCapabilities( Version: "1.0", } initRequest.Params.Capabilities = mcp.ClientCapabilities{ - Experimental: make(map[string]interface{}), + Experimental: make(map[string]any), Roots: nil, Sampling: nil, } @@ -280,20 +281,20 @@ func (s *StdioSession) DiscoverAndRegisterCapabilities( log.Logf("<%s> Successfully initialized MCP client", serverName) // Start capability discovery - log.LogInfoWithFields("client", "Starting MCP capability discovery", map[string]interface{}{ + log.LogInfoWithFields("client", "Starting MCP capability discovery", map[string]any{ "server": serverName, }) - log.LogTraceWithFields("client", "Starting capability discovery", map[string]interface{}{ + log.LogTraceWithFields("client", "Starting capability discovery", map[string]any{ "server": serverName, "sessionID": session.SessionID(), "userEmail": userEmail, "requiresUserToken": requiresToken, - "hasTokenSetup": tokenSetup != nil, + "hasTokenSetup": userAuth != nil, }) // Discover and register tools - if err := s.client.addToolsToServer(ctx, mcpServer, userEmail, requiresToken, tokenStore, serverName, setupBaseURL, tokenSetup, session); err != nil { + if err := s.client.addToolsToServer(ctx, mcpServer, userEmail, requiresToken, tokenStore, serverName, setupBaseURL, userAuth, session); err != nil { return err } @@ -306,12 +307,12 @@ func (s *StdioSession) DiscoverAndRegisterCapabilities( // Discover and register resource templates _ = s.client.addResourceTemplatesToServer(ctx, mcpServer) - log.LogInfoWithFields("client", "MCP capability discovery completed", map[string]interface{}{ + log.LogInfoWithFields("client", "MCP capability discovery completed", map[string]any{ "server": serverName, "userTokenRequired": requiresToken, }) - log.LogTraceWithFields("client", "Capability discovery completed", map[string]interface{}{ + log.LogTraceWithFields("client", "Capability discovery completed", map[string]any{ "server": serverName, "sessionID": session.SessionID(), "userEmail": userEmail, @@ -343,7 +344,7 @@ func (sm *StdioSessionManager) checkUserLimits(userEmail string) error { count := sm.getUserSessionCount(userEmail) if count >= sm.maxPerUser { - log.LogWarnWithFields("session_manager", "User session limit exceeded", map[string]interface{}{ + log.LogWarnWithFields("session_manager", "User session limit exceeded", map[string]any{ "user": userEmail, "count": count, "limit": sm.maxPerUser, @@ -352,7 +353,7 @@ func (sm *StdioSessionManager) checkUserLimits(userEmail string) error { ErrUserLimitExceeded, userEmail, count, sm.maxPerUser) } - log.LogTraceWithFields("session_manager", "User session limit check passed", map[string]interface{}{ + log.LogTraceWithFields("session_manager", "User session limit check passed", map[string]any{ "user": userEmail, "currentSessions": count, "maxPerUser": sm.maxPerUser, @@ -379,14 +380,21 @@ func (sm *StdioSessionManager) getUserSessionCount(userEmail string) int { // createSession creates a new stdio session func (sm *StdioSessionManager) createSession( - ctx context.Context, key SessionKey, config *config.MCPClientConfig, - info mcp.Implementation, - baseURL string, + userToken string, ) (*StdioSession, error) { + // Create an independent context for the stdio session. We intentionally use + // context.Background() instead of the HTTP request context because stdio + // sessions are long-lived processes that must persist across multiple HTTP + // requests. The session will be cleaned up by the timeout-based cleanup + // routine, not by HTTP request cancellation. sessionCtx, cancel := context.WithCancel(context.Background()) + if userToken != "" && config.RequiresUserToken { + config = config.ApplyUserToken(userToken) + } + client, err := sm.createClient(key.ServerName, config) if err != nil { cancel() @@ -409,13 +417,13 @@ func (sm *StdioSessionManager) createSession( sm.sessions[key] = session sm.mu.Unlock() - log.LogInfoWithFields("session_manager", "Created new session", map[string]interface{}{ + log.LogInfoWithFields("session_manager", "Created new session", map[string]any{ "sessionID": key.SessionID, "server": key.ServerName, "user": key.UserEmail, }) - log.LogTraceWithFields("session_manager", "Session created with details", map[string]interface{}{ + log.LogTraceWithFields("session_manager", "Session created with details", map[string]any{ "sessionID": key.SessionID, "server": key.ServerName, "user": key.UserEmail, @@ -465,7 +473,7 @@ func (sm *StdioSessionManager) cleanupTimedOutSessions() { sm.mu.RUnlock() if totalSessions > 0 || len(timedOut) > 0 { - log.LogTraceWithFields("session_manager", "Session cleanup cycle", map[string]interface{}{ + log.LogTraceWithFields("session_manager", "Session cleanup cycle", map[string]any{ "totalSessions": totalSessions, "activeSessions": activeSessions, "timedOutSessions": len(timedOut), @@ -474,7 +482,7 @@ func (sm *StdioSessionManager) cleanupTimedOutSessions() { } for _, key := range timedOut { - log.LogInfoWithFields("session_manager", "Removing timed out session", map[string]interface{}{ + log.LogInfoWithFields("session_manager", "Removing timed out session", map[string]any{ "sessionID": key.SessionID, "server": key.ServerName, "user": key.UserEmail, diff --git a/internal/client/session_manager_test.go b/internal/client/session_manager_test.go index 19154c0..74ba543 100644 --- a/internal/client/session_manager_test.go +++ b/internal/client/session_manager_test.go @@ -44,12 +44,12 @@ func TestStdioSessionManager_CreateAndRetrieve(t *testing.T) { } // Create session - session1, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost") + session1, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "") require.NoError(t, err) require.NotNil(t, session1) // Retrieve same session - session2, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost") + session2, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "") require.NoError(t, err) require.NotNil(t, session2) @@ -81,22 +81,22 @@ func TestStdioSessionManager_UserLimits(t *testing.T) { // Create first session key1 := SessionKey{UserEmail: userEmail, ServerName: "server", SessionID: "1"} - _, err := sm.GetOrCreateSession(context.Background(), key1, config, info, "http://localhost") + _, err := sm.GetOrCreateSession(context.Background(), key1, config, info, "http://localhost", "") require.NoError(t, err) // Create second session (at limit) key2 := SessionKey{UserEmail: userEmail, ServerName: "server", SessionID: "2"} - _, err = sm.GetOrCreateSession(context.Background(), key2, config, info, "http://localhost") + _, err = sm.GetOrCreateSession(context.Background(), key2, config, info, "http://localhost", "") require.NoError(t, err) // Try to create third session (should fail) key3 := SessionKey{UserEmail: userEmail, ServerName: "server", SessionID: "3"} - _, err = sm.GetOrCreateSession(context.Background(), key3, config, info, "http://localhost") + _, err = sm.GetOrCreateSession(context.Background(), key3, config, info, "http://localhost", "") assert.ErrorIs(t, err, ErrUserLimitExceeded) // Different user should work key4 := SessionKey{UserEmail: "other@example.com", ServerName: "server", SessionID: "4"} - _, err = sm.GetOrCreateSession(context.Background(), key4, config, info, "http://localhost") + _, err = sm.GetOrCreateSession(context.Background(), key4, config, info, "http://localhost", "") require.NoError(t, err) } @@ -122,7 +122,7 @@ func TestStdioSessionManager_RemoveSession(t *testing.T) { info := mcp.Implementation{Name: "test", Version: "1.0"} // Create session - session, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost") + session, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "") require.NoError(t, err) require.NotNil(t, session) @@ -161,7 +161,7 @@ func TestStdioSessionManager_Timeout(t *testing.T) { info := mcp.Implementation{Name: "test", Version: "1.0"} // Create session - session, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost") + session, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "") require.NoError(t, err) require.NotNil(t, session) @@ -192,7 +192,7 @@ func TestStdioSessionManager_ConcurrentAccess(t *testing.T) { // Run concurrent operations var wg sync.WaitGroup - for i := 0; i < 10; i++ { + for i := range 10 { wg.Add(1) go func(i int) { defer wg.Done() @@ -204,7 +204,7 @@ func TestStdioSessionManager_ConcurrentAccess(t *testing.T) { } // Create session - _, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost") + _, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "") assert.NoError(t, err) // Get session @@ -236,14 +236,14 @@ func TestStdioSessionManager_NoLimitsForAnonymous(t *testing.T) { info := mcp.Implementation{Name: "test", Version: "1.0"} // Create multiple anonymous sessions (empty userEmail) - for i := 0; i < 5; i++ { + for i := range 5 { key := SessionKey{ UserEmail: "", // Anonymous ServerName: "server", SessionID: fmt.Sprintf("session-%d", i), } - _, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost") + _, err := sm.GetOrCreateSession(context.Background(), key, config, info, "http://localhost", "") require.NoError(t, err, "Anonymous session %d should succeed", i) } } diff --git a/internal/config/load.go b/internal/config/load.go index 0c6bfa4..ba87dc0 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -10,47 +10,47 @@ import ( ) // Load loads and processes the config with immediate env var resolution -func Load(path string) (*Config, error) { +func Load(path string) (Config, error) { data, err := os.ReadFile(path) if err != nil { - return nil, fmt.Errorf("reading config file: %w", err) + return Config{}, fmt.Errorf("reading config file: %w", err) } - var rawConfig map[string]interface{} + var rawConfig map[string]any if err := json.Unmarshal(data, &rawConfig); err != nil { - return nil, fmt.Errorf("parsing config JSON: %w", err) + return Config{}, fmt.Errorf("parsing config JSON: %w", err) } version, ok := rawConfig["version"].(string) if !ok { - return nil, fmt.Errorf("config version is required") + return Config{}, fmt.Errorf("config version is required") } if !strings.HasPrefix(version, "v0.0.1-DEV_EDITION") { - return nil, fmt.Errorf("unsupported config version: %s", version) + return Config{}, fmt.Errorf("unsupported config version: %s", version) } if err := validateRawConfig(rawConfig); err != nil { - return nil, fmt.Errorf("config validation failed: %w", err) + return Config{}, fmt.Errorf("config validation failed: %w", err) } // Parse directly into typed Config struct // The custom UnmarshalJSON methods will resolve env vars immediately var config Config if err := json.Unmarshal(data, &config); err != nil { - return nil, fmt.Errorf("parsing config: %w", err) + return Config{}, fmt.Errorf("parsing config: %w", err) } if err := ValidateConfig(&config); err != nil { - return nil, fmt.Errorf("config validation failed: %w", err) + return Config{}, fmt.Errorf("config validation failed: %w", err) } - return &config, nil + return config, nil } // validateRawConfig validates the config structure before environment resolution -func validateRawConfig(rawConfig map[string]interface{}) error { - if proxy, ok := rawConfig["proxy"].(map[string]interface{}); ok { - if auth, ok := proxy["auth"].(map[string]interface{}); ok { +func validateRawConfig(rawConfig map[string]any) error { + if proxy, ok := rawConfig["proxy"].(map[string]any); ok { + if auth, ok := proxy["auth"].(map[string]any); ok { if kind, ok := auth["kind"].(string); ok && kind == "oauth" { secrets := []struct { name string @@ -68,7 +68,7 @@ func validateRawConfig(rawConfig map[string]interface{}) error { return fmt.Errorf("%s must use environment variable reference for security", secret.name) } // Verify it's an env ref - if refMap, isMap := value.(map[string]interface{}); isMap { + if refMap, isMap := value.(map[string]any); isMap { if _, hasEnv := refMap["$env"]; !hasEnv { return fmt.Errorf("%s must use {\"$env\": \"VAR_NAME\"} format", secret.name) } @@ -97,16 +97,13 @@ func ValidateConfig(config *Config) error { return fmt.Errorf("proxy.addr is required") } - if oauth, ok := config.Proxy.Auth.(*OAuthAuthConfig); ok { + if oauth := config.Proxy.Auth; oauth != nil { if err := validateOAuthConfig(oauth); err != nil { return fmt.Errorf("oauth config: %w", err) } } - hasOAuth := false - if _, ok := config.Proxy.Auth.(*OAuthAuthConfig); ok { - hasOAuth = true - } + hasOAuth := config.Proxy.Auth != nil for name, server := range config.MCPServers { if err := validateMCPServer(name, server); err != nil { @@ -204,9 +201,9 @@ func validateMCPServer(name string, server *MCPClientConfig) error { return fmt.Errorf("server %s has invalid transportType: %s", name, server.TransportType) } - // Validate token setup if required - if server.RequiresUserToken && server.TokenSetup == nil { - return fmt.Errorf("server %s requires user token but has no tokenSetup", name) + // Validate user authentication if required + if server.RequiresUserToken && server.UserAuthentication == nil { + return fmt.Errorf("server %s requires user token but has no userAuthentication", name) } return nil diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 1a6a03c..7e13dd4 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -25,7 +25,8 @@ func TestValidateConfig_UserTokensRequireOAuth(t *testing.T) { TransportType: MCPClientTypeSSE, URL: "https://notion.example.com", RequiresUserToken: true, - TokenSetup: &TokenSetupConfig{ + UserAuthentication: &UserAuthentication{ + Type: UserAuthTypeManual, DisplayName: "Notion", }, }, @@ -56,7 +57,8 @@ func TestValidateConfig_UserTokensRequireOAuth(t *testing.T) { TransportType: MCPClientTypeSSE, URL: "https://notion.example.com", RequiresUserToken: true, - TokenSetup: &TokenSetupConfig{ + UserAuthentication: &UserAuthentication{ + Type: UserAuthTypeManual, DisplayName: "Notion", }, }, diff --git a/internal/config/secret_test.go b/internal/config/secret_test.go new file mode 100644 index 0000000..b1a8613 --- /dev/null +++ b/internal/config/secret_test.go @@ -0,0 +1,112 @@ +package config + +import ( + "encoding/json" + "fmt" + "strings" + "testing" +) + +func TestSecretRedaction(t *testing.T) { + tests := []struct { + name string + secret Secret + want string + }{ + { + name: "non-empty secret", + secret: Secret("super-secret-password"), + want: "***", + }, + { + name: "empty secret", + secret: Secret(""), + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test String() method + if got := tt.secret.String(); got != tt.want { + t.Errorf("Secret.String() = %v, want %v", got, tt.want) + } + + // Test fmt.Sprintf behavior + formatted := fmt.Sprintf("value: %s", tt.secret) + expectedFormatted := "value: " + tt.want + if formatted != expectedFormatted { + t.Errorf("fmt.Sprintf = %v, want %v", formatted, expectedFormatted) + } + + // Test fmt.Printf (capture output) + output := fmt.Sprintf("password: %v", tt.secret) + if tt.secret != "" && strings.Contains(output, string(tt.secret)) { + t.Errorf("fmt.Printf leaked secret: %v", output) + } + }) + } +} + +func TestSecretJSONMarshal(t *testing.T) { + type configWithSecrets struct { + Username string `json:"username"` + Password Secret `json:"password"` + APIKey Secret `json:"apiKey"` + } + + cfg := configWithSecrets{ + Username: "admin", + Password: Secret("super-secret-password"), + APIKey: Secret("sk-1234567890abcdef"), + } + + data, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + jsonStr := string(data) + + // Check that secrets are redacted + if strings.Contains(jsonStr, "super-secret-password") { + t.Errorf("JSON contains unredacted password: %s", jsonStr) + } + if strings.Contains(jsonStr, "sk-1234567890abcdef") { + t.Errorf("JSON contains unredacted API key: %s", jsonStr) + } + + // Check that username is not redacted + if !strings.Contains(jsonStr, "admin") { + t.Errorf("JSON doesn't contain username: %s", jsonStr) + } + + // Check expected JSON structure + expected := `{"username":"admin","password":"***","apiKey":"***"}` + if jsonStr != expected { + t.Errorf("JSON = %s, want %s", jsonStr, expected) + } +} + +func TestSecretInStruct(t *testing.T) { + auth := ServiceAuth{ + Type: ServiceAuthTypeBasic, + Username: "testuser", + HashedPassword: Secret("$2a$10$abcdef..."), + UserToken: Secret("token-12345"), + } + + // Test struct string representation + str := fmt.Sprintf("%+v", auth) + if strings.Contains(str, "$2a$10$abcdef") { + t.Errorf("Struct representation leaked hashed password: %s", str) + } + if strings.Contains(str, "token-12345") { + t.Errorf("Struct representation leaked token: %s", str) + } + + // Individual field access should still redact + if auth.HashedPassword.String() != "***" { + t.Errorf("HashedPassword.String() = %v, want ***", auth.HashedPassword.String()) + } +} diff --git a/internal/config/types.go b/internal/config/types.go index 2ecd1af..c7216ff 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -8,6 +8,25 @@ import ( "time" ) +// Secret is a string type that redacts itself when printed +type Secret string + +// String implements fmt.Stringer to redact the secret +func (s Secret) String() string { + if s == "" { + return "" + } + return "***" +} + +// MarshalJSON implements json.Marshaler to prevent secrets in JSON logs +func (s Secret) MarshalJSON() ([]byte, error) { + if s == "" { + return json.Marshal("") + } + return json.Marshal("***") +} + // MCPClientType represents the transport type for MCP clients type MCPClientType string @@ -46,15 +65,6 @@ type Options struct { ToolFilter *ToolFilterConfig `json:"toolFilter,omitempty"` } -// TokenSetupConfig provides information for users to set up their tokens -type TokenSetupConfig struct { - DisplayName string `json:"displayName"` - Instructions string `json:"instructions"` - HelpURL string `json:"helpUrl,omitempty"` - TokenFormat string `json:"tokenFormat,omitempty"` - CompiledRegex *regexp.Regexp `json:"-"` -} - // ServiceAuthType represents the type of service authentication type ServiceAuthType string @@ -63,23 +73,65 @@ const ( ServiceAuthTypeBasic ServiceAuthType = "basic" ) +// UserAuthType represents the type of user authentication +type UserAuthType string + +const ( + // UserAuthTypeManual indicates that users manually provide API tokens/keys + // through the web UI. These tokens are stored encrypted and injected into + // MCP servers as configured. + UserAuthTypeManual UserAuthType = "manual" + + // UserAuthTypeOAuth indicates OAuth 2.0 authorization code flow is used. + // Users click "Connect with X" and are redirected to the service's OAuth + // consent page. The resulting access tokens are stored, automatically + // refreshed, and injected into MCP servers. + UserAuthTypeOAuth UserAuthType = "oauth" +) + // ServiceAuth represents authentication method for service-to-service communication type ServiceAuth struct { Type ServiceAuthType `json:"type"` // For basic auth - Username string `json:"username,omitempty"` - Password json.RawMessage `json:"password,omitempty"` + Username string `json:"username,omitempty"` + PasswordRaw json.RawMessage `json:"password,omitempty"` // For bearer auth Tokens []string `json:"tokens,omitempty"` // User token to inject when requiresUserToken is true - UserToken json.RawMessage `json:"userToken,omitempty"` + UserTokenRaw json.RawMessage `json:"userToken,omitempty"` + + // Computed fields + HashedPassword Secret `json:"-"` // bcrypt hash for basic auth + UserToken Secret `json:"-"` // parsed user token +} + +// UserAuthentication represents authentication configuration for end users +type UserAuthentication struct { + Type UserAuthType `json:"type"` + DisplayName string `json:"displayName"` + + // For OAuth + ClientIDRaw json.RawMessage `json:"clientId,omitempty"` + ClientSecretRaw json.RawMessage `json:"clientSecret,omitempty"` + AuthorizationURL string `json:"authorizationUrl,omitempty"` + TokenURL string `json:"tokenUrl,omitempty"` + Scopes []string `json:"scopes,omitempty"` + + // For Manual + Instructions string `json:"instructions,omitempty"` + HelpURL string `json:"helpUrl,omitempty"` + Validation string `json:"validation,omitempty"` + + // Common + TokenFormat string `json:"tokenFormat,omitempty"` // Computed fields - HashedPassword string `json:"-"` // bcrypt hash for basic auth - ResolvedUserToken string `json:"-"` // resolved user token + ClientID Secret `json:"-"` + ClientSecret Secret `json:"-"` + ValidationRegex *regexp.Regexp `json:"-"` } // MCPClientConfig represents the configuration for an MCP client after parsing. @@ -125,8 +177,8 @@ type MCPClientConfig struct { Options *Options `json:"options,omitempty"` // User token requirements - RequiresUserToken bool `json:"requiresUserToken,omitempty"` - TokenSetup *TokenSetupConfig `json:"tokenSetup,omitempty"` + RequiresUserToken bool `json:"requiresUserToken,omitempty"` + UserAuthentication *UserAuthentication `json:"userAuthentication,omitempty"` // Service-to-service authentication ServiceAuths []ServiceAuth `json:"serviceAuths,omitempty"` @@ -150,30 +202,30 @@ type AdminConfig struct { // OAuthAuthConfig represents OAuth 2.1 configuration with resolved values type OAuthAuthConfig struct { - Kind AuthKind `json:"kind"` - Issuer string `json:"issuer"` - GCPProject string `json:"gcpProject"` - AllowedDomains []string `json:"allowedDomains"` // For Google OAuth email validation - AllowedOrigins []string `json:"allowedOrigins"` // For CORS validation - TokenTTL string `json:"tokenTtl"` - Storage string `json:"storage"` // "memory" or "firestore" - FirestoreDatabase string `json:"firestoreDatabase,omitempty"` // Optional: Firestore database name - FirestoreCollection string `json:"firestoreCollection,omitempty"` // Optional: Firestore collection name - GoogleClientID string `json:"googleClientId"` - GoogleClientSecret string `json:"googleClientSecret"` - GoogleRedirectURI string `json:"googleRedirectUri"` - JWTSecret string `json:"jwtSecret"` - EncryptionKey string `json:"encryptionKey"` + Kind AuthKind `json:"kind"` + Issuer string `json:"issuer"` + GCPProject string `json:"gcpProject"` + AllowedDomains []string `json:"allowedDomains"` // For Google OAuth email validation + AllowedOrigins []string `json:"allowedOrigins"` // For CORS validation + TokenTTL time.Duration `json:"tokenTtl"` + Storage string `json:"storage"` // "memory" or "firestore" + FirestoreDatabase string `json:"firestoreDatabase,omitempty"` // Optional: Firestore database name + FirestoreCollection string `json:"firestoreCollection,omitempty"` // Optional: Firestore collection name + GoogleClientID string `json:"googleClientId"` + GoogleClientSecret Secret `json:"googleClientSecret"` + GoogleRedirectURI string `json:"googleRedirectUri"` + JWTSecret Secret `json:"jwtSecret"` + EncryptionKey Secret `json:"encryptionKey"` } // ProxyConfig represents the proxy configuration with resolved values type ProxyConfig struct { - BaseURL string `json:"baseURL"` - Addr string `json:"addr"` - Name string `json:"name"` - Auth interface{} `json:"-"` // OAuthAuthConfig or BearerTokenAuthConfig - Admin *AdminConfig `json:"admin,omitempty"` - Sessions *SessionConfig `json:"sessions,omitempty"` + BaseURL string `json:"baseURL"` + Addr string `json:"addr"` + Name string `json:"name"` + Auth *OAuthAuthConfig `json:"auth,omitempty"` // Only OAuth is supported + Admin *AdminConfig `json:"admin,omitempty"` + Sessions *SessionConfig `json:"sessions,omitempty"` } // Config represents the config structure with resolved values diff --git a/internal/config/unmarshal.go b/internal/config/unmarshal.go index ed2192b..438b3bf 100644 --- a/internal/config/unmarshal.go +++ b/internal/config/unmarshal.go @@ -7,7 +7,7 @@ import ( "strings" "time" - "github.com/dgellow/mcp-front/internal/utils" + emailutil "github.com/dgellow/mcp-front/internal/email" "golang.org/x/crypto/bcrypt" "github.com/dgellow/mcp-front/internal/log" @@ -17,18 +17,18 @@ import ( func (c *MCPClientConfig) UnmarshalJSON(data []byte) error { // Use a raw type to avoid recursion type rawConfig struct { - TransportType MCPClientType `json:"transportType,omitempty"` - Command json.RawMessage `json:"command,omitempty"` - Args []json.RawMessage `json:"args,omitempty"` - Env map[string]json.RawMessage `json:"env,omitempty"` - URL json.RawMessage `json:"url,omitempty"` - Headers map[string]json.RawMessage `json:"headers,omitempty"` - Timeout string `json:"timeout,omitempty"` - Options *Options `json:"options,omitempty"` - RequiresUserToken bool `json:"requiresUserToken,omitempty"` - TokenSetup *TokenSetupConfig `json:"tokenSetup,omitempty"` - ServiceAuths []ServiceAuth `json:"serviceAuths,omitempty"` - InlineConfig json.RawMessage `json:"inline,omitempty"` + TransportType MCPClientType `json:"transportType,omitempty"` + Command json.RawMessage `json:"command,omitempty"` + Args []json.RawMessage `json:"args,omitempty"` + Env map[string]json.RawMessage `json:"env,omitempty"` + URL json.RawMessage `json:"url,omitempty"` + Headers map[string]json.RawMessage `json:"headers,omitempty"` + Timeout string `json:"timeout,omitempty"` + Options *Options `json:"options,omitempty"` + RequiresUserToken bool `json:"requiresUserToken,omitempty"` + UserAuthentication *UserAuthentication `json:"userAuthentication,omitempty"` + ServiceAuths []ServiceAuth `json:"serviceAuths,omitempty"` + InlineConfig json.RawMessage `json:"inline,omitempty"` } var raw rawConfig @@ -39,7 +39,7 @@ func (c *MCPClientConfig) UnmarshalJSON(data []byte) error { c.TransportType = raw.TransportType c.Options = raw.Options c.RequiresUserToken = raw.RequiresUserToken - c.TokenSetup = raw.TokenSetup + c.UserAuthentication = raw.UserAuthentication c.ServiceAuths = raw.ServiceAuths c.InlineConfig = raw.InlineConfig @@ -108,15 +108,6 @@ func (c *MCPClientConfig) UnmarshalJSON(data []byte) error { c.HeadersNeedToken = needsToken } - // Compile token format regex if present - if c.TokenSetup != nil && c.TokenSetup.TokenFormat != "" { - regex, err := regexp.Compile(c.TokenSetup.TokenFormat) - if err != nil { - return fmt.Errorf("compiling token format regex: %w", err) - } - c.TokenSetup.CompiledRegex = regex - } - return nil } @@ -149,39 +140,96 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error { o.Kind = raw.Kind o.AllowedDomains = raw.AllowedDomains o.AllowedOrigins = raw.AllowedOrigins - o.TokenTTL = raw.TokenTTL o.Storage = raw.Storage o.FirestoreDatabase = raw.FirestoreDatabase o.FirestoreCollection = raw.FirestoreCollection - // Parse fields that can be references - fields := []struct { - name string - raw json.RawMessage - target *string - allowUserToken bool - }{ - {"issuer", raw.Issuer, &o.Issuer, false}, - {"gcpProject", raw.GCPProject, &o.GCPProject, false}, - {"googleClientId", raw.GoogleClientID, &o.GoogleClientID, false}, - {"googleClientSecret", raw.GoogleClientSecret, &o.GoogleClientSecret, false}, - {"googleRedirectUri", raw.GoogleRedirectURI, &o.GoogleRedirectURI, false}, - {"jwtSecret", raw.JWTSecret, &o.JWTSecret, false}, - {"encryptionKey", raw.EncryptionKey, &o.EncryptionKey, false}, - } - - for _, field := range fields { - if field.raw == nil { - continue - } - parsed, err := ParseConfigValue(field.raw) + // Parse TokenTTL duration + if raw.TokenTTL != "" { + tokenTTL, err := time.ParseDuration(raw.TokenTTL) + if err != nil { + return fmt.Errorf("parsing tokenTtl: %w", err) + } + o.TokenTTL = tokenTTL + } + + // Parse string fields + if raw.Issuer != nil { + parsed, err := ParseConfigValue(raw.Issuer) if err != nil { - return fmt.Errorf("parsing %s: %w", field.name, err) + return fmt.Errorf("parsing issuer: %w", err) } - if parsed.needsUserToken && !field.allowUserToken { - return fmt.Errorf("%s cannot be a user token reference", field.name) + if parsed.needsUserToken { + return fmt.Errorf("issuer cannot be a user token reference") } - *field.target = parsed.value + o.Issuer = parsed.value + } + + if raw.GCPProject != nil { + parsed, err := ParseConfigValue(raw.GCPProject) + if err != nil { + return fmt.Errorf("parsing gcpProject: %w", err) + } + if parsed.needsUserToken { + return fmt.Errorf("gcpProject cannot be a user token reference") + } + o.GCPProject = parsed.value + } + + if raw.GoogleClientID != nil { + parsed, err := ParseConfigValue(raw.GoogleClientID) + if err != nil { + return fmt.Errorf("parsing googleClientId: %w", err) + } + if parsed.needsUserToken { + return fmt.Errorf("googleClientId cannot be a user token reference") + } + o.GoogleClientID = parsed.value + } + + if raw.GoogleRedirectURI != nil { + parsed, err := ParseConfigValue(raw.GoogleRedirectURI) + if err != nil { + return fmt.Errorf("parsing googleRedirectUri: %w", err) + } + if parsed.needsUserToken { + return fmt.Errorf("googleRedirectUri cannot be a user token reference") + } + o.GoogleRedirectURI = parsed.value + } + + // Parse secret fields + if raw.GoogleClientSecret != nil { + parsed, err := ParseConfigValue(raw.GoogleClientSecret) + if err != nil { + return fmt.Errorf("parsing googleClientSecret: %w", err) + } + if parsed.needsUserToken { + return fmt.Errorf("googleClientSecret cannot be a user token reference") + } + o.GoogleClientSecret = Secret(parsed.value) + } + + if raw.JWTSecret != nil { + parsed, err := ParseConfigValue(raw.JWTSecret) + if err != nil { + return fmt.Errorf("parsing jwtSecret: %w", err) + } + if parsed.needsUserToken { + return fmt.Errorf("jwtSecret cannot be a user token reference") + } + o.JWTSecret = Secret(parsed.value) + } + + if raw.EncryptionKey != nil { + parsed, err := ParseConfigValue(raw.EncryptionKey) + if err != nil { + return fmt.Errorf("parsing encryptionKey: %w", err) + } + if parsed.needsUserToken { + return fmt.Errorf("encryptionKey cannot be a user token reference") + } + o.EncryptionKey = Secret(parsed.value) } // Validate JWT secret length @@ -224,8 +272,8 @@ func (p *ProxyConfig) UnmarshalJSON(data []byte) error { // Normalize admin emails for consistent comparison if p.Admin != nil && len(p.Admin.AdminEmails) > 0 { normalizedEmails := make([]string, len(p.Admin.AdminEmails)) - for i, email := range p.Admin.AdminEmails { - normalizedEmails[i] = utils.NormalizeEmail(email) + for i, emailAddr := range p.Admin.AdminEmails { + normalizedEmails[i] = emailutil.Normalize(emailAddr) } p.Admin.AdminEmails = normalizedEmails } @@ -347,34 +395,27 @@ func (c *MCPClientConfig) ApplyUserToken(userToken string) *MCPClientConfig { // UnmarshalJSON implements custom unmarshaling for ServiceAuth func (s *ServiceAuth) UnmarshalJSON(data []byte) error { - // First unmarshal without custom processing - type rawServiceAuth struct { - Type ServiceAuthType `json:"type"` - Username string `json:"username,omitempty"` - Password json.RawMessage `json:"password,omitempty"` - Tokens []string `json:"tokens,omitempty"` - UserToken json.RawMessage `json:"userToken,omitempty"` - } - + // Use type alias to avoid recursion + type rawServiceAuth ServiceAuth var raw rawServiceAuth + if err := json.Unmarshal(data, &raw); err != nil { return err } - log.LogTraceWithFields("config", "Unmarshaling service auth", map[string]interface{}{ - "type": raw.Type, - }) + // Copy all fields + *s = ServiceAuth(raw) - s.Type = raw.Type - s.Username = raw.Username - s.Tokens = raw.Tokens + log.LogTraceWithFields("config", "Unmarshaling service auth", map[string]any{ + "type": s.Type, + }) // Parse password if provided (for basic auth) - if raw.Password != nil { - log.LogTraceWithFields("config", "Parsing password for basic auth", map[string]interface{}{ + if s.PasswordRaw != nil { + log.LogTraceWithFields("config", "Parsing password for basic auth", map[string]any{ "username": s.Username, }) - parsed, err := ParseConfigValue(raw.Password) + parsed, err := ParseConfigValue(s.PasswordRaw) if err != nil { return fmt.Errorf("parsing password: %w", err) } @@ -383,29 +424,29 @@ func (s *ServiceAuth) UnmarshalJSON(data []byte) error { } // Hash the password using bcrypt - log.LogTraceWithFields("config", "Hashing password for basic auth", map[string]interface{}{ + log.LogTraceWithFields("config", "Hashing password for basic auth", map[string]any{ "username": s.Username, }) hashed, err := bcrypt.GenerateFromPassword([]byte(parsed.value), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("hashing password: %w", err) } - s.HashedPassword = string(hashed) + s.HashedPassword = Secret(hashed) } // Parse user token if provided - if raw.UserToken != nil { - log.LogTraceWithFields("config", "Parsing user token for service auth", map[string]interface{}{ + if s.UserTokenRaw != nil { + log.LogTraceWithFields("config", "Parsing user token for service auth", map[string]any{ "type": s.Type, }) - parsed, err := ParseConfigValue(raw.UserToken) + parsed, err := ParseConfigValue(s.UserTokenRaw) if err != nil { return fmt.Errorf("parsing userToken: %w", err) } if parsed.needsUserToken { return fmt.Errorf("userToken cannot be a user token reference") } - s.ResolvedUserToken = parsed.value + s.UserToken = Secret(parsed.value) } // Validate required fields based on type @@ -414,7 +455,7 @@ func (s *ServiceAuth) UnmarshalJSON(data []byte) error { if s.Username == "" { return fmt.Errorf("username is required for basic auth") } - if raw.Password == nil { + if s.PasswordRaw == nil { return fmt.Errorf("password is required for basic auth") } case ServiceAuthTypeBearer: @@ -428,6 +469,61 @@ func (s *ServiceAuth) UnmarshalJSON(data []byte) error { return nil } +// UnmarshalJSON implements custom unmarshaling for UserAuthentication +func (u *UserAuthentication) UnmarshalJSON(data []byte) error { + // First unmarshal to get the type + // Use type alias to avoid recursion + type rawAuth UserAuthentication + var raw rawAuth + + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Copy all fields + *u = UserAuthentication(raw) + + // Set default token format if not specified + if u.TokenFormat == "" { + u.TokenFormat = "{{token}}" + } + + switch u.Type { + case UserAuthTypeOAuth: + // Parse OAuth credentials + if u.ClientIDRaw != nil { + parsed, err := ParseConfigValue(u.ClientIDRaw) + if err != nil { + return fmt.Errorf("parsing clientId: %w", err) + } + u.ClientID = Secret(parsed.value) + } + + if u.ClientSecretRaw != nil { + parsed, err := ParseConfigValue(u.ClientSecretRaw) + if err != nil { + return fmt.Errorf("parsing clientSecret: %w", err) + } + u.ClientSecret = Secret(parsed.value) + } + + case UserAuthTypeManual: + // Compile validation regex if present + if u.Validation != "" { + regex, err := regexp.Compile(u.Validation) + if err != nil { + return fmt.Errorf("compiling validation regex: %w", err) + } + u.ValidationRegex = regex + } + + default: + return fmt.Errorf("unknown user auth type: %s", u.Type) + } + + return nil +} + // UnmarshalJSON implements custom unmarshaling for SessionConfig func (s *SessionConfig) UnmarshalJSON(data []byte) error { var raw struct { diff --git a/internal/config/unmarshal_test.go b/internal/config/unmarshal_test.go index 48ac83d..dc0682a 100644 --- a/internal/config/unmarshal_test.go +++ b/internal/config/unmarshal_test.go @@ -129,10 +129,11 @@ func TestMCPClientConfig_UnmarshalJSON(t *testing.T) { "AUTH_HEADER": {"$userToken": "Bearer {{token}}"} }, "requiresUserToken": true, - "tokenSetup": { + "userAuthentication": { + "type": "manual", "displayName": "Test Token", "instructions": "Enter your test token", - "tokenFormat": "^test_[a-z]+$" + "validation": "^test_[a-z]+$" } }` @@ -155,13 +156,14 @@ func TestMCPClientConfig_UnmarshalJSON(t *testing.T) { "AUTH_HEADER": true, }, config.EnvNeedsToken) - // Check token setup + // Check user authentication assert.True(t, config.RequiresUserToken) - assert.NotNil(t, config.TokenSetup) - assert.Equal(t, "Test Token", config.TokenSetup.DisplayName) - assert.NotNil(t, config.TokenSetup.CompiledRegex) - assert.True(t, config.TokenSetup.CompiledRegex.MatchString("test_abc")) - assert.False(t, config.TokenSetup.CompiledRegex.MatchString("test_123")) + assert.NotNil(t, config.UserAuthentication) + assert.Equal(t, UserAuthTypeManual, config.UserAuthentication.Type) + assert.Equal(t, "Test Token", config.UserAuthentication.DisplayName) + assert.NotNil(t, config.UserAuthentication.ValidationRegex) + assert.True(t, config.UserAuthentication.ValidationRegex.MatchString("test_abc")) + assert.False(t, config.UserAuthentication.ValidationRegex.MatchString("test_123")) } func TestMCPClientConfig_ApplyUserToken(t *testing.T) { @@ -270,9 +272,9 @@ func TestOAuthAuthConfig_UnmarshalJSON(t *testing.T) { assert.Equal(t, []string{"example.com"}, config.AllowedDomains) assert.Equal(t, []string{"https://claude.ai", "https://example.com"}, config.AllowedOrigins) assert.Equal(t, "test-client-id", config.GoogleClientID) - assert.Equal(t, "test-secret-value", config.GoogleClientSecret) - assert.Equal(t, "this-is-a-very-long-jwt-secret-key", config.JWTSecret) - assert.Equal(t, "exactly-32-bytes-long-encryptkey", config.EncryptionKey) + assert.Equal(t, Secret("test-secret-value"), config.GoogleClientSecret) + assert.Equal(t, Secret("this-is-a-very-long-jwt-secret-key"), config.JWTSecret) + assert.Equal(t, Secret("exactly-32-bytes-long-encryptkey"), config.EncryptionKey) } func TestOAuthAuthConfig_ValidationErrors(t *testing.T) { diff --git a/internal/config/validation.go b/internal/config/validation.go index 412e53c..e4b4fb9 100644 --- a/internal/config/validation.go +++ b/internal/config/validation.go @@ -35,7 +35,7 @@ func ValidateFile(path string) (*ValidationResult, error) { } // Check JSON syntax - var rawConfig map[string]interface{} + var rawConfig map[string]any if err := json.Unmarshal(data, &rawConfig); err != nil { result.Errors = append(result.Errors, ValidationError{ Message: fmt.Sprintf("invalid JSON: %v", err), @@ -51,12 +51,12 @@ func ValidateFile(path string) (*ValidationResult, error) { if !ok { result.Errors = append(result.Errors, ValidationError{ Path: "version", - Message: "version field is required", + Message: "version field is required. Hint: Add \"version\": \"v0.0.1-DEV_EDITION\"", }) } else if !strings.HasPrefix(version, "v0.0.1-DEV_EDITION") { result.Errors = append(result.Errors, ValidationError{ Path: "version", - Message: fmt.Sprintf("unsupported version: %s", version), + Message: fmt.Sprintf("unsupported version '%s' - use 'v0.0.1-DEV_EDITION' or 'v0.0.1-DEV_EDITION-'", version), }) } @@ -70,8 +70,8 @@ func ValidateFile(path string) (*ValidationResult, error) { } // validateProxyStructure checks the proxy configuration structure -func validateProxyStructure(rawConfig map[string]interface{}, result *ValidationResult) { - proxy, ok := rawConfig["proxy"].(map[string]interface{}) +func validateProxyStructure(rawConfig map[string]any, result *ValidationResult) { + proxy, ok := rawConfig["proxy"].(map[string]any) if !ok { result.Errors = append(result.Errors, ValidationError{ Path: "proxy", @@ -84,29 +84,29 @@ func validateProxyStructure(rawConfig map[string]interface{}, result *Validation if _, ok := proxy["baseURL"]; !ok { result.Errors = append(result.Errors, ValidationError{ Path: "proxy.baseURL", - Message: "baseURL is required", + Message: "baseURL is required. Example: \"https://api.example.com\"", }) } if _, ok := proxy["addr"]; !ok { result.Errors = append(result.Errors, ValidationError{ Path: "proxy.addr", - Message: "addr is required", + Message: "addr is required. Example: \":8080\" or \"0.0.0.0:8080\"", }) } // Check auth if present - if auth, ok := proxy["auth"].(map[string]interface{}); ok { + if auth, ok := proxy["auth"].(map[string]any); ok { validateAuthStructure(auth, result) } // Check admin if present - if admin, ok := proxy["admin"].(map[string]interface{}); ok { + if admin, ok := proxy["admin"].(map[string]any); ok { validateAdminStructure(admin, result) // If admin is enabled, ensure OAuth is configured if enabled, ok := admin["enabled"].(bool); ok && enabled { hasOAuth := false - if auth, ok := proxy["auth"].(map[string]interface{}); ok { + if auth, ok := proxy["auth"].(map[string]any); ok { if kind, ok := auth["kind"].(string); ok && kind == "oauth" { hasOAuth = true } @@ -122,12 +122,12 @@ func validateProxyStructure(rawConfig map[string]interface{}, result *Validation } // validateAuthStructure checks auth configuration structure -func validateAuthStructure(auth map[string]interface{}, result *ValidationResult) { +func validateAuthStructure(auth map[string]any, result *ValidationResult) { kind, ok := auth["kind"].(string) if !ok { result.Errors = append(result.Errors, ValidationError{ Path: "proxy.auth.kind", - Message: "auth kind is required", + Message: "auth kind is required. Use \"oauth\" for Google OAuth authentication", }) return } @@ -158,13 +158,13 @@ func validateAuthStructure(auth map[string]interface{}, result *ValidationResult }) } } - if domains, ok := auth["allowedDomains"].([]interface{}); !ok || len(domains) == 0 { + if domains, ok := auth["allowedDomains"].([]any); !ok || len(domains) == 0 { result.Errors = append(result.Errors, ValidationError{ Path: "proxy.auth.allowedDomains", Message: "at least one allowed domain is required for OAuth", }) } - if origins, ok := auth["allowedOrigins"].([]interface{}); !ok || len(origins) == 0 { + if origins, ok := auth["allowedOrigins"].([]any); !ok || len(origins) == 0 { result.Errors = append(result.Errors, ValidationError{ Path: "proxy.auth.allowedOrigins", Message: "at least one allowed origin is required for OAuth (CORS configuration)", @@ -173,13 +173,13 @@ func validateAuthStructure(auth map[string]interface{}, result *ValidationResult default: result.Errors = append(result.Errors, ValidationError{ Path: "proxy.auth.kind", - Message: fmt.Sprintf("unknown auth kind: %s (only 'oauth' is supported for proxy auth)", kind), + Message: fmt.Sprintf("unknown auth kind '%s' - only 'oauth' is supported for proxy auth", kind), }) } } // validateAdminStructure checks admin configuration structure -func validateAdminStructure(admin map[string]interface{}, result *ValidationResult) { +func validateAdminStructure(admin map[string]any, result *ValidationResult) { enabled, ok := admin["enabled"].(bool) if !ok { result.Errors = append(result.Errors, ValidationError{ @@ -191,7 +191,7 @@ func validateAdminStructure(admin map[string]interface{}, result *ValidationResu if enabled { // Check adminEmails when enabled - emails, ok := admin["adminEmails"].([]interface{}) + emails, ok := admin["adminEmails"].([]any) if !ok { result.Errors = append(result.Errors, ValidationError{ Path: "proxy.admin.adminEmails", @@ -208,7 +208,7 @@ func validateAdminStructure(admin map[string]interface{}, result *ValidationResu if _, ok := email.(string); !ok { result.Errors = append(result.Errors, ValidationError{ Path: fmt.Sprintf("proxy.admin.adminEmails[%d]", i), - Message: "admin email must be a string", + Message: fmt.Sprintf("admin email must be a string, got %T - use \"user@example.com\" format", email), }) } } @@ -220,8 +220,8 @@ func validateAdminStructure(admin map[string]interface{}, result *ValidationResu } // validateServersStructure checks MCP servers configuration -func validateServersStructure(rawConfig map[string]interface{}, result *ValidationResult) { - servers, ok := rawConfig["mcpServers"].(map[string]interface{}) +func validateServersStructure(rawConfig map[string]any, result *ValidationResult) { + servers, ok := rawConfig["mcpServers"].(map[string]any) if !ok { result.Errors = append(result.Errors, ValidationError{ Path: "mcpServers", @@ -231,8 +231,8 @@ func validateServersStructure(rawConfig map[string]interface{}, result *Validati } hasOAuth := false - if proxy, ok := rawConfig["proxy"].(map[string]interface{}); ok { - if auth, ok := proxy["auth"].(map[string]interface{}); ok { + if proxy, ok := rawConfig["proxy"].(map[string]any); ok { + if auth, ok := proxy["auth"].(map[string]any); ok { if kind, ok := auth["kind"].(string); ok && kind == "oauth" { hasOAuth = true } @@ -240,7 +240,7 @@ func validateServersStructure(rawConfig map[string]interface{}, result *Validati } for name, server := range servers { - srv, ok := server.(map[string]interface{}) + srv, ok := server.(map[string]any) if !ok { result.Errors = append(result.Errors, ValidationError{ Path: fmt.Sprintf("mcpServers.%s", name), @@ -254,7 +254,7 @@ func validateServersStructure(rawConfig map[string]interface{}, result *Validati if !ok { result.Errors = append(result.Errors, ValidationError{ Path: fmt.Sprintf("mcpServers.%s.transportType", name), - Message: "transportType is required", + Message: "transportType is required. Options: stdio, sse, streamable-http, inline", }) continue } @@ -265,7 +265,7 @@ func validateServersStructure(rawConfig map[string]interface{}, result *Validati if _, ok := srv["command"]; !ok { result.Errors = append(result.Errors, ValidationError{ Path: fmt.Sprintf("mcpServers.%s.command", name), - Message: "command is required for stdio transport", + Message: "command is required for stdio transport. Example: [\"npx\", \"-y\", \"@your/mcp-server\"]", }) } case "sse", "streamable-http": @@ -285,7 +285,7 @@ func validateServersStructure(rawConfig map[string]interface{}, result *Validati default: result.Errors = append(result.Errors, ValidationError{ Path: fmt.Sprintf("mcpServers.%s.transportType", name), - Message: fmt.Sprintf("invalid transportType: %s", transportType), + Message: fmt.Sprintf("invalid transportType '%s' - supported types: stdio, sse, streamable-http, inline", transportType), }) } @@ -297,16 +297,18 @@ func validateServersStructure(rawConfig map[string]interface{}, result *Validati Message: "server requires user token but OAuth is not configured. Hint: User tokens require OAuth authentication - set proxy.auth.kind to 'oauth'", }) } - if _, ok := srv["tokenSetup"]; !ok { + if userAuth, ok := srv["userAuthentication"]; !ok { result.Errors = append(result.Errors, ValidationError{ - Path: fmt.Sprintf("mcpServers.%s.tokenSetup", name), - Message: "tokenSetup is required when requiresUserToken is true. Hint: Add tokenSetup with displayName and instructions for users to obtain their token", + Path: fmt.Sprintf("mcpServers.%s.userAuthentication", name), + Message: "userAuthentication is required when requiresUserToken is true. Hint: Add userAuthentication with type, displayName and instructions", }) + } else { + validateUserAuthentication(userAuth, fmt.Sprintf("mcpServers.%s.userAuthentication", name), result) } } // Check service auth configuration - if serviceAuths, ok := srv["serviceAuths"].([]interface{}); ok { + if serviceAuths, ok := srv["serviceAuths"].([]any); ok { requiresUserToken := false if requiresToken, ok := srv["requiresUserToken"].(bool); ok { requiresUserToken = requiresToken @@ -317,9 +319,9 @@ func validateServersStructure(rawConfig map[string]interface{}, result *Validati } // validateServiceAuths validates service authentication configuration -func validateServiceAuths(serviceAuths []interface{}, serverName string, requiresUserToken bool, result *ValidationResult) { +func validateServiceAuths(serviceAuths []any, serverName string, requiresUserToken bool, result *ValidationResult) { for i, authInterface := range serviceAuths { - auth, ok := authInterface.(map[string]interface{}) + auth, ok := authInterface.(map[string]any) if !ok { result.Errors = append(result.Errors, ValidationError{ Path: fmt.Sprintf("mcpServers.%s.serviceAuths[%d]", serverName, i), @@ -332,7 +334,7 @@ func validateServiceAuths(serviceAuths []interface{}, serverName string, require if !ok { result.Errors = append(result.Errors, ValidationError{ Path: fmt.Sprintf("mcpServers.%s.serviceAuths[%d].type", serverName, i), - Message: "service auth type is required", + Message: "service auth type is required. Options: basic, bearer", }) continue } @@ -355,7 +357,7 @@ func validateServiceAuths(serviceAuths []interface{}, serverName string, require validatePasswordReference(auth["password"], fmt.Sprintf("mcpServers.%s.serviceAuths[%d].password", serverName, i), result) } case "bearer": - tokens, ok := auth["tokens"].([]interface{}) + tokens, ok := auth["tokens"].([]any) if !ok { result.Errors = append(result.Errors, ValidationError{ Path: fmt.Sprintf("mcpServers.%s.serviceAuths[%d].tokens", serverName, i), @@ -370,7 +372,7 @@ func validateServiceAuths(serviceAuths []interface{}, serverName string, require default: result.Errors = append(result.Errors, ValidationError{ Path: fmt.Sprintf("mcpServers.%s.serviceAuths[%d].type", serverName, i), - Message: fmt.Sprintf("unknown service auth type: %s (supported: basic, bearer)", authType), + Message: fmt.Sprintf("unknown service auth type '%s' - supported types: basic, bearer", authType), }) } @@ -379,7 +381,7 @@ func validateServiceAuths(serviceAuths []interface{}, serverName string, require if _, ok := auth["userToken"]; !ok { result.Errors = append(result.Errors, ValidationError{ Path: fmt.Sprintf("mcpServers.%s.serviceAuths[%d].userToken", serverName, i), - Message: "userToken is required when server has requiresUserToken: true", + Message: "userToken is required when server has requiresUserToken: true. Hint: Use {\"$userToken\": \"{{token}}\"} to inject user's token", }) } else { validateUserTokenReference(auth["userToken"], fmt.Sprintf("mcpServers.%s.serviceAuths[%d].userToken", serverName, i), result) @@ -388,54 +390,184 @@ func validateServiceAuths(serviceAuths []interface{}, serverName string, require } } +// validateEnvVarReference validates that a field uses proper env var reference format +func validateEnvVarReference(value any, fieldName, path string) *ValidationError { + switch v := value.(type) { + case string: + // Check if it looks like a bash-style env var + bashStyleRegex := regexp.MustCompile(`\$\{?([A-Z_][A-Z0-9_]*)\}?`) + if matches := bashStyleRegex.FindStringSubmatch(v); len(matches) > 1 { + varName := matches[1] + return &ValidationError{ + Path: path, + Message: fmt.Sprintf("found bash-style syntax '%s' - use {\"$env\": \"%s\"} instead. Hint: JSON syntax prevents accidental shell expansion and ensures security", v, varName), + } + } + // Plain string value + return &ValidationError{ + Path: path, + Message: fmt.Sprintf("%s must use environment variable reference {\"$env\": \"YOUR_ENV_VAR\"} instead of plain text '%s'. Hint: This prevents secrets from being stored in config files", fieldName, v), + } + case map[string]any: + if _, hasEnv := v["$env"]; !hasEnv { + return &ValidationError{ + Path: path, + Message: fmt.Sprintf("%s must use {\"$env\": \"YOUR_ENV_VAR\"} format, not %v", fieldName, v), + } + } + // Valid env reference + return nil + default: + return &ValidationError{ + Path: path, + Message: fmt.Sprintf("%s must be an environment variable reference {\"$env\": \"YOUR_ENV_VAR\"}, not %T", fieldName, value), + } + } +} + // validatePasswordReference validates that password uses env var reference -func validatePasswordReference(password interface{}, path string, result *ValidationResult) { - switch p := password.(type) { +func validatePasswordReference(password any, path string, result *ValidationResult) { + if err := validateEnvVarReference(password, "password", path); err != nil { + result.Errors = append(result.Errors, *err) + } +} + +// validateUserTokenReference validates that userToken uses proper reference format +func validateUserTokenReference(userToken any, path string, result *ValidationResult) { + switch v := userToken.(type) { case string: + // Plain string is not allowed for userToken result.Errors = append(result.Errors, ValidationError{ Path: path, - Message: "password must use environment variable reference {\"$env\": \"VAR_NAME\"} for security", + Message: fmt.Sprintf("userToken must use {\"$userToken\": \"{{token}}\"} format instead of plain text '%s'. Hint: This injects the user's authenticated token at runtime", v), }) - case map[string]interface{}: - if _, hasEnv := p["$env"]; !hasEnv { - result.Errors = append(result.Errors, ValidationError{ - Path: path, - Message: "password must use {\"$env\": \"VAR_NAME\"} format", - }) + case map[string]any: + if _, hasUserToken := v["$userToken"]; !hasUserToken { + // Check if they're trying to use env var syntax + if _, hasEnv := v["$env"]; hasEnv { + result.Errors = append(result.Errors, ValidationError{ + Path: path, + Message: "userToken cannot use {\"$env\": \"...\"} syntax - use {\"$userToken\": \"{{token}}\"} to inject user's authenticated token", + }) + } else { + result.Errors = append(result.Errors, ValidationError{ + Path: path, + Message: fmt.Sprintf("userToken must use {\"$userToken\": \"{{token}}\"} format, not %v", v), + }) + } } + // Valid userToken reference default: result.Errors = append(result.Errors, ValidationError{ Path: path, - Message: "password must be an environment variable reference", + Message: fmt.Sprintf("userToken must be a reference object {\"$userToken\": \"{{token}}\"}, not %T", userToken), }) } } -// validateUserTokenReference validates that userToken uses env var reference -func validateUserTokenReference(userToken interface{}, path string, result *ValidationResult) { - switch t := userToken.(type) { - case string: +// validateUserAuthentication validates user authentication configuration +func validateUserAuthentication(userAuth any, path string, result *ValidationResult) { + auth, ok := userAuth.(map[string]any) + if !ok { result.Errors = append(result.Errors, ValidationError{ Path: path, - Message: "userToken must use environment variable reference {\"$env\": \"VAR_NAME\"} for security", + Message: "userAuthentication must be an object", + }) + return + } + + // Check required type field + authType, ok := auth["type"].(string) + if !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: path + ".type", + Message: "type is required. Options: oauth (for automated OAuth flow) or manual (for user-provided tokens)", }) - case map[string]interface{}: - if _, hasEnv := t["$env"]; !hasEnv { + return + } + + // Validate based on type + switch authType { + case "oauth": + // OAuth requires OAuth configuration + if oauth, ok := auth["oauth"].(map[string]any); !ok { result.Errors = append(result.Errors, ValidationError{ - Path: path, - Message: "userToken must use {\"$env\": \"VAR_NAME\"} format", + Path: path + ".oauth", + Message: "oauth configuration is required when type is oauth", + }) + } else { + validateOAuthServiceConfig(oauth, path+".oauth", result) + } + case "manual": + // Manual requires displayName and instructions + if _, ok := auth["displayName"]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: path + ".displayName", + Message: "displayName is required for manual authentication. Example: \"GitHub Personal Access Token\"", + }) + } + if _, ok := auth["instructions"]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: path + ".instructions", + Message: "instructions are required for manual authentication. Example: \"Create a token at https://github.com/settings/tokens\"", }) } default: result.Errors = append(result.Errors, ValidationError{ - Path: path, - Message: "userToken must be an environment variable reference", + Path: path + ".type", + Message: fmt.Sprintf("invalid authentication type '%s' - must be 'oauth' or 'manual'", authType), + }) + } +} + +// validateOAuthServiceConfig validates OAuth service configuration +func validateOAuthServiceConfig(oauth map[string]any, path string, result *ValidationResult) { + // Check required fields + requiredFields := []string{"clientId", "authorizationUrl", "tokenUrl", "scopes"} + for _, field := range requiredFields { + if _, ok := oauth[field]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: path + "." + field, + Message: fmt.Sprintf("%s is required for OAuth configuration", field), + }) + } + } + + // Validate scopes is an array + if scopes, ok := oauth["scopes"].([]any); ok { + if len(scopes) == 0 { + result.Errors = append(result.Errors, ValidationError{ + Path: path + ".scopes", + Message: "at least one scope is required", + }) + } + } else { + result.Errors = append(result.Errors, ValidationError{ + Path: path + ".scopes", + Message: "scopes must be an array", }) } + + // Check client secret uses env var + if clientSecret, ok := oauth["clientSecret"]; ok { + validateSecretReference(clientSecret, path+".clientSecret", result) + } else { + result.Errors = append(result.Errors, ValidationError{ + Path: path + ".clientSecret", + Message: "clientSecret is required for OAuth configuration", + }) + } +} + +// validateSecretReference validates that secret uses env var reference +func validateSecretReference(secret any, path string, result *ValidationResult) { + if err := validateEnvVarReference(secret, "clientSecret", path); err != nil { + result.Errors = append(result.Errors, *err) + } } // checkBashStyleSyntax recursively checks for bash-style env var syntax -func checkBashStyleSyntax(value interface{}, path string, result *ValidationResult) { +func checkBashStyleSyntax(value any, path string, result *ValidationResult) { bashStyleRegex := regexp.MustCompile(`\$\{?[A-Z_][A-Z0-9_]*\}?`) switch v := value.(type) { @@ -449,7 +581,7 @@ func checkBashStyleSyntax(value interface{}, path string, result *ValidationResu }) } } - case map[string]interface{}: + case map[string]any: // Skip if this is already an env/userToken ref if _, hasEnv := v["$env"]; hasEnv { return @@ -467,7 +599,7 @@ func checkBashStyleSyntax(value interface{}, path string, result *ValidationResu } checkBashStyleSyntax(val, newPath, result) } - case []interface{}: + case []any: for i, item := range v { newPath := fmt.Sprintf("%s[%d]", path, i) checkBashStyleSyntax(item, newPath, result) diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go index eb11ea2..084bc13 100644 --- a/internal/config/validation_test.go +++ b/internal/config/validation_test.go @@ -79,7 +79,7 @@ func TestValidateFile(t *testing.T) { }, "mcpServers": {} }`, - wantErrors: []string{"version field is required"}, + wantErrors: []string{"version field is required. Hint: Add \"version\": \"v0.0.1-DEV_EDITION\""}, wantErrCount: 1, }, { @@ -92,7 +92,7 @@ func TestValidateFile(t *testing.T) { }, "mcpServers": {} }`, - wantErrors: []string{"unsupported version: v2.0.0"}, + wantErrors: []string{"unsupported version 'v2.0.0' - use 'v0.0.1-DEV_EDITION' or 'v0.0.1-DEV_EDITION-'"}, wantErrCount: 1, }, { @@ -112,8 +112,8 @@ func TestValidateFile(t *testing.T) { "mcpServers": {} }`, wantErrors: []string{ - "baseURL is required", - "addr is required", + "baseURL is required. Example: \"https://api.example.com\"", + "addr is required. Example: \":8080\" or \"0.0.0.0:8080\"", }, wantErrCount: 2, }, @@ -156,7 +156,7 @@ func TestValidateFile(t *testing.T) { } } }`, - wantErrors: []string{"transportType is required"}, + wantErrors: []string{"transportType is required. Options: stdio, sse, streamable-http, inline"}, wantErrCount: 1, }, { @@ -173,7 +173,7 @@ func TestValidateFile(t *testing.T) { } } }`, - wantErrors: []string{"command is required for stdio transport"}, + wantErrors: []string{"command is required for stdio transport. Example: [\"npx\", \"-y\", \"@your/mcp-server\"]"}, wantErrCount: 1, }, { @@ -211,7 +211,7 @@ func TestValidateFile(t *testing.T) { }`, wantErrors: []string{ "server requires user token but OAuth is not configured. Hint: User tokens require OAuth authentication - set proxy.auth.kind to 'oauth'", - "tokenSetup is required when requiresUserToken is true. Hint: Add tokenSetup with displayName and instructions for users to obtain their token", + "userAuthentication is required when requiresUserToken is true. Hint: Add userAuthentication with type, displayName and instructions", }, wantErrCount: 2, }, @@ -242,7 +242,7 @@ func TestValidateFile(t *testing.T) { } } }`, - wantErrors: []string{"tokenSetup is required when requiresUserToken is true. Hint: Add tokenSetup with displayName and instructions for users to obtain their token"}, + wantErrors: []string{"userAuthentication is required when requiresUserToken is true. Hint: Add userAuthentication with type, displayName and instructions"}, wantErrCount: 1, }, { @@ -270,6 +270,125 @@ func TestValidateFile(t *testing.T) { }, wantErrCount: 8, }, + { + name: "valid_manual_user_authentication", + config: `{ + "version": "v0.0.1-DEV_EDITION", + "proxy": { + "baseURL": "http://localhost:8080", + "addr": ":8080", + "auth": { + "kind": "oauth", + "issuer": "https://example.com", + "googleClientId": "id", + "googleClientSecret": "secret", + "googleRedirectUri": "https://example.com/callback", + "jwtSecret": "secret123456789012345678901234567890", + "encryptionKey": "key12345678901234567890123456789", + "allowedDomains": ["example.com"], + "allowedOrigins": ["https://claude.ai"] + } + }, + "mcpServers": { + "notion": { + "transportType": "stdio", + "command": "docker", + "requiresUserToken": true, + "userAuthentication": { + "type": "manual", + "displayName": "Notion", + "instructions": "Get your token from Notion settings" + } + } + } + }`, + wantErrors: []string{}, + wantErrCount: 0, + }, + { + name: "valid_oauth_user_authentication", + config: `{ + "version": "v0.0.1-DEV_EDITION", + "proxy": { + "baseURL": "http://localhost:8080", + "addr": ":8080", + "auth": { + "kind": "oauth", + "issuer": "https://example.com", + "googleClientId": "id", + "googleClientSecret": "secret", + "googleRedirectUri": "https://example.com/callback", + "jwtSecret": "secret123456789012345678901234567890", + "encryptionKey": "key12345678901234567890123456789", + "allowedDomains": ["example.com"], + "allowedOrigins": ["https://claude.ai"] + } + }, + "mcpServers": { + "linear": { + "transportType": "stdio", + "command": "npx", + "requiresUserToken": true, + "userAuthentication": { + "type": "oauth", + "displayName": "Linear", + "oauth": { + "clientId": "client123", + "clientSecret": {"$env": "LINEAR_CLIENT_SECRET"}, + "authorizationUrl": "https://linear.app/oauth/authorize", + "tokenUrl": "https://api.linear.app/oauth/token", + "scopes": ["read", "write"] + } + } + } + } + }`, + wantErrors: []string{}, + wantErrCount: 0, + }, + { + name: "invalid_oauth_user_authentication_missing_fields", + config: `{ + "version": "v0.0.1-DEV_EDITION", + "proxy": { + "baseURL": "http://localhost:8080", + "addr": ":8080", + "auth": { + "kind": "oauth", + "issuer": "https://example.com", + "googleClientId": "id", + "googleClientSecret": "secret", + "googleRedirectUri": "https://example.com/callback", + "jwtSecret": "secret123456789012345678901234567890", + "encryptionKey": "key12345678901234567890123456789", + "allowedDomains": ["example.com"], + "allowedOrigins": ["https://claude.ai"] + } + }, + "mcpServers": { + "linear": { + "transportType": "stdio", + "command": "npx", + "requiresUserToken": true, + "userAuthentication": { + "type": "oauth", + "displayName": "Linear", + "oauth": { + "clientId": "client123" + } + } + } + } + }`, + wantErrors: []string{ + "authorizationUrl is required for OAuth configuration", + "tokenUrl is required for OAuth configuration", + "scopes is required for OAuth configuration", + "scopes must be an array", + "clientSecret is required for OAuth configuration", + }, + wantErrCount: 5, + }, } for _, tt := range tests { @@ -340,6 +459,151 @@ func TestValidateFile_FileNotFound(t *testing.T) { assert.Contains(t, err.Error(), "reading config file") } +func TestValidateFile_ImprovedErrorMessages(t *testing.T) { + tests := []struct { + name string + config string + wantErrorMsg string + }{ + { + name: "plain_text_password", + config: `{ + "version": "v0.0.1-DEV_EDITION", + "proxy": { + "baseURL": "http://localhost:8080", + "addr": ":8080" + }, + "mcpServers": { + "service": { + "transportType": "stdio", + "command": "test", + "serviceAuths": [{ + "type": "basic", + "username": "user", + "password": "my-secret-password" + }] + } + } + }`, + wantErrorMsg: "password must use environment variable reference {\"$env\": \"YOUR_ENV_VAR\"} instead of plain text 'my-secret-password'", + }, + { + name: "bash_style_password", + config: `{ + "version": "v0.0.1-DEV_EDITION", + "proxy": { + "baseURL": "http://localhost:8080", + "addr": ":8080" + }, + "mcpServers": { + "service": { + "transportType": "stdio", + "command": "test", + "serviceAuths": [{ + "type": "basic", + "username": "user", + "password": "$DB_PASSWORD" + }] + } + } + }`, + wantErrorMsg: "found bash-style syntax '$DB_PASSWORD' - use {\"$env\": \"DB_PASSWORD\"} instead", + }, + { + name: "plain_text_clientSecret", + config: `{ + "version": "v0.0.1-DEV_EDITION", + "proxy": { + "baseURL": "http://localhost:8080", + "addr": ":8080", + "auth": { + "kind": "oauth", + "issuer": "https://example.com", + "googleClientId": "id", + "googleClientSecret": "secret", + "googleRedirectUri": "https://example.com/callback", + "jwtSecret": "secret123456789012345678901234567890", + "encryptionKey": "key12345678901234567890123456789", + "allowedDomains": ["example.com"], + "allowedOrigins": ["https://claude.ai"] + } + }, + "mcpServers": { + "linear": { + "transportType": "stdio", + "command": "npx", + "requiresUserToken": true, + "userAuthentication": { + "type": "oauth", + "displayName": "Linear", + "oauth": { + "clientId": "client123", + "clientSecret": "super-secret", + "authorizationUrl": "https://linear.app/oauth/authorize", + "tokenUrl": "https://api.linear.app/oauth/token", + "scopes": ["read"] + } + } + } + } + }`, + wantErrorMsg: "clientSecret must use environment variable reference {\"$env\": \"YOUR_ENV_VAR\"} instead of plain text 'super-secret'", + }, + { + name: "bash_style_userToken", + config: `{ + "version": "v0.0.1-DEV_EDITION", + "proxy": { + "baseURL": "http://localhost:8080", + "addr": ":8080" + }, + "mcpServers": { + "service": { + "transportType": "stdio", + "command": "test", + "requiresUserToken": true, + "serviceAuths": [{ + "type": "basic", + "username": "user", + "password": {"$env": "PASS"}, + "userToken": "${USER_TOKEN}" + }] + } + } + }`, + wantErrorMsg: "userToken must use {\"$userToken\": \"{{token}}\"} format instead of plain text '${USER_TOKEN}'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temp file + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + err := os.WriteFile(configPath, []byte(tt.config), 0644) + require.NoError(t, err) + + // Validate + result, err := ValidateFile(configPath) + assert.NoError(t, err) + assert.NotNil(t, result) + + // Check that we have at least one error + assert.GreaterOrEqual(t, len(result.Errors), 1, "Expected at least one validation error") + + // Check that one of the errors contains our expected message + found := false + for _, e := range result.Errors { + if strings.Contains(e.Message, tt.wantErrorMsg) { + found = true + break + } + } + assert.True(t, found, "Expected error message containing '%s', but got errors: %v", tt.wantErrorMsg, result.Errors) + }) + } +} + func contains(s, substr string) bool { return strings.Contains(s, substr) } diff --git a/internal/cookie/cookie.go b/internal/cookie/cookie.go index 2ad9b1f..3be28d4 100644 --- a/internal/cookie/cookie.go +++ b/internal/cookie/cookie.go @@ -4,7 +4,7 @@ import ( "net/http" "time" - "github.com/dgellow/mcp-front/internal" + "github.com/dgellow/mcp-front/internal/envutil" "github.com/dgellow/mcp-front/internal/log" ) @@ -16,7 +16,7 @@ const ( // SetSession sets a session cookie with appropriate security settings func SetSession(w http.ResponseWriter, value string, maxAge time.Duration) { - secure := !internal.IsDevelopmentMode() + secure := !envutil.IsDev() http.SetCookie(w, &http.Cookie{ Name: SessionCookie, Value: value, @@ -27,7 +27,7 @@ func SetSession(w http.ResponseWriter, value string, maxAge time.Duration) { MaxAge: int(maxAge.Seconds()), }) - log.LogTraceWithFields("cookie", "Session cookie set", map[string]interface{}{ + log.LogTraceWithFields("cookie", "Session cookie set", map[string]any{ "maxAge": maxAge.String(), "secure": secure, "sameSite": "Lax", @@ -41,7 +41,7 @@ func SetCSRF(w http.ResponseWriter, value string) { Value: value, Path: "/", HttpOnly: false, // CSRF tokens need to be readable by JavaScript - Secure: !internal.IsDevelopmentMode(), + Secure: !envutil.IsDev(), SameSite: http.SameSiteStrictMode, MaxAge: int((24 * time.Hour).Seconds()), // 24 hours }) diff --git a/internal/crypto/signed_token.go b/internal/crypto/signed_token.go new file mode 100644 index 0000000..97c837d --- /dev/null +++ b/internal/crypto/signed_token.go @@ -0,0 +1,98 @@ +package crypto + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "time" +) + +// TokenSigner provides HMAC-signed JSON tokens with optional expiry +type TokenSigner struct { + signingKey []byte + ttl time.Duration +} + +// NewTokenSigner creates a new token signer +func NewTokenSigner(signingKey []byte, ttl time.Duration) TokenSigner { + return TokenSigner{ + signingKey: signingKey, + ttl: ttl, + } +} + +// TokenData wraps user data with metadata +type TokenData struct { + Data json.RawMessage `json:"data"` + ExpiresAt time.Time `json:"expires_at,omitempty"` +} + +// Sign marshals data to JSON, signs it with HMAC, and returns a base64-encoded token +func (ts *TokenSigner) Sign(v any) (string, error) { + // Marshal user data + userData, err := json.Marshal(v) + if err != nil { + return "", fmt.Errorf("failed to marshal data: %w", err) + } + + // Wrap with metadata + tokenData := TokenData{ + Data: userData, + } + if ts.ttl > 0 { + tokenData.ExpiresAt = time.Now().Add(ts.ttl) + } + + // Marshal complete token + jsonData, err := json.Marshal(tokenData) + if err != nil { + return "", fmt.Errorf("failed to marshal token data: %w", err) + } + + // Create signature + signature := SignData(string(jsonData), ts.signingKey) + + // Combine data and signature + combined := fmt.Sprintf("%s.%s", base64.URLEncoding.EncodeToString(jsonData), signature) + return combined, nil +} + +// Verify validates the signature, checks expiry, and unmarshals the data +func (ts *TokenSigner) Verify(token string, v any) error { + // Split data and signature + parts := strings.Split(token, ".") + if len(parts) != 2 { + return fmt.Errorf("invalid token format") + } + + // Decode JSON data + jsonData, err := base64.URLEncoding.DecodeString(parts[0]) + if err != nil { + return fmt.Errorf("failed to decode token data: %w", err) + } + + // Verify signature + signature := parts[1] + if !ValidateSignedData(string(jsonData), signature, ts.signingKey) { + return fmt.Errorf("invalid signature") + } + + // Unmarshal token data + var tokenData TokenData + if err := json.Unmarshal(jsonData, &tokenData); err != nil { + return fmt.Errorf("failed to unmarshal token data: %w", err) + } + + // Check expiry + if !tokenData.ExpiresAt.IsZero() && time.Now().After(tokenData.ExpiresAt) { + return fmt.Errorf("token expired") + } + + // Unmarshal user data + if err := json.Unmarshal(tokenData.Data, v); err != nil { + return fmt.Errorf("failed to unmarshal user data: %w", err) + } + + return nil +} diff --git a/internal/email/email.go b/internal/email/email.go new file mode 100644 index 0000000..56ba63e --- /dev/null +++ b/internal/email/email.go @@ -0,0 +1,18 @@ +package emailutil + +import "strings" + +// Normalize normalizes an email address for consistent comparison +// by converting to lowercase and trimming whitespace +func Normalize(email string) string { + return strings.ToLower(strings.TrimSpace(email)) +} + +// ExtractDomain extracts domain from email address +func ExtractDomain(email string) string { + parts := strings.Split(email, "@") + if len(parts) != 2 { + return "" + } + return parts[1] +} diff --git a/internal/utils/email_test.go b/internal/email/email_test.go similarity index 88% rename from internal/utils/email_test.go rename to internal/email/email_test.go index ae529a8..f801f3a 100644 --- a/internal/utils/email_test.go +++ b/internal/email/email_test.go @@ -1,4 +1,4 @@ -package utils +package emailutil import ( "testing" @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNormalizeEmail(t *testing.T) { +func TestNormalize(t *testing.T) { tests := []struct { name string input string @@ -61,8 +61,8 @@ func TestNormalizeEmail(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := NormalizeEmail(tt.input) - assert.Equal(t, tt.expected, result, "NormalizeEmail(%q)", tt.input) + result := Normalize(tt.input) + assert.Equal(t, tt.expected, result, "Normalize(%q)", tt.input) }) } } diff --git a/internal/env.go b/internal/envutil/envutil.go similarity index 62% rename from internal/env.go rename to internal/envutil/envutil.go index ff92429..ef5acfc 100644 --- a/internal/env.go +++ b/internal/envutil/envutil.go @@ -1,13 +1,13 @@ -package internal +package envutil import ( "os" "strings" ) -// IsDevelopmentMode checks if we're running in development mode +// IsDev checks if we're running in development mode // where security requirements can be relaxed for testing -func IsDevelopmentMode() bool { +func IsDev() bool { env := strings.ToLower(os.Getenv("MCP_FRONT_ENV")) return env == "development" || env == "dev" } diff --git a/internal/googleauth/google.go b/internal/googleauth/google.go new file mode 100644 index 0000000..defa32c --- /dev/null +++ b/internal/googleauth/google.go @@ -0,0 +1,121 @@ +package googleauth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "os" + "slices" + "strings" + + "github.com/dgellow/mcp-front/internal/config" + emailutil "github.com/dgellow/mcp-front/internal/email" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +// UserInfo represents Google user information +type UserInfo struct { + Email string `json:"email"` + HostedDomain string `json:"hd"` + Name string `json:"name"` + Picture string `json:"picture"` + VerifiedEmail bool `json:"verified_email"` +} + +// GoogleAuthURL generates a Google OAuth authorization URL +func GoogleAuthURL(oauthConfig config.OAuthAuthConfig, state string) string { + googleOAuth := newGoogleOAuth2Config(oauthConfig) + return googleOAuth.AuthCodeURL(state, + oauth2.AccessTypeOffline, + oauth2.ApprovalForce, + ) +} + +// ExchangeCodeForToken exchanges the authorization code for a token +func ExchangeCodeForToken(ctx context.Context, oauthConfig config.OAuthAuthConfig, code string) (*oauth2.Token, error) { + googleOAuth := newGoogleOAuth2Config(oauthConfig) + return googleOAuth.Exchange(ctx, code) +} + +// ValidateUser validates the Google OAuth token and checks domain membership +func ValidateUser(ctx context.Context, oauthConfig config.OAuthAuthConfig, token *oauth2.Token) (UserInfo, error) { + googleOAuth := newGoogleOAuth2Config(oauthConfig) + client := googleOAuth.Client(ctx, token) + userInfoURL := "https://www.googleapis.com/oauth2/v2/userinfo" + if customURL := os.Getenv("GOOGLE_USERINFO_URL"); customURL != "" { + userInfoURL = customURL + } + resp, err := client.Get(userInfoURL) + if err != nil { + return UserInfo{}, fmt.Errorf("failed to get user info: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return UserInfo{}, fmt.Errorf("failed to get user info: status %d", resp.StatusCode) + } + + var userInfo UserInfo + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + return UserInfo{}, fmt.Errorf("failed to decode user info: %w", err) + } + + // Validate domain if configured + if len(oauthConfig.AllowedDomains) > 0 { + userDomain := emailutil.ExtractDomain(userInfo.Email) + if !slices.Contains(oauthConfig.AllowedDomains, userDomain) { + return UserInfo{}, fmt.Errorf("domain '%s' is not allowed. Contact your administrator", userDomain) + } + } + + return userInfo, nil +} + +// ParseClientRequest parses MCP client registration metadata +func ParseClientRequest(metadata map[string]any) ([]string, []string, error) { + // Extract redirect URIs + redirectURIs := []string{} + if uris, ok := metadata["redirect_uris"].([]any); ok { + for _, uri := range uris { + if uriStr, ok := uri.(string); ok { + redirectURIs = append(redirectURIs, uriStr) + } + } + } + + if len(redirectURIs) == 0 { + return nil, nil, fmt.Errorf("no valid redirect URIs provided") + } + + // Extract scopes, default to read/write if not provided + scopes := []string{"read", "write"} // Default MCP scopes + if clientScopes, ok := metadata["scope"].(string); ok { + if strings.TrimSpace(clientScopes) != "" { + scopes = strings.Fields(clientScopes) + } + } + + return redirectURIs, scopes, nil +} + +// newGoogleOAuth2Config creates the OAuth2 config from our Config +func newGoogleOAuth2Config(oauthConfig config.OAuthAuthConfig) oauth2.Config { + // Use custom OAuth endpoints if provided (for testing) + endpoint := google.Endpoint + if authURL := os.Getenv("GOOGLE_OAUTH_AUTH_URL"); authURL != "" { + endpoint.AuthURL = authURL + } + if tokenURL := os.Getenv("GOOGLE_OAUTH_TOKEN_URL"); tokenURL != "" { + endpoint.TokenURL = tokenURL + } + + return oauth2.Config{ + ClientID: oauthConfig.GoogleClientID, + ClientSecret: string(oauthConfig.GoogleClientSecret), + RedirectURL: oauthConfig.GoogleRedirectURI, + Scopes: []string{"openid", "profile", "email"}, + Endpoint: endpoint, + } +} diff --git a/internal/inline/server.go b/internal/inline/server.go index c087a7f..1590148 100644 --- a/internal/inline/server.go +++ b/internal/inline/server.go @@ -36,9 +36,9 @@ func NewServer(name string, config Config, resolvedTools []ResolvedToolConfig) * // Tool represents an MCP tool type Tool struct { - Name string `json:"name"` - Description string `json:"description"` - InputSchema map[string]interface{} `json:"inputSchema"` + Name string `json:"name"` + Description string `json:"description"` + InputSchema map[string]any `json:"inputSchema"` } // ServerCapabilities represents server capabilities @@ -51,7 +51,7 @@ func (s *Server) GetCapabilities() ServerCapabilities { tools := make(map[string]Tool) for name, tool := range s.tools { - var inputSchema map[string]interface{} + var inputSchema map[string]any if len(tool.InputSchema) > 0 { if err := json.Unmarshal(tool.InputSchema, &inputSchema); err != nil { log.LogError("Failed to unmarshal input schema for tool %s: %v", name, err) @@ -76,7 +76,7 @@ func (s *Server) GetDescription() string { } // HandleToolCall executes a tool and returns the result -func (s *Server) HandleToolCall(ctx context.Context, toolName string, args map[string]interface{}) (interface{}, error) { +func (s *Server) HandleToolCall(ctx context.Context, toolName string, args map[string]any) (any, error) { tool, exists := s.tools[toolName] if !exists { return nil, fmt.Errorf("tool %s not found", toolName) @@ -114,25 +114,25 @@ func (s *Server) HandleToolCall(ctx context.Context, toolName string, args map[s // Execute err := cmd.Run() if err != nil { - log.LogErrorWithFields("inline", "Tool execution failed", map[string]interface{}{ + log.LogErrorWithFields("inline", "Tool execution failed", map[string]any{ "tool": toolName, "error": err.Error(), "stderr": stderr.String(), }) - return map[string]interface{}{ + return map[string]any{ "error": err.Error(), "stderr": stderr.String(), }, fmt.Errorf("command failed: %w", err) } // Try to parse as JSON first - var result interface{} + var result any if err := json.Unmarshal(stdout.Bytes(), &result); err == nil { return result, nil } // Return as text if not JSON - return map[string]interface{}{ + return map[string]any{ "output": stdout.String(), "stderr": stderr.String(), }, nil diff --git a/internal/inline/server_test.go b/internal/inline/server_test.go index 5cf6db7..9f3a704 100644 --- a/internal/inline/server_test.go +++ b/internal/inline/server_test.go @@ -90,17 +90,17 @@ func TestServer_HandleToolCall(t *testing.T) { tests := []struct { name string toolName string - args map[string]interface{} + args map[string]any wantError bool - validate func(t *testing.T, result interface{}, err error) + validate func(t *testing.T, result any, err error) }{ { name: "echo tool", toolName: "echo", - args: map[string]interface{}{}, + args: map[string]any{}, wantError: false, - validate: func(t *testing.T, result interface{}, err error) { - resultMap, ok := result.(map[string]interface{}) + validate: func(t *testing.T, result any, err error) { + resultMap, ok := result.(map[string]any) require.True(t, ok) output := resultMap["output"].(string) assert.Equal(t, "test-message\n", output) @@ -109,9 +109,9 @@ func TestServer_HandleToolCall(t *testing.T) { { name: "nonexistent tool", toolName: "nonexistent", - args: map[string]interface{}{}, + args: map[string]any{}, wantError: true, - validate: func(t *testing.T, result interface{}, err error) { + validate: func(t *testing.T, result any, err error) { assert.Error(t, err) assert.Contains(t, err.Error(), "tool nonexistent not found") }, @@ -119,11 +119,11 @@ func TestServer_HandleToolCall(t *testing.T) { { name: "cat nonexistent file", toolName: "cat", - args: map[string]interface{}{}, + args: map[string]any{}, wantError: true, - validate: func(t *testing.T, result interface{}, err error) { + validate: func(t *testing.T, result any, err error) { assert.Error(t, err) - resultMap, ok := result.(map[string]interface{}) + resultMap, ok := result.(map[string]any) require.True(t, ok) stderr := resultMap["stderr"].(string) assert.Contains(t, stderr, "No such file") @@ -132,10 +132,10 @@ func TestServer_HandleToolCall(t *testing.T) { { name: "environment variable test", toolName: "env_test", - args: map[string]interface{}{}, + args: map[string]any{}, wantError: false, - validate: func(t *testing.T, result interface{}, err error) { - resultMap, ok := result.(map[string]interface{}) + validate: func(t *testing.T, result any, err error) { + resultMap, ok := result.(map[string]any) require.True(t, ok) output := resultMap["output"].(string) assert.Contains(t, output, "TEST_VAR=test-value") @@ -175,12 +175,12 @@ func TestServer_HandleToolCall_JSON(t *testing.T) { server := NewServer("test", Config{}, resolvedTools) ctx := context.Background() - result, err := server.HandleToolCall(ctx, "json_output", map[string]interface{}{}) + result, err := server.HandleToolCall(ctx, "json_output", map[string]any{}) require.NoError(t, err) // Should parse as JSON - resultMap, ok := result.(map[string]interface{}) + resultMap, ok := result.(map[string]any) require.True(t, ok) assert.Equal(t, "ok", resultMap["status"]) assert.Equal(t, float64(42), resultMap["value"]) @@ -205,13 +205,13 @@ func TestServer_HandleToolCall_Timeout(t *testing.T) { server := NewServer("test", Config{}, resolvedTools) ctx := context.Background() - result, err := server.HandleToolCall(ctx, "slow_command", map[string]interface{}{}) + result, err := server.HandleToolCall(ctx, "slow_command", map[string]any{}) assert.Error(t, err) assert.Contains(t, err.Error(), "command failed") // Check that we got a timeout-related error in stderr or error message - if resultMap, ok := result.(map[string]interface{}); ok { + if resultMap, ok := result.(map[string]any); ok { stderr, _ := resultMap["stderr"].(string) errorMsg, _ := resultMap["error"].(string) // The actual error message varies by OS, but it should indicate termination diff --git a/internal/log/log.go b/internal/log/log.go index a192d46..045db75 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -114,7 +114,7 @@ func SetLogLevel(level string) error { currentLevel.Store(newLevel) updateHandler() - LogInfoWithFields("logging", "Log level changed", map[string]interface{}{ + LogInfoWithFields("logging", "Log level changed", map[string]any{ "new_level": level, }) @@ -142,30 +142,30 @@ func GetLogLevel() string { } // Convenience functions using standard slog with component context -func Logf(format string, args ...interface{}) { +func Logf(format string, args ...any) { logger.Info(fmt.Sprintf(format, args...)) } -func LogError(format string, args ...interface{}) { +func LogError(format string, args ...any) { logger.Error(fmt.Sprintf(format, args...)) } -func LogWarn(format string, args ...interface{}) { +func LogWarn(format string, args ...any) { logger.Warn(fmt.Sprintf(format, args...)) } -func LogDebug(format string, args ...interface{}) { +func LogDebug(format string, args ...any) { logger.Debug(fmt.Sprintf(format, args...)) } -func LogTrace(format string, args ...interface{}) { +func LogTrace(format string, args ...any) { if currentLevel.Load().(slog.Level) <= LevelTrace { logger.Log(context.Background(), LevelTrace, fmt.Sprintf(format, args...)) } } // Structured logging functions with component and fields -func LogInfoWithFields(component, message string, fields map[string]interface{}) { +func LogInfoWithFields(component, message string, fields map[string]any) { args := make([]any, 0, len(fields)*2+2) args = append(args, "component", component) for k, v := range fields { @@ -174,7 +174,7 @@ func LogInfoWithFields(component, message string, fields map[string]interface{}) logger.Info(message, args...) } -func LogDebugWithFields(component, message string, fields map[string]interface{}) { +func LogDebugWithFields(component, message string, fields map[string]any) { args := make([]any, 0, len(fields)*2+2) args = append(args, "component", component) for k, v := range fields { @@ -183,7 +183,7 @@ func LogDebugWithFields(component, message string, fields map[string]interface{} logger.Debug(message, args...) } -func LogErrorWithFields(component, message string, fields map[string]interface{}) { +func LogErrorWithFields(component, message string, fields map[string]any) { args := make([]any, 0, len(fields)*2+2) args = append(args, "component", component) for k, v := range fields { @@ -192,7 +192,7 @@ func LogErrorWithFields(component, message string, fields map[string]interface{} logger.Error(message, args...) } -func LogWarnWithFields(component, message string, fields map[string]interface{}) { +func LogWarnWithFields(component, message string, fields map[string]any) { args := make([]any, 0, len(fields)*2+2) args = append(args, "component", component) for k, v := range fields { @@ -201,7 +201,7 @@ func LogWarnWithFields(component, message string, fields map[string]interface{}) logger.Warn(message, args...) } -func LogTraceWithFields(component, message string, fields map[string]interface{}) { +func LogTraceWithFields(component, message string, fields map[string]any) { if currentLevel.Load().(slog.Level) <= LevelTrace { args := make([]any, 0, len(fields)*2+2) args = append(args, "component", component) diff --git a/internal/mcpfront.go b/internal/mcpfront.go new file mode 100644 index 0000000..8802e41 --- /dev/null +++ b/internal/mcpfront.go @@ -0,0 +1,579 @@ +package internal + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + "os/signal" + "syscall" + "time" + + "github.com/dgellow/mcp-front/internal/auth" + "github.com/dgellow/mcp-front/internal/client" + "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/crypto" + "github.com/dgellow/mcp-front/internal/inline" + "github.com/dgellow/mcp-front/internal/log" + "github.com/dgellow/mcp-front/internal/oauth" + "github.com/dgellow/mcp-front/internal/server" + "github.com/dgellow/mcp-front/internal/storage" + "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/ory/fosite" +) + +// MCPFront represents the complete MCP proxy application +type MCPFront struct { + config config.Config + httpServer *server.HTTPServer + sessionManager *client.StdioSessionManager + storage storage.Storage +} + +// NewMCPFront creates a new MCP proxy application with all dependencies built +func NewMCPFront(ctx context.Context, cfg config.Config) (*MCPFront, error) { + log.LogInfoWithFields("mcpfront", "Building MCP proxy application", map[string]any{ + "baseURL": cfg.Proxy.BaseURL, + "mcpServers": len(cfg.MCPServers), + }) + + // Parse base URL + baseURL, err := url.Parse(cfg.Proxy.BaseURL) + if err != nil { + return nil, fmt.Errorf("invalid base URL: %w", err) + } + + // Setup storage (always available, independent of OAuth) + store, err := setupStorage(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("failed to setup storage: %w", err) + } + + // Setup authentication (OAuth components and service client) + oauthProvider, sessionEncryptor, authConfig, serviceOAuthClient, err := setupAuthentication(ctx, cfg, store) + if err != nil { + return nil, fmt.Errorf("failed to setup authentication: %w", err) + } + + // Create session manager for stdio servers with configurable timeouts + sessionTimeout := 5 * time.Minute + cleanupInterval := 1 * time.Minute + maxPerUser := 10 + + // Use config values if available + if cfg.Proxy.Sessions != nil { + if cfg.Proxy.Sessions.Timeout > 0 { + sessionTimeout = cfg.Proxy.Sessions.Timeout + log.LogInfoWithFields("mcpfront", "Using configured session timeout", map[string]any{ + "timeout": sessionTimeout, + }) + } + if cfg.Proxy.Sessions.CleanupInterval > 0 { + cleanupInterval = cfg.Proxy.Sessions.CleanupInterval + log.LogInfoWithFields("mcpfront", "Using configured cleanup interval", map[string]any{ + "interval": cleanupInterval, + }) + } + maxPerUser = cfg.Proxy.Sessions.MaxPerUser + } + + sessionManager := client.NewStdioSessionManager( + client.WithTimeout(sessionTimeout), + client.WithMaxPerUser(maxPerUser), + client.WithCleanupInterval(cleanupInterval), + ) + + // Create user token service + userTokenService := server.NewUserTokenService(store, serviceOAuthClient) + + info := mcp.Implementation{ + Name: cfg.Proxy.Name, + Version: "dev", + } + + // Build complete HTTP handler with all routing and dependencies + mux, err := buildHTTPHandler( + cfg, + store, + oauthProvider, + sessionEncryptor, + authConfig, + serviceOAuthClient, + sessionManager, + userTokenService, + baseURL.String(), + info, + ) + if err != nil { + return nil, fmt.Errorf("failed to build HTTP handler: %w", err) + } + + // Create clean HTTP server with just the handler and address + httpServer := server.NewHTTPServer(mux, cfg.Proxy.Addr) + + return &MCPFront{ + config: cfg, + httpServer: httpServer, + sessionManager: sessionManager, + storage: store, + }, nil +} + +// Run starts and manages the complete MCP proxy application lifecycle +func (m *MCPFront) Run() error { + log.LogInfoWithFields("mcpfront", "Starting MCP proxy application", map[string]any{ + "addr": m.config.Proxy.Addr, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Channel to signal errors that should trigger shutdown + errChan := make(chan error, 1) + + // Start HTTP server + go func() { + if err := m.httpServer.Start(); err != nil { + errChan <- fmt.Errorf("HTTP server error: %w", err) + } + }() + + // Start session manager cleanup (if needed) + // The session manager already starts its cleanup goroutine internally, + // but this is where we could start other background services + + // Handle graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + var shutdownReason string + select { + case sig := <-sigChan: + shutdownReason = fmt.Sprintf("signal %v", sig) + log.LogInfoWithFields("mcpfront", "Received shutdown signal", map[string]any{ + "signal": sig.String(), + }) + case err := <-errChan: + shutdownReason = fmt.Sprintf("error: %v", err) + log.LogErrorWithFields("mcpfront", "Shutting down due to error", map[string]any{ + "error": err.Error(), + }) + case <-ctx.Done(): + shutdownReason = "context cancelled" + log.LogInfoWithFields("mcpfront", "Context cancelled, shutting down", nil) + } + + // Graceful shutdown + log.LogInfoWithFields("mcpfront", "Starting graceful shutdown", map[string]any{ + "reason": shutdownReason, + "timeout": "30s", + }) + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() + + // Stop HTTP server + if err := m.httpServer.Stop(shutdownCtx); err != nil { + log.LogErrorWithFields("mcpfront", "HTTP server shutdown error", map[string]any{ + "error": err.Error(), + }) + return err + } + + // Shutdown session manager + if m.sessionManager != nil { + m.sessionManager.Shutdown() + } + + log.LogInfoWithFields("mcpfront", "Application shutdown complete", map[string]any{ + "reason": shutdownReason, + }) + return nil +} + +// setupStorage creates storage based on configuration, independent of OAuth +func setupStorage(ctx context.Context, cfg config.Config) (storage.Storage, error) { + // Check if OAuth config provides storage configuration + if oauthAuth := cfg.Proxy.Auth; oauthAuth != nil { + if oauthAuth.Storage == "firestore" { + log.LogInfoWithFields("storage", "Using Firestore storage", map[string]any{ + "project": oauthAuth.GCPProject, + "database": oauthAuth.FirestoreDatabase, + "collection": oauthAuth.FirestoreCollection, + }) + // Create encryptor for Firestore storage + encryptor, err := crypto.NewEncryptor([]byte(oauthAuth.EncryptionKey)) + if err != nil { + return nil, fmt.Errorf("failed to create encryptor: %w", err) + } + firestoreStorage, err := storage.NewFirestoreStorage( + ctx, + oauthAuth.GCPProject, + oauthAuth.FirestoreDatabase, + oauthAuth.FirestoreCollection, + encryptor, + ) + if err != nil { + return nil, fmt.Errorf("failed to create Firestore storage: %w", err) + } + return firestoreStorage, nil + } + } + + // Default to memory storage + log.LogInfoWithFields("storage", "Using in-memory storage", map[string]any{}) + return storage.NewMemoryStorage(), nil +} + +// setupAuthentication creates individual OAuth components using clean constructors +func setupAuthentication(ctx context.Context, cfg config.Config, store storage.Storage) (fosite.OAuth2Provider, crypto.Encryptor, config.OAuthAuthConfig, *auth.ServiceOAuthClient, error) { + oauthAuth := cfg.Proxy.Auth + if oauthAuth == nil { + // OAuth not configured + return nil, nil, config.OAuthAuthConfig{}, nil, nil + } + + log.LogDebug("initializing OAuth 2.1 components with clean constructors") + + // Generate or validate JWT secret using clean constructor + jwtSecret, err := oauth.GenerateJWTSecret(string(oauthAuth.JWTSecret)) + if err != nil { + return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to setup JWT secret: %w", err) + } + + // Create session encryptor using clean constructor + encryptionKey := []byte(oauthAuth.EncryptionKey) + sessionEncryptor, err := oauth.NewSessionEncryptor(encryptionKey) + if err != nil { + return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create session encryptor: %w", err) + } + + // Create OAuth provider using clean constructor + oauthProvider, err := oauth.NewOAuthProvider(*oauthAuth, store, jwtSecret) + if err != nil { + return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create OAuth provider: %w", err) + } + + // Create OAuth client for service authentication and token refresh + serviceOAuthClient := auth.NewServiceOAuthClient(store, cfg.Proxy.BaseURL) + + // Initialize admin users if admin is enabled + if cfg.Proxy.Admin != nil && cfg.Proxy.Admin.Enabled { + for _, adminEmail := range cfg.Proxy.Admin.AdminEmails { + // Upsert admin user + if err := store.UpsertUser(ctx, adminEmail); err != nil { + log.LogWarnWithFields("mcpfront", "Failed to initialize admin user", map[string]any{ + "email": adminEmail, + "error": err.Error(), + }) + continue + } + // Set as admin + if err := store.SetUserAdmin(ctx, adminEmail, true); err != nil { + log.LogWarnWithFields("mcpfront", "Failed to set user as admin", map[string]any{ + "email": adminEmail, + "error": err.Error(), + }) + } + } + } + + return oauthProvider, sessionEncryptor, *oauthAuth, serviceOAuthClient, nil +} + +// buildHTTPHandler creates the complete HTTP handler with all routing and middleware +func buildHTTPHandler( + cfg config.Config, + storage storage.Storage, + oauthProvider fosite.OAuth2Provider, + sessionEncryptor crypto.Encryptor, + authConfig config.OAuthAuthConfig, + serviceOAuthClient *auth.ServiceOAuthClient, + sessionManager *client.StdioSessionManager, + userTokenService *server.UserTokenService, + baseURL string, + info mcp.Implementation, +) (*http.ServeMux, error) { + // Create mux and register all routes with dependency injection + mux := http.NewServeMux() + + // Build common middleware + corsMiddleware := server.NewCORSMiddleware(authConfig.AllowedOrigins) + oauthLogger := server.NewLoggerMiddleware("oauth") + mcpLogger := server.NewLoggerMiddleware("mcp") + tokenLogger := server.NewLoggerMiddleware("tokens") + adminLogger := server.NewLoggerMiddleware("admin") + mcpRecover := server.NewRecoverMiddleware("mcp") + oauthRecover := server.NewRecoverMiddleware("oauth") + + // Register health endpoint + mux.Handle("/health", server.NewHealthHandler()) + + // Create browser state token for SSO middleware (used by both OAuth and admin routes) + var browserStateToken *crypto.TokenSigner + if authConfig.EncryptionKey != "" { + token := crypto.NewTokenSigner([]byte(authConfig.EncryptionKey), 10*time.Minute) + browserStateToken = &token + } + + // Register OAuth endpoints if OAuth is enabled + if oauthProvider != nil { + // Build OAuth middleware + oauthMiddleware := []server.MiddlewareFunc{ + corsMiddleware, + oauthLogger, + oauthRecover, + } + + // Create OAuth auth handlers with dependency injection + authHandlers := server.NewAuthHandlers( + oauthProvider, + authConfig, + storage, + sessionEncryptor, + cfg.MCPServers, + serviceOAuthClient, + ) + + // Register OAuth endpoints + mux.Handle("/.well-known/oauth-authorization-server", server.ChainMiddleware(http.HandlerFunc(authHandlers.WellKnownHandler), oauthMiddleware...)) + mux.Handle("/authorize", server.ChainMiddleware(http.HandlerFunc(authHandlers.AuthorizeHandler), oauthMiddleware...)) + mux.Handle("/oauth/callback", server.ChainMiddleware(http.HandlerFunc(authHandlers.GoogleCallbackHandler), oauthMiddleware...)) + mux.Handle("/token", server.ChainMiddleware(http.HandlerFunc(authHandlers.TokenHandler), oauthMiddleware...)) + mux.Handle("/register", server.ChainMiddleware(http.HandlerFunc(authHandlers.RegisterHandler), oauthMiddleware...)) + + // Register protected token endpoints + tokenMiddleware := []server.MiddlewareFunc{ + corsMiddleware, + tokenLogger, + server.NewBrowserSSOMiddleware(authConfig, sessionEncryptor, browserStateToken), + mcpRecover, + } + + // Create token handlers + tokenHandlers := server.NewTokenHandlers(storage, cfg.MCPServers, true, serviceOAuthClient) + + // Token management UI endpoints + mux.Handle("/my/tokens", server.ChainMiddleware(http.HandlerFunc(tokenHandlers.ListTokensHandler), tokenMiddleware...)) + mux.Handle("/my/tokens/set", server.ChainMiddleware(http.HandlerFunc(tokenHandlers.SetTokenHandler), tokenMiddleware...)) + mux.Handle("/my/tokens/delete", server.ChainMiddleware(http.HandlerFunc(tokenHandlers.DeleteTokenHandler), tokenMiddleware...)) + + // OAuth interstitial page and completion endpoint + mux.Handle("/oauth/services", server.ChainMiddleware(http.HandlerFunc(authHandlers.ServiceSelectionHandler), tokenMiddleware...)) + mux.Handle("/oauth/complete", server.ChainMiddleware(http.HandlerFunc(authHandlers.CompleteOAuthHandler), tokenMiddleware...)) + + // Register service OAuth endpoints + serviceAuthHandlers := server.NewServiceAuthHandlers(serviceOAuthClient, cfg.MCPServers, storage) + mux.HandleFunc("/oauth/callback/", serviceAuthHandlers.CallbackHandler) + mux.Handle("/oauth/connect", server.ChainMiddleware(http.HandlerFunc(serviceAuthHandlers.ConnectHandler), tokenMiddleware...)) + mux.Handle("/oauth/disconnect", server.ChainMiddleware(http.HandlerFunc(serviceAuthHandlers.DisconnectHandler), tokenMiddleware...)) + } + + // Setup MCP server endpoints + sseServers := make(map[string]*mcpserver.SSEServer) // Track SSE servers for stdio servers + + for serverName, serverConfig := range cfg.MCPServers { + log.LogInfoWithFields("server", "Registering MCP server", map[string]any{ + "name": serverName, + "transport_type": serverConfig.TransportType, + "requires_user_token": serverConfig.RequiresUserToken, + }) + + var handler http.Handler + var err error + var mcpServer *mcpserver.MCPServer + var sseServer *mcpserver.SSEServer + + // For inline servers, create a custom handler + if serverConfig.TransportType == config.MCPClientTypeInline { + handler, err = buildInlineHandler(serverName, serverConfig) + if err != nil { + return nil, fmt.Errorf("failed to create inline handler for %s: %w", serverName, err) + } + } else { + // For stdio/SSE servers + if isStdioServer(serverConfig) { + sseServer, mcpServer, err = buildStdioSSEServer(serverName, baseURL, sessionManager) + if err != nil { + return nil, fmt.Errorf("failed to create SSE server for %s: %w", serverName, err) + } + sseServers[serverName] = sseServer + } + + // Create MCP handler for stdio/SSE servers + handler = server.NewMCPHandler( + serverName, + serverConfig, + storage, + baseURL, + info, + sessionManager, + sseServers[serverName], // Pass the shared SSE server (nil for non-stdio) + mcpServer, // Pass the shared MCP server (nil for non-stdio) + userTokenService.GetUserToken, + ) + } + + // Setup middlewares for this MCP server + mcpMiddlewares := []server.MiddlewareFunc{ + mcpLogger, + corsMiddleware, + } + + // Add OAuth validation if OAuth is enabled + if oauthProvider != nil { + mcpMiddlewares = append(mcpMiddlewares, oauth.NewValidateTokenMiddleware(oauthProvider)) + } + + // Add service auth middleware if configured + if len(serverConfig.ServiceAuths) > 0 { + mcpMiddlewares = append(mcpMiddlewares, server.NewServiceAuthMiddleware(serverConfig.ServiceAuths)) + } + + // Recovery middleware should be last (outermost) + mcpMiddlewares = append(mcpMiddlewares, mcpRecover) + + // Register handler - SSE server needs to handle all paths under the server name + mux.Handle("/"+serverName+"/", server.ChainMiddleware(handler, mcpMiddlewares...)) + } + + // Setup admin routes if admin is enabled + if cfg.Proxy.Admin != nil && cfg.Proxy.Admin.Enabled { + log.LogInfoWithFields("server", "Admin UI enabled", map[string]any{ + "admin_emails": cfg.Proxy.Admin.AdminEmails, + }) + + // Get encryption key from OAuth config + var encryptionKey string + if oauthAuth := cfg.Proxy.Auth; oauthAuth != nil { + encryptionKey = string(oauthAuth.EncryptionKey) + } + + // Create admin handlers + adminHandlers := server.NewAdminHandlers(storage, cfg, sessionManager, encryptionKey) + + // Build admin middleware + adminMiddleware := []server.MiddlewareFunc{ + corsMiddleware, + adminLogger, + } + + // Add browser SSO if OAuth is enabled + if oauthProvider != nil { + // Reuse the same browserStateToken created earlier for consistency + adminMiddleware = append(adminMiddleware, server.NewBrowserSSOMiddleware(authConfig, sessionEncryptor, browserStateToken)) + } + + // Add admin check middleware + adminMiddleware = append(adminMiddleware, server.NewAdminMiddleware(cfg.Proxy.Admin, storage)) + + // Recovery middleware last + adminMiddleware = append(adminMiddleware, mcpRecover) + + // Register admin routes + mux.Handle("/admin", server.ChainMiddleware(http.HandlerFunc(adminHandlers.DashboardHandler), adminMiddleware...)) + mux.Handle("/admin/users", server.ChainMiddleware(http.HandlerFunc(adminHandlers.UserActionHandler), adminMiddleware...)) + mux.Handle("/admin/sessions", server.ChainMiddleware(http.HandlerFunc(adminHandlers.SessionActionHandler), adminMiddleware...)) + mux.Handle("/admin/logging", server.ChainMiddleware(http.HandlerFunc(adminHandlers.LoggingActionHandler), adminMiddleware...)) + } + + log.LogInfoWithFields("server", "MCP proxy server initialized", nil) + return mux, nil +} + +// buildInlineHandler creates an inline MCP handler +func buildInlineHandler(serverName string, serverConfig *config.MCPClientConfig) (http.Handler, error) { + // Resolve inline config + inlineConfig, resolvedTools, err := inline.ResolveConfig(serverConfig.InlineConfig) + if err != nil { + return nil, fmt.Errorf("failed to resolve inline config: %w", err) + } + + // Create inline server + inlineServer := inline.NewServer(serverName, inlineConfig, resolvedTools) + + // Create inline handler + handler := inline.NewHandler(serverName, inlineServer) + + log.LogInfoWithFields("server", "Created inline MCP server", map[string]any{ + "name": serverName, + "tools": len(resolvedTools), + }) + + return handler, nil +} + +// buildStdioSSEServer creates an SSE server for stdio MCP servers +func buildStdioSSEServer(serverName, baseURL string, sessionManager *client.StdioSessionManager) (*mcpserver.SSEServer, *mcpserver.MCPServer, error) { + // Create hooks for session management + hooks := &mcpserver.Hooks{} + + // Store reference to server name for use in hooks + currentServerName := serverName + + // Setup hooks that will be called when sessions are created/destroyed + hooks.AddOnRegisterSession(func(sessionCtx context.Context, session mcpserver.ClientSession) { + // Extract handler from context + if handler, ok := sessionCtx.Value(server.SessionHandlerKey{}).(*server.SessionRequestHandler); ok { + // Handle session registration (MCP server is already set in handler) + server.HandleSessionRegistration(sessionCtx, session, handler, sessionManager) + } else { + log.LogErrorWithFields("server", "No session handler in context", map[string]any{ + "sessionID": session.SessionID(), + "server": currentServerName, + }) + } + }) + + hooks.AddOnUnregisterSession(func(sessionCtx context.Context, session mcpserver.ClientSession) { + // Extract handler from context + if handler, ok := sessionCtx.Value(server.SessionHandlerKey{}).(*server.SessionRequestHandler); ok { + // Handle session cleanup + key := client.SessionKey{ + UserEmail: handler.GetUserEmail(), + ServerName: handler.GetServerName(), + SessionID: session.SessionID(), + } + sessionManager.RemoveSession(key) + + if storage := handler.GetStorage(); storage != nil { + if err := storage.RevokeSession(sessionCtx, session.SessionID()); err != nil { + log.LogWarnWithFields("server", "Failed to revoke session from storage", map[string]any{ + "error": err.Error(), + "sessionID": session.SessionID(), + "user": handler.GetUserEmail(), + }) + } + } + + log.LogInfoWithFields("server", "Session unregistered and cleaned up", map[string]any{ + "sessionID": session.SessionID(), + "server": currentServerName, + "user": handler.GetUserEmail(), + }) + } + }) + + // Now create the MCP server with the hooks + mcpServer := mcpserver.NewMCPServer(serverName, "1.0.0", + mcpserver.WithHooks(hooks), + mcpserver.WithPromptCapabilities(true), + mcpserver.WithResourceCapabilities(true, true), + mcpserver.WithToolCapabilities(true), + mcpserver.WithLogging(), + ) + + // Create the SSE server wrapper around the MCP server + sseServer := mcpserver.NewSSEServer(mcpServer, + mcpserver.WithStaticBasePath(serverName), + mcpserver.WithBaseURL(baseURL), + ) + + return sseServer, mcpServer, nil +} + +// isStdioServer checks if this is a stdio-based server +func isStdioServer(cfg *config.MCPClientConfig) bool { + return cfg.TransportType == config.MCPClientTypeStdio +} diff --git a/internal/oauth/auth.go b/internal/oauth/auth.go deleted file mode 100644 index 7cb699a..0000000 --- a/internal/oauth/auth.go +++ /dev/null @@ -1,143 +0,0 @@ -package oauth - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "os" - "strings" - - "github.com/dgellow/mcp-front/internal/log" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" -) - -// authService handles Google OAuth integration and user validation -type authService struct { - googleOAuth *oauth2.Config - allowedDomains []string -} - -// UserInfo represents Google user information -type UserInfo struct { - Email string `json:"email"` - HostedDomain string `json:"hd"` - Name string `json:"name"` - Picture string `json:"picture"` - VerifiedEmail bool `json:"verified_email"` -} - -// newAuthService creates a new auth service instance -func newAuthService(config Config) (*authService, error) { - // Use custom OAuth endpoints if provided (for testing) - endpoint := google.Endpoint - if authURL := os.Getenv("GOOGLE_OAUTH_AUTH_URL"); authURL != "" { - endpoint.AuthURL = authURL - } - if tokenURL := os.Getenv("GOOGLE_OAUTH_TOKEN_URL"); tokenURL != "" { - endpoint.TokenURL = tokenURL - } - - googleConfig := &oauth2.Config{ - ClientID: config.GoogleClientID, - ClientSecret: config.GoogleClientSecret, - RedirectURL: config.GoogleRedirectURI, - Scopes: []string{ - "openid", - "email", - }, - Endpoint: endpoint, - } - - log.Logf("Google OAuth config - ClientID: %s, RedirectURL: %s", config.GoogleClientID, config.GoogleRedirectURI) - - return &authService{ - googleOAuth: googleConfig, - allowedDomains: config.AllowedDomains, - }, nil -} - -// googleAuthURL returns the Google OAuth authorization URL -func (s *authService) googleAuthURL(state string) string { - return s.googleOAuth.AuthCodeURL(state, - oauth2.AccessTypeOffline, - oauth2.ApprovalForce, - ) -} - -// exchangeCodeForToken exchanges the authorization code for a token -func (s *authService) exchangeCodeForToken(ctx context.Context, code string) (*oauth2.Token, error) { - return s.googleOAuth.Exchange(ctx, code) -} - -// validateUser validates the Google OAuth token and checks domain membership -func (s *authService) validateUser(ctx context.Context, token *oauth2.Token) (*UserInfo, error) { - client := s.googleOAuth.Client(ctx, token) - userInfoURL := "https://www.googleapis.com/oauth2/v2/userinfo" - if customURL := os.Getenv("GOOGLE_USERINFO_URL"); customURL != "" { - userInfoURL = customURL - } - resp, err := client.Get(userInfoURL) - if err != nil { - return nil, fmt.Errorf("failed to get user info: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("failed to get user info: status %d", resp.StatusCode) - } - - var userInfo UserInfo - if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { - return nil, fmt.Errorf("failed to decode user info: %w", err) - } - - // Validate domain if configured - if len(s.allowedDomains) > 0 { - if userInfo.HostedDomain == "" { - return nil, fmt.Errorf("user %s does not belong to a hosted domain", userInfo.Email) - } - - domainAllowed := false - for _, domain := range s.allowedDomains { - if userInfo.HostedDomain == domain { - domainAllowed = true - break - } - } - - if !domainAllowed { - return nil, fmt.Errorf("user %s domain %s is not in allowed domains", userInfo.Email, userInfo.HostedDomain) - } - } - - return &userInfo, nil -} - -// parseClientRequest parses and validates a client registration request -func (s *authService) parseClientRequest(metadata map[string]interface{}) ([]string, []string, error) { - // Extract redirect URIs - redirectURIs := []string{} - if uris, ok := metadata["redirect_uris"].([]interface{}); ok { - for _, uri := range uris { - if uriStr, ok := uri.(string); ok { - redirectURIs = append(redirectURIs, uriStr) - } - } - } - - if len(redirectURIs) == 0 { - return nil, nil, fmt.Errorf("redirect_uris is required") - } - - // Extract scopes - scopes := []string{"read", "write"} // Default MCP scopes - if clientScopes, ok := metadata["scope"].(string); ok { - if strings.TrimSpace(clientScopes) != "" { - scopes = strings.Fields(clientScopes) - } - } - - return redirectURIs, scopes, nil -} diff --git a/internal/oauth/browser.go b/internal/oauth/browser.go deleted file mode 100644 index b236546..0000000 --- a/internal/oauth/browser.go +++ /dev/null @@ -1,110 +0,0 @@ -package oauth - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "time" - - "github.com/dgellow/mcp-front/internal/cookie" - "github.com/dgellow/mcp-front/internal/crypto" - jsonwriter "github.com/dgellow/mcp-front/internal/json" - "github.com/dgellow/mcp-front/internal/log" -) - -// SessionData represents the data stored in the encrypted session cookie -type SessionData struct { - Email string `json:"email"` - Expires time.Time `json:"expires"` -} - -// SSOMiddleware creates middleware for browser-based SSO authentication -func (s *Server) SSOMiddleware() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check for session cookie - sessionValue, err := cookie.GetSession(r) - if err != nil { - // No cookie, redirect directly to Google OAuth - state := s.generateBrowserState(r.URL.String()) - googleURL := s.authService.googleAuthURL(state) - http.Redirect(w, r, googleURL, http.StatusFound) - return - } - - // Decrypt cookie - decrypted, err := s.sessionEncryptor.Decrypt(sessionValue) - if err != nil { - // Invalid cookie, redirect to OAuth - log.LogDebug("Invalid session cookie: %v", err) - cookie.ClearSession(w) // Clear bad cookie - state := s.generateBrowserState(r.URL.String()) - googleURL := s.authService.googleAuthURL(state) - http.Redirect(w, r, googleURL, http.StatusFound) - return - } - - // Parse session data - var sessionData SessionData - if err := json.Unmarshal([]byte(decrypted), &sessionData); err != nil { - // Invalid format - cookie.ClearSession(w) - jsonwriter.WriteUnauthorized(w, "Invalid session") - return - } - - // Check expiration - if time.Now().After(sessionData.Expires) { - // Expired session - log.LogDebug("Session expired for user %s", sessionData.Email) - cookie.ClearSession(w) - // Redirect directly to Google OAuth - state := s.generateBrowserState(r.URL.String()) - googleURL := s.authService.googleAuthURL(state) - http.Redirect(w, r, googleURL, http.StatusFound) - return - } - - // Valid session, set user in context - ctx := context.WithValue(r.Context(), userContextKey, sessionData.Email) - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - -// generateBrowserState creates a secure state parameter for browser SSO -func (s *Server) generateBrowserState(returnURL string) string { - // Generate random nonce - nonce := crypto.GenerateSecureToken() - - // Create signed CSRF token: nonce + HMAC(nonce + returnURL) - // This ensures the token is tied to the specific return URL - data := nonce + ":" + returnURL - signature := crypto.SignData(data, []byte(s.config.EncryptionKey)) - - // Format: "browser:nonce:signature:returnURL" - return fmt.Sprintf("browser:%s:%s:%s", nonce, signature, returnURL) -} - -// setBrowserSessionCookie sets an encrypted session cookie for browser-based authentication -func (s *Server) setBrowserSessionCookie(w http.ResponseWriter, userEmail string) error { - sessionData := SessionData{ - Email: userEmail, - Expires: time.Now().Add(s.config.SessionDuration), - } - - jsonData, err := json.Marshal(sessionData) - if err != nil { - return fmt.Errorf("failed to marshal session data: %w", err) - } - - encrypted, err := s.sessionEncryptor.Encrypt(string(jsonData)) - if err != nil { - return fmt.Errorf("failed to encrypt session: %w", err) - } - - cookie.SetSession(w, encrypted, 24*time.Hour) - - return nil -} diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go deleted file mode 100644 index 5acd728..0000000 --- a/internal/oauth/oauth.go +++ /dev/null @@ -1,578 +0,0 @@ -package oauth - -import ( - "context" - "crypto/rand" - "encoding/json" - "fmt" - "net/http" - "os" - "strings" - "time" - - "github.com/dgellow/mcp-front/internal" - "github.com/dgellow/mcp-front/internal/crypto" - jsonwriter "github.com/dgellow/mcp-front/internal/json" - "github.com/dgellow/mcp-front/internal/log" - "github.com/dgellow/mcp-front/internal/storage" - "github.com/ory/fosite" - "github.com/ory/fosite/compose" -) - -// contextKey is a type for context keys to avoid collisions -type contextKey string - -// userContextKey is the context key for user email -const userContextKey contextKey = "user_email" - -// GetUserFromContext extracts user email from context -func GetUserFromContext(ctx context.Context) (string, bool) { - email, ok := ctx.Value(userContextKey).(string) - return email, ok -} - -// GetUserContextKey returns the context key for user email (for testing) -func GetUserContextKey() contextKey { - return userContextKey -} - -// Server wraps fosite.OAuth2Provider with clean architecture -type Server struct { - provider fosite.OAuth2Provider - storage storage.Storage - authService *authService - config Config - sessionEncryptor crypto.Encryptor // Created once for browser SSO performance -} - -// Config holds OAuth server configuration -type Config struct { - Issuer string - TokenTTL time.Duration - SessionDuration time.Duration // Duration for browser session cookies (default: 24h) - AllowedDomains []string - AllowedOrigins []string // For CORS validation - GoogleClientID string - GoogleClientSecret string - GoogleRedirectURI string - JWTSecret string // Should be provided via environment variable - EncryptionKey string // Should be provided via environment variable - StorageType string // "memory" or "firestore" - GCPProjectID string // Required for firestore storage - FirestoreDatabase string // Optional: Firestore database name (default: "(default)") - FirestoreCollection string // Optional: Collection name for Firestore storage (default: "mcp_front_oauth_clients") -} - -// NewServer creates a new OAuth 2.1 server -func NewServer(config Config, store storage.Storage) (*Server, error) { - // Create session encryptor for browser SSO - key := []byte(config.EncryptionKey) - sessionEncryptor, err := crypto.NewEncryptor(key) - if err != nil { - return nil, fmt.Errorf("failed to create session encryptor: %w", err) - } - log.Logf("Session encryptor initialized for browser SSO") - - // Create auth service (business logic) - authService, err := newAuthService(config) - if err != nil { - return nil, fmt.Errorf("failed to create auth service: %w", err) - } - - // Use provided JWT secret or generate a secure one - var secret []byte - if config.JWTSecret != "" { - secret = []byte(config.JWTSecret) - // Validate JWT secret length for HMAC-SHA512/256 - if len(secret) < 32 { - return nil, fmt.Errorf("JWT secret must be at least 32 bytes long for security, got %d bytes", len(secret)) - } - } else { - secret = make([]byte, 32) - if _, err := rand.Read(secret); err != nil { - return nil, fmt.Errorf("failed to generate JWT secret: %w", err) - } - log.LogWarn("Generated random JWT secret. Set JWT_SECRET env var for persistent tokens across restarts") - } - - // Determine min parameter entropy based on environment - minEntropy := 8 // Production default - enforce secure state parameters (8+ chars) - log.Logf("OAuth server initialization - MCP_FRONT_ENV=%s, isDevelopmentMode=%v", os.Getenv("MCP_FRONT_ENV"), internal.IsDevelopmentMode()) - if internal.IsDevelopmentMode() { - minEntropy = 0 // Development mode - allow weak state parameters for buggy clients - log.LogWarn("MCP_FRONT_ENV=development - weak OAuth state parameters allowed for testing") - } - log.Logf("OAuth MinParameterEntropy set to: %d", minEntropy) - - // Create fosite configuration - fositeConfig := &compose.Config{ - AccessTokenLifespan: config.TokenTTL, - RefreshTokenLifespan: 24 * time.Hour, - AuthorizeCodeLifespan: 10 * time.Minute, - TokenURL: config.Issuer + "/token", - ScopeStrategy: fosite.HierarchicScopeStrategy, - AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, - EnforcePKCEForPublicClients: true, - EnablePKCEPlainChallengeMethod: false, - MinParameterEntropy: minEntropy, - } - - // Create OAuth 2.1 provider - provider := compose.Compose( - fositeConfig, - store.GetMemoryStore(), - &compose.CommonStrategy{ - CoreStrategy: compose.NewOAuth2HMACStrategy(fositeConfig, secret, nil), - }, - nil, // hasher - compose.OAuth2AuthorizeExplicitFactory, - compose.OAuth2ClientCredentialsGrantFactory, - compose.OAuth2PKCEFactory, - compose.OAuth2RefreshTokenGrantFactory, - compose.OAuth2TokenIntrospectionFactory, - ) - - // Set default session duration if not configured - if config.SessionDuration == 0 { - config.SessionDuration = 24 * time.Hour - } - - return &Server{ - provider: provider, - storage: store, - authService: authService, - config: config, - sessionEncryptor: sessionEncryptor, - }, nil -} - -// WellKnownHandler serves OAuth 2.0 Authorization Server Metadata (RFC 8414) -func (s *Server) WellKnownHandler(w http.ResponseWriter, r *http.Request) { - metadata := map[string]interface{}{ - "issuer": s.config.Issuer, - "authorization_endpoint": s.config.Issuer + "/authorize", - "token_endpoint": s.config.Issuer + "/token", - "registration_endpoint": s.config.Issuer + "/register", - "scopes_supported": []string{ - "read", - "write", - }, - "response_types_supported": []string{ - "code", - }, - "grant_types_supported": []string{ - "authorization_code", - "refresh_token", - }, - "code_challenge_methods_supported": []string{ - "S256", - }, - "token_endpoint_auth_methods_supported": []string{ - "none", - "client_secret_post", - }, - "revocation_endpoint": s.config.Issuer + "/revoke", - "introspection_endpoint": s.config.Issuer + "/introspect", - } - - if err := jsonwriter.Write(w, metadata); err != nil { - log.LogErrorWithFields("oauth", "Failed to encode metadata response", map[string]interface{}{ - "error": err.Error(), - }) - } -} - -// AuthorizeHandler handles OAuth 2.0 authorization requests -func (s *Server) AuthorizeHandler(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Debug log the incoming request - log.Logf("Authorization request: %s", r.URL.RawQuery) - clientID := r.URL.Query().Get("client_id") - scopes := r.URL.Query().Get("scope") - redirectURI := r.URL.Query().Get("redirect_uri") - stateParam := r.URL.Query().Get("state") - log.Logf("Client ID: %s, Requested scopes: %s", clientID, scopes) - log.Logf("Requested redirect_uri: %s", redirectURI) - log.Logf("State parameter: '%s' (length: %d)", stateParam, len(stateParam)) - - // In development mode, generate a secure state parameter if missing - // This works around bugs in OAuth clients like MCP Inspector - if internal.IsDevelopmentMode() && len(stateParam) == 0 { - generatedState := crypto.GenerateSecureToken() - log.LogWarn("Development mode: generating state parameter '%s' for buggy client", generatedState) - q := r.URL.Query() - q.Set("state", generatedState) - r.URL.RawQuery = q.Encode() - // Also update the form values - if r.Form == nil { - _ = r.ParseForm() - } - r.Form.Set("state", generatedState) - } - - // Debug: Check what redirect URIs the client actually has - if client, err := s.storage.GetClient(ctx, clientID); err == nil { - log.Logf("Client registered redirect URIs: %v", client.GetRedirectURIs()) - } else { - log.LogError("Client not found: %v", err) - } - - // WORKAROUND: Claude.ai generates a single client ID per OAuth provider domain - // and reuses it forever. If the client registration is lost (server restart, - // storage cleared, etc), Claude.ai has no mechanism to detect this and re-register. - // This auto-registers their client to prevent users from being permanently locked out. - // TODO: Remove once Claude.ai implements proper client registration retry logic - if clientID != "" && - (redirectURI == "https://claude.ai/api/mcp/auth_callback" || - strings.HasPrefix(redirectURI, "https://claude.ai/api/mcp/")) { - - if _, err := s.storage.GetClient(ctx, clientID); err != nil { - log.LogWarn("Auto-registering Claude.ai client %s", clientID) - - // Register Claude.ai's client with their parameters - redirectURIs := []string{redirectURI} - requestedScopes := strings.Fields(scopes) - if len(requestedScopes) == 0 { - requestedScopes = []string{"read", "write"} - } - - s.storage.CreateClient(clientID, redirectURIs, requestedScopes, s.config.Issuer) - } - } - - // Parse and validate the authorization request - ar, err := s.provider.NewAuthorizeRequest(ctx, r) - if err != nil { - log.LogError("Authorization request error: %v", err) - s.provider.WriteAuthorizeError(w, ar, err) - return - } - - // Generate state for Google OAuth flow - state := crypto.GenerateSecureToken() - s.storage.StoreAuthorizeRequest(state, ar) - - // Redirect to Google OAuth - googleURL := s.authService.googleAuthURL(state) - log.Logf("Redirecting to Google OAuth URL: %s", googleURL) - - http.Redirect(w, r, googleURL, http.StatusFound) -} - -// GoogleCallbackHandler handles the callback from Google OAuth -func (s *Server) GoogleCallbackHandler(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - state := r.URL.Query().Get("state") - code := r.URL.Query().Get("code") - - if errMsg := r.URL.Query().Get("error"); errMsg != "" { - errDesc := r.URL.Query().Get("error_description") - log.LogError("Google OAuth error: %s - %s", errMsg, errDesc) - jsonwriter.WriteBadRequest(w, fmt.Sprintf("Authentication failed: %s", errMsg)) - return - } - - if state == "" || code == "" { - log.LogError("Missing state or code in callback") - jsonwriter.WriteBadRequest(w, "Invalid callback parameters") - return - } - - // Check if this is a browser SSO flow - var ar fosite.AuthorizeRequester - var isBrowserFlow bool - var returnURL string - - if strings.HasPrefix(state, "browser:") { - // Browser SSO flow - validate signature and extract return URL - isBrowserFlow = true - // Format: "browser:nonce:signature:returnURL" - parts := strings.SplitN(state, ":", 4) - if len(parts) != 4 { - log.LogError("Invalid browser state format: %s", state) - jsonwriter.WriteBadRequest(w, "Invalid state parameter") - return - } - // parts[0] = "browser", parts[1] = nonce, parts[2] = signature, parts[3] = return URL - nonce := parts[1] - signature := parts[2] - returnURL = parts[3] - - // Validate HMAC signature - data := nonce + ":" + returnURL - if !crypto.ValidateSignedData(data, signature, []byte(s.config.EncryptionKey)) { - log.LogError("Invalid CSRF signature in browser flow") - jsonwriter.WriteBadRequest(w, "Invalid state parameter") - return - } - } else { - // OAuth client flow - retrieve stored authorize request - var found bool - ar, found = s.storage.GetAuthorizeRequest(state) - if !found { - log.LogError("Invalid or expired state: %s", state) - jsonwriter.WriteBadRequest(w, "Invalid or expired authorization request") - return - } - } - - // Exchange code for token with timeout - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - token, err := s.authService.exchangeCodeForToken(ctx, code) - if err != nil { - log.LogError("Google token exchange error: %v", err) - jsonwriter.WriteInternalServerError(w, "Failed to exchange authorization code") - return - } - - // Validate user and get user info - userInfo, err := s.authService.validateUser(ctx, token) - if err != nil { - log.LogError("User validation error: %v", err) - jsonwriter.WriteForbidden(w, "Access denied: user validation failed") - return - } - log.Logf("User validated successfully: %s", userInfo.Email) - - // Handle browser SSO flow - if isBrowserFlow { - // Set session cookie - if err := s.setBrowserSessionCookie(w, userInfo.Email); err != nil { - log.LogError("Failed to set browser session cookie: %v", err) - jsonwriter.WriteInternalServerError(w, "Failed to create session") - return - } - - // Redirect to the original URL - log.Logf("Browser SSO successful for %s, redirecting to %s", userInfo.Email, returnURL) - http.Redirect(w, r, returnURL, http.StatusFound) - return - } - - // Handle OAuth client flow - // Create session with user info - session := NewSession(userInfo) - log.Logf("Session created for user: %s", userInfo.Email) - - log.Logf("Creating authorize response for client: %s", ar.GetClient().GetID()) - response, err := s.provider.NewAuthorizeResponse(ctx, ar, session) - if err != nil { - log.LogError("Failed to create authorize response: %v (type: %T)", err, err) - // Log more details about the error - if fositeErr, ok := err.(*fosite.RFC6749Error); ok { - log.LogError("Fosite error details - Code: %s, Description: %s, Debug: %s", - fositeErr.ErrorField, fositeErr.DescriptionField, fositeErr.DebugField) - } - s.provider.WriteAuthorizeError(w, ar, err) - return - } - - // Continue with normal OAuth flow - s.provider.WriteAuthorizeResponse(w, ar, response) -} - -// TokenHandler handles OAuth 2.0 token requests -func (s *Server) TokenHandler(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Create session for the token exchange - // Note: We create our custom Session type here, and fosite will populate it - // with the session data from the authorization code during NewAccessRequest - session := &Session{DefaultSession: &fosite.DefaultSession{}} - - // Handle token request - this retrieves the session from the authorization code - accessRequest, err := s.provider.NewAccessRequest(ctx, r, session) - if err != nil { - log.LogError("Access request error: %v", err) - s.provider.WriteAccessError(w, accessRequest, err) - return - } - - // At this point, accessRequest.GetSession() contains the session data from - // the authorization phase (including our custom UserInfo). Fosite handles - // the session propagation internally when creating the access token. - - // Generate tokens - response, err := s.provider.NewAccessResponse(ctx, accessRequest) - if err != nil { - log.LogError("Access response error: %v", err) - s.provider.WriteAccessError(w, accessRequest, err) - return - } - - // Write token response - s.provider.WriteAccessResponse(w, accessRequest, response) -} - -// buildClientRegistrationResponse creates the registration response for a client -func (s *Server) buildClientRegistrationResponse(client *fosite.DefaultClient, tokenEndpointAuthMethod string, clientSecret string) map[string]interface{} { - response := map[string]interface{}{ - "client_id": client.GetID(), - "client_id_issued_at": time.Now().Unix(), - "redirect_uris": client.GetRedirectURIs(), - "grant_types": client.GetGrantTypes(), - "response_types": client.GetResponseTypes(), - "scope": strings.Join(client.GetScopes(), " "), // Space-separated string - "token_endpoint_auth_method": tokenEndpointAuthMethod, - } - - // Include client_secret only for confidential clients - if clientSecret != "" { - response["client_secret"] = clientSecret - } - - return response -} - -// RegisterHandler handles dynamic client registration (RFC 7591) -func (s *Server) RegisterHandler(w http.ResponseWriter, r *http.Request) { - log.Logf("Register handler called: %s %s", r.Method, r.URL.Path) - - if r.Method != http.MethodPost { - jsonwriter.WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed") - return - } - - // Parse client metadata - var metadata map[string]interface{} - if err := json.NewDecoder(r.Body).Decode(&metadata); err != nil { - jsonwriter.WriteBadRequest(w, "Invalid request body") - return - } - - // Parse client request - redirectURIs, scopes, err := s.authService.parseClientRequest(metadata) - if err != nil { - log.LogError("Client request parsing error: %v", err) - jsonwriter.WriteBadRequest(w, err.Error()) - return - } - - // Check if client requests client_secret_post authentication - tokenEndpointAuthMethod := "none" - var client *fosite.DefaultClient - var plaintextSecret string - clientID := crypto.GenerateSecureToken() - - if authMethod, ok := metadata["token_endpoint_auth_method"].(string); ok && authMethod == "client_secret_post" { - // Generate and hash secret for confidential client - generatedSecret, err := crypto.GenerateClientSecret() - if err != nil { - log.LogError("Failed to generate client secret: %v", err) - jsonwriter.WriteInternalServerError(w, "Internal server error") - return - } - - hashedSecret, err := crypto.HashClientSecret(generatedSecret) - if err != nil { - log.LogError("Failed to hash client secret: %v", err) - jsonwriter.WriteInternalServerError(w, "Internal server error") - return - } - - client = s.storage.CreateConfidentialClient(clientID, hashedSecret, redirectURIs, scopes, s.config.Issuer) - tokenEndpointAuthMethod = "client_secret_post" - plaintextSecret = generatedSecret // Save for response - } else { - // Create public client (default behavior) - client = s.storage.CreateClient(clientID, redirectURIs, scopes, s.config.Issuer) - } - - // Build response using helper function - response := s.buildClientRegistrationResponse(client, tokenEndpointAuthMethod, plaintextSecret) - - if err := jsonwriter.WriteResponse(w, http.StatusCreated, response); err != nil { - log.LogErrorWithFields("oauth", "Failed to encode register response", map[string]interface{}{ - "error": err.Error(), - }) - } -} - -// DebugClientsHandler shows all registered clients (for debugging) -func (s *Server) DebugClientsHandler(w http.ResponseWriter, r *http.Request) { - clients := make(map[string]interface{}) - - // Get all clients thread-safely - allClients := s.storage.GetAllClients() - for clientID, client := range allClients { - clients[clientID] = map[string]interface{}{ - "redirect_uris": client.GetRedirectURIs(), - "scopes": client.GetScopes(), - "grant_types": client.GetGrantTypes(), - "response_types": client.GetResponseTypes(), - } - } - - response := map[string]interface{}{ - "total_clients": len(clients), - "clients": clients, - } - - if err := jsonwriter.Write(w, response); err != nil { - log.LogErrorWithFields("oauth", "Failed to encode debug response", map[string]interface{}{ - "error": err.Error(), - }) - } -} - -// ValidateTokenMiddleware creates middleware that validates OAuth tokens -func (s *Server) ValidateTokenMiddleware() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Extract token from Authorization header - auth := r.Header.Get("Authorization") - if auth == "" { - jsonwriter.WriteUnauthorized(w, "Missing authorization header") - return - } - - parts := strings.Split(auth, " ") - if len(parts) != 2 || parts[0] != "Bearer" { - jsonwriter.WriteUnauthorized(w, "Invalid authorization header format") - return - } - - token := parts[1] - - // Validate token and extract session - // IMPORTANT: Fosite's IntrospectToken behavior is non-intuitive: - // - The session parameter passed to IntrospectToken is NOT populated with data - // - This is documented fosite behavior, not a bug - // - The actual session data must be retrieved from the returned AccessRequester - // See: https://github.com/ory/fosite/issues/256 - session := &Session{DefaultSession: &fosite.DefaultSession{}} - _, accessRequest, err := s.provider.IntrospectToken(ctx, token, fosite.AccessToken, session) - if err != nil { - jsonwriter.WriteUnauthorized(w, "Invalid or expired token") - return - } - - // Get the actual session from the access request (not the input session parameter) - // This is the correct way to retrieve session data after token introspection - var userEmail string - if accessRequest != nil { - if reqSession, ok := accessRequest.GetSession().(*Session); ok { - if reqSession.UserInfo != nil && reqSession.UserInfo.Email != "" { - userEmail = reqSession.UserInfo.Email - } - } - } - - // Pass user info through context - if userEmail != "" { - ctx = context.WithValue(ctx, userContextKey, userEmail) - r = r.WithContext(ctx) - } - - next.ServeHTTP(w, r) - }) - } -} - -// logf is a simple logging helper diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go deleted file mode 100644 index cd968ec..0000000 --- a/internal/oauth/oauth_test.go +++ /dev/null @@ -1,532 +0,0 @@ -package oauth - -import ( - "bytes" - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/dgellow/mcp-front/internal/storage" -) - -func TestNewServer(t *testing.T) { - config := Config{ - Issuer: "https://test.example.com", - TokenTTL: time.Hour, - AllowedDomains: []string{"example.com"}, - GoogleClientID: "test-client-id", - GoogleClientSecret: "test-client-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-secret-32-bytes-long-for-testing", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - } - - store := storage.NewMemoryStorage() - server, err := NewServer(config, store) - if err != nil { - t.Fatalf("Failed to create OAuth server: %v", err) - } - - if server.provider == nil { - t.Error("OAuth provider not initialized") - } - - if server.storage == nil { - t.Error("Storage not initialized") - } - - if server.config.Issuer != config.Issuer { - t.Error("Config not properly stored") - } -} - -func TestNewServerWithoutJWTSecret(t *testing.T) { - config := Config{ - Issuer: "https://test.example.com", - TokenTTL: time.Hour, - AllowedDomains: []string{"example.com"}, - GoogleClientID: "test-client-id", - GoogleClientSecret: "test-client-secret", - GoogleRedirectURI: "https://test.example.com/callback", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - // JWTSecret is empty - should generate random secret - } - - store := storage.NewMemoryStorage() - server, err := NewServer(config, store) - if err != nil { - t.Fatalf("Failed to create OAuth server: %v", err) - } - - if server == nil { - t.Error("Server should be created even without JWT secret") - } -} - -func TestWellKnownHandler(t *testing.T) { - config := Config{ - Issuer: "https://test.example.com", - TokenTTL: time.Hour, - JWTSecret: "test-secret-32-bytes-long-for-testing", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - } - - store := storage.NewMemoryStorage() - server, err := NewServer(config, store) - if err != nil { - t.Fatalf("Failed to create server: %v", err) - } - - req := httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil) - w := httptest.NewRecorder() - - server.WellKnownHandler(w, req) - - if w.Code != http.StatusOK { - t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) - } - - if w.Header().Get("Content-Type") != "application/json" { - t.Errorf("Expected Content-Type application/json, got %s", w.Header().Get("Content-Type")) - } - - var metadata map[string]interface{} - if err := json.NewDecoder(w.Body).Decode(&metadata); err != nil { - t.Fatalf("Failed to decode response: %v", err) - } - - // Check required fields - requiredFields := []string{ - "issuer", - "authorization_endpoint", - "token_endpoint", - "registration_endpoint", - "scopes_supported", - "response_types_supported", - "grant_types_supported", - "code_challenge_methods_supported", - } - - for _, field := range requiredFields { - if _, ok := metadata[field]; !ok { - t.Errorf("Missing required field: %s", field) - } - } - - // Verify issuer - if issuer, ok := metadata["issuer"].(string); !ok || issuer != config.Issuer { - t.Errorf("Expected issuer %s, got %v", config.Issuer, metadata["issuer"]) - } -} - -func TestRegisterHandler(t *testing.T) { - config := Config{ - Issuer: "https://test.example.com", - TokenTTL: time.Hour, - JWTSecret: "test-secret-32-bytes-long-for-testing", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - } - - store := storage.NewMemoryStorage() - server, err := NewServer(config, store) - if err != nil { - t.Fatalf("Failed to create server: %v", err) - } - - // Create registration request - reqBody := map[string]interface{}{ - "redirect_uris": []string{"https://client.example.com/callback"}, - "scope": "read write", - } - body, _ := json.Marshal(reqBody) - - req := httptest.NewRequest("POST", "/register", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - - server.RegisterHandler(w, req) - - if w.Code != http.StatusCreated { - t.Errorf("Expected status %d, got %d: %s", http.StatusCreated, w.Code, w.Body.String()) - } - - var response map[string]interface{} - if err := json.NewDecoder(w.Body).Decode(&response); err != nil { - t.Fatalf("Failed to decode response: %v", err) - } - - // Verify response - if clientID, ok := response["client_id"].(string); !ok || clientID == "" { - t.Error("Missing or empty client_id") - } - - // Verify scope is returned as a string, not array - if scope, ok := response["scope"].(string); !ok || scope != "read write" { - t.Errorf("Expected scope 'read write', got %v", response["scope"]) - } - - // Verify redirect_uris - if uris, ok := response["redirect_uris"].([]interface{}); !ok || len(uris) != 1 { - t.Errorf("Expected redirect_uris with 1 URI, got %v", response["redirect_uris"]) - } -} - -func TestRegisterHandlerInvalidMethod(t *testing.T) { - config := Config{ - Issuer: "https://test.example.com", - TokenTTL: time.Hour, - JWTSecret: "test-secret-32-bytes-long-for-testing", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - } - - store := storage.NewMemoryStorage() - server, err := NewServer(config, store) - if err != nil { - t.Fatalf("Failed to create server: %v", err) - } - - req := httptest.NewRequest("GET", "/register", nil) - w := httptest.NewRecorder() - - server.RegisterHandler(w, req) - - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("Expected status %d for GET request, got %d", http.StatusMethodNotAllowed, w.Code) - } -} - -func TestClientRegistrationAndDebugEndpoint(t *testing.T) { - config := Config{ - Issuer: "https://test.example.com", - TokenTTL: time.Hour, - JWTSecret: "test-secret-32-bytes-long-for-testing", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - } - - store := storage.NewMemoryStorage() - server, err := NewServer(config, store) - if err != nil { - t.Fatalf("Failed to create server: %v", err) - } - - // Register a client - reqBody := map[string]interface{}{ - "redirect_uris": []string{"http://localhost:3000/callback"}, - "scope": "read write", - } - body, _ := json.Marshal(reqBody) - - req := httptest.NewRequest("POST", "/register", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - - server.RegisterHandler(w, req) - - if w.Code != http.StatusCreated { - t.Fatalf("Register failed with status %d: %s", w.Code, w.Body.String()) - } - - var registerResp map[string]interface{} - _ = json.NewDecoder(w.Body).Decode(®isterResp) - clientID := registerResp["client_id"].(string) - - // Check debug endpoint - req = httptest.NewRequest("GET", "/debug/clients", nil) - w = httptest.NewRecorder() - - server.DebugClientsHandler(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("Debug endpoint failed with status %d", w.Code) - } - - var debugResp map[string]interface{} - _ = json.NewDecoder(w.Body).Decode(&debugResp) - - if total, ok := debugResp["total_clients"].(float64); !ok || total != 1 { - t.Errorf("Expected 1 client, got %v", debugResp["total_clients"]) - } - - clients, ok := debugResp["clients"].(map[string]interface{}) - if !ok { - t.Fatal("clients field is not a map") - } - - if _, exists := clients[clientID]; !exists { - t.Errorf("Client %s not found in debug output", clientID) - } - - t.Logf("✅ Successfully registered client %s and verified in debug endpoint", clientID) -} - -func TestStorageArchitecture(t *testing.T) { - store := storage.NewMemoryStorage() - - // Test client creation - t.Run("client_creation", func(t *testing.T) { - clientID := "test-client-123" - redirectURIs := []string{"https://example.com/callback"} - scopes := []string{"read", "write"} - issuer := "https://test.example.com" - - client := store.CreateClient(clientID, redirectURIs, scopes, issuer) - - if client.GetID() != clientID { - t.Errorf("Expected client ID %s, got %s", clientID, client.GetID()) - } - if len(client.GetScopes()) != 2 { - t.Errorf("Expected 2 scopes, got %d", len(client.GetScopes())) - } - if client.GetRedirectURIs()[0] != redirectURIs[0] { - t.Errorf("Expected redirect URI %s, got %s", redirectURIs[0], client.GetRedirectURIs()[0]) - } - }) - - // Test client retrieval - t.Run("client_retrieval", func(t *testing.T) { - clientID := "test-client-456" - redirectURIs := []string{"https://example.com/callback"} - scopes := []string{"read"} - issuer := "https://test.example.com" - - // Create client - originalClient := store.CreateClient(clientID, redirectURIs, scopes, issuer) - - // Retrieve client - retrievedClient, err := store.GetClient(context.Background(), clientID) - if err != nil { - t.Fatalf("Failed to retrieve client: %v", err) - } - - if retrievedClient.GetID() != originalClient.GetID() { - t.Errorf("Retrieved client ID doesn't match original") - } - }) - - // Test thread-safe operations - t.Run("thread_safety", func(t *testing.T) { - // This test verifies the GetAllClients method works correctly - // and doesn't race with concurrent access - clients := store.GetAllClients() - if clients == nil { - t.Error("GetAllClients should return a map, not nil") - } - - // Should be able to call this multiple times safely - clients2 := store.GetAllClients() - if len(clients) != len(clients2) { - t.Error("GetAllClients should return consistent results") - } - }) -} - -func TestAuthServiceArchitecture(t *testing.T) { - config := Config{ - GoogleClientID: "test-client-id", - GoogleClientSecret: "test-client-secret", - GoogleRedirectURI: "https://test.example.com/callback", - AllowedDomains: []string{"example.com"}, - } - - authService, err := newAuthService(config) - if err != nil { - t.Fatalf("Failed to create auth service: %v", err) - } - - // Test client request parsing - t.Run("client_request_parsing", func(t *testing.T) { - metadata := map[string]interface{}{ - "redirect_uris": []interface{}{"https://example.com/callback", "https://example.com/callback2"}, - "scope": "read write execute", - } - - redirectURIs, scopes, err := authService.parseClientRequest(metadata) - if err != nil { - t.Fatalf("Failed to parse client request: %v", err) - } - - if len(redirectURIs) != 2 { - t.Errorf("Expected 2 redirect URIs, got %d", len(redirectURIs)) - } - - if len(scopes) != 3 { - t.Errorf("Expected 3 scopes, got %d", len(scopes)) - } - - if scopes[0] != "read" || scopes[1] != "write" || scopes[2] != "execute" { - t.Errorf("Scopes not parsed correctly: %v", scopes) - } - }) - - // Test missing redirect URIs - t.Run("missing_redirect_uris", func(t *testing.T) { - metadata := map[string]interface{}{ - "scope": "read write", - } - - _, _, err := authService.parseClientRequest(metadata) - if err == nil { - t.Error("Expected error for missing redirect_uris") - } - }) - - // Test default scopes - t.Run("default_scopes", func(t *testing.T) { - metadata := map[string]interface{}{ - "redirect_uris": []interface{}{"https://example.com/callback"}, - } - - _, scopes, err := authService.parseClientRequest(metadata) - if err != nil { - t.Fatalf("Failed to parse client request: %v", err) - } - - // Should get default MCP scopes - if len(scopes) != 2 || scopes[0] != "read" || scopes[1] != "write" { - t.Errorf("Expected default scopes [read, write], got %v", scopes) - } - }) -} - -// TestClaudeAIWorkaround tests the auto-registration workaround for Claude.ai -func TestClaudeAIWorkaround(t *testing.T) { - config := Config{ - Issuer: "https://test.example.com", - TokenTTL: time.Hour, - AllowedDomains: []string{"example.com"}, - GoogleClientID: "test-client-id", - GoogleClientSecret: "test-client-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-secret-32-bytes-long-for-testing", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - } - - store := storage.NewMemoryStorage() - server, err := NewServer(config, store) - if err != nil { - t.Fatalf("Failed to create OAuth server: %v", err) - } - - // Test cases for Claude.ai workaround - tests := []struct { - name string - clientID string - redirectURI string - shouldAutoRegister bool - }{ - { - name: "Claude.ai auth callback - should auto-register", - clientID: "claude-test-client-123", - redirectURI: "https://claude.ai/api/mcp/auth_callback", - shouldAutoRegister: true, - }, - { - name: "Claude.ai MCP endpoint - should auto-register", - clientID: "claude-test-client-456", - redirectURI: "https://claude.ai/api/mcp/something", - shouldAutoRegister: true, - }, - { - name: "Fake Claude domain - should NOT auto-register", - clientID: "fake-claude-client", - redirectURI: "https://myfakeclaude.ai/api/mcp/auth_callback", - shouldAutoRegister: false, - }, - { - name: "Non-Claude client - should NOT auto-register", - clientID: "other-client", - redirectURI: "https://example.com/callback", - shouldAutoRegister: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Ensure client doesn't exist - _, err := store.GetClient(context.Background(), tt.clientID) - if err == nil { - t.Fatalf("Client %s should not exist initially", tt.clientID) - } - - // Create request - req := httptest.NewRequest("GET", "/authorize", nil) - q := req.URL.Query() - q.Set("client_id", tt.clientID) - q.Set("redirect_uri", tt.redirectURI) - q.Set("response_type", "code") - q.Set("scope", "read write") - q.Set("state", "test-state-123") - q.Set("code_challenge", "test-challenge") - q.Set("code_challenge_method", "S256") - req.URL.RawQuery = q.Encode() - - w := httptest.NewRecorder() - - // Call the handler - server.AuthorizeHandler(w, req) - - // Check if client was auto-registered - _, err = store.GetClient(context.Background(), tt.clientID) - if tt.shouldAutoRegister { - if err != nil { - t.Errorf("Claude.ai client should have been auto-registered but wasn't") - } - // Verify the auto-registered client has correct redirect URI - client, _ := store.GetClient(context.Background(), tt.clientID) - if len(client.GetRedirectURIs()) != 1 || client.GetRedirectURIs()[0] != tt.redirectURI { - t.Errorf("Auto-registered client has wrong redirect URI: %v", client.GetRedirectURIs()) - } - } else { - if err == nil { - t.Errorf("Non-Claude.ai client should NOT have been auto-registered but was") - } - } - - // Clean up for next test - if tt.shouldAutoRegister { - // Remove the auto-registered client - delete(store.MemoryStore.Clients, tt.clientID) - } - }) - } - - // Test that existing Claude.ai clients are not re-created - t.Run("Existing Claude.ai client - should not recreate", func(t *testing.T) { - clientID := "existing-claude-client" - redirectURI := "https://claude.ai/api/mcp/auth_callback" - - // Pre-register client with specific scopes - originalScopes := []string{"read", "write", "admin"} - store.CreateClient(clientID, []string{redirectURI}, originalScopes, config.Issuer) - - // Create request - req := httptest.NewRequest("GET", "/authorize", nil) - q := req.URL.Query() - q.Set("client_id", clientID) - q.Set("redirect_uri", redirectURI) - q.Set("response_type", "code") - q.Set("scope", "read") // Different scope than registered - q.Set("state", "test-state-456") - q.Set("code_challenge", "test-challenge") - q.Set("code_challenge_method", "S256") - req.URL.RawQuery = q.Encode() - - w := httptest.NewRecorder() - - // Call the handler - server.AuthorizeHandler(w, req) - - // Verify client still has original scopes (not overwritten) - client, err := store.GetClient(context.Background(), clientID) - if err != nil { - t.Fatalf("Client should still exist: %v", err) - } - - if len(client.GetScopes()) != len(originalScopes) { - t.Errorf("Client scopes were modified. Expected %v, got %v", originalScopes, client.GetScopes()) - } - }) -} diff --git a/internal/oauth/provider.go b/internal/oauth/provider.go new file mode 100644 index 0000000..4291187 --- /dev/null +++ b/internal/oauth/provider.go @@ -0,0 +1,171 @@ +package oauth + +import ( + "context" + "crypto/rand" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/crypto" + "github.com/dgellow/mcp-front/internal/envutil" + "github.com/dgellow/mcp-front/internal/log" + "github.com/dgellow/mcp-front/internal/oauthsession" + "github.com/dgellow/mcp-front/internal/storage" + "github.com/ory/fosite" + "github.com/ory/fosite/compose" +) + +// userContextKey is the context key for user email +const userContextKey contextKey = "user_email" + +// GetUserFromContext extracts user email from context +func GetUserFromContext(ctx context.Context) (string, bool) { + email, ok := ctx.Value(userContextKey).(string) + return email, ok +} + +// GetUserContextKey returns the context key for user email (for testing) +func GetUserContextKey() contextKey { + return userContextKey +} + +// NewOAuthProvider creates a new OAuth 2.1 provider with clean dependency injection +func NewOAuthProvider(oauthConfig config.OAuthAuthConfig, store storage.Storage, jwtSecret []byte) (fosite.OAuth2Provider, error) { + // Use TTL duration from config + tokenTTL := oauthConfig.TokenTTL + if tokenTTL == 0 { + tokenTTL = time.Hour // Default 1 hour + } + // Validate JWT secret length for HMAC-SHA512/256 + if len(jwtSecret) < 32 { + return nil, fmt.Errorf("JWT secret must be at least 32 bytes long for security, got %d bytes", len(jwtSecret)) + } + + // Determine min parameter entropy based on environment + minEntropy := 8 // Production default - enforce secure state parameters (8+ chars) + log.Logf("OAuth provider initialization - MCP_FRONT_ENV=%s, isDevelopmentMode=%v", os.Getenv("MCP_FRONT_ENV"), envutil.IsDev()) + if envutil.IsDev() { + minEntropy = 0 // Development mode - allow empty state parameters + log.LogWarn("Development mode enabled - OAuth security checks relaxed (state parameter entropy: %d)", minEntropy) + } + + // Configure fosite + fositeConfig := &compose.Config{ + AccessTokenLifespan: tokenTTL, + RefreshTokenLifespan: tokenTTL * 2, + AuthorizeCodeLifespan: 10 * time.Minute, + TokenURL: oauthConfig.Issuer + "/token", + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + EnforcePKCEForPublicClients: true, + EnablePKCEPlainChallengeMethod: false, + MinParameterEntropy: minEntropy, + } + + // Create provider using compose with specific factories + provider := compose.Compose( + fositeConfig, + store, + &compose.CommonStrategy{ + CoreStrategy: compose.NewOAuth2HMACStrategy(fositeConfig, jwtSecret, nil), + }, + nil, // hasher + compose.OAuth2AuthorizeExplicitFactory, + compose.OAuth2ClientCredentialsGrantFactory, + compose.OAuth2PKCEFactory, + compose.OAuth2RefreshTokenGrantFactory, + compose.OAuth2TokenIntrospectionFactory, + ) + + return provider, nil +} + +// NewSessionEncryptor creates a new session encryptor for browser SSO +func NewSessionEncryptor(encryptionKey []byte) (crypto.Encryptor, error) { + sessionEncryptor, err := crypto.NewEncryptor(encryptionKey) + if err != nil { + return nil, fmt.Errorf("failed to create session encryptor: %w", err) + } + log.Logf("Session encryptor initialized for browser SSO") + return sessionEncryptor, nil +} + +// GenerateJWTSecret generates a secure JWT secret if none is provided +func GenerateJWTSecret(providedSecret string) ([]byte, error) { + if providedSecret != "" { + secret := []byte(providedSecret) + // Validate JWT secret length for HMAC-SHA512/256 + if len(secret) < 32 { + return nil, fmt.Errorf("JWT secret must be at least 32 bytes long for security, got %d bytes", len(secret)) + } + return secret, nil + } + + // Generate a secure random secret + secret := make([]byte, 32) + if _, err := rand.Read(secret); err != nil { + return nil, fmt.Errorf("failed to generate JWT secret: %w", err) + } + log.LogWarn("Generated random JWT secret. Set JWT_SECRET env var for persistent tokens across restarts") + return secret, nil +} + +// NewValidateTokenMiddleware creates middleware that validates OAuth tokens using dependency injection +func NewValidateTokenMiddleware(provider fosite.OAuth2Provider) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Extract token from Authorization header + auth := r.Header.Get("Authorization") + if auth == "" { + http.Error(w, "Missing authorization header", http.StatusUnauthorized) + return + } + + parts := strings.Split(auth, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + http.Error(w, "Invalid authorization header format", http.StatusUnauthorized) + return + } + + token := parts[1] + + // Validate token and extract session + // IMPORTANT: Fosite's IntrospectToken behavior is non-intuitive: + // - The session parameter passed to IntrospectToken is NOT populated with data + // - This is documented fosite behavior, not a bug + // - The actual session data must be retrieved from the returned AccessRequester + // See: https://github.com/ory/fosite/issues/256 + session := &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}} + _, accessRequest, err := provider.IntrospectToken(ctx, token, fosite.AccessToken, session) + if err != nil { + http.Error(w, "Invalid or expired token", http.StatusUnauthorized) + return + } + + // Get the actual session from the access request (not the input session parameter) + // This is the correct way to retrieve session data after token introspection + var userEmail string + if accessRequest != nil { + if reqSession, ok := accessRequest.GetSession().(*oauthsession.Session); ok { + if reqSession.UserInfo.Email != "" { + userEmail = reqSession.UserInfo.Email + } + } + } + + // Pass user info through context + if userEmail != "" { + ctx = context.WithValue(ctx, userContextKey, userEmail) + r = r.WithContext(ctx) + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/oauth/types.go b/internal/oauth/types.go new file mode 100644 index 0000000..d9b9df3 --- /dev/null +++ b/internal/oauth/types.go @@ -0,0 +1,4 @@ +package oauth + +// contextKey is the type used for context keys to avoid collisions +type contextKey string diff --git a/internal/oauth/session.go b/internal/oauthsession/session.go similarity index 79% rename from internal/oauth/session.go rename to internal/oauthsession/session.go index f23b750..9caf225 100644 --- a/internal/oauth/session.go +++ b/internal/oauthsession/session.go @@ -1,27 +1,20 @@ -package oauth +package oauthsession import ( "time" + "github.com/dgellow/mcp-front/internal/googleauth" "github.com/ory/fosite" ) // Session extends DefaultSession with user information type Session struct { *fosite.DefaultSession - UserInfo *UserInfo `json:"user_info,omitempty"` -} - -// Clone implements fosite.Session -func (s *Session) Clone() fosite.Session { - return &Session{ - DefaultSession: s.DefaultSession.Clone().(*fosite.DefaultSession), - UserInfo: s.UserInfo, - } + UserInfo googleauth.UserInfo `json:"user_info,omitempty"` } // NewSession creates a new session with user info -func NewSession(userInfo *UserInfo) *Session { +func NewSession(userInfo googleauth.UserInfo) *Session { return &Session{ DefaultSession: &fosite.DefaultSession{ ExpiresAt: map[fosite.TokenType]time.Time{ @@ -34,3 +27,11 @@ func NewSession(userInfo *UserInfo) *Session { UserInfo: userInfo, } } + +// Clone implements fosite.Session +func (s *Session) Clone() fosite.Session { + return &Session{ + DefaultSession: s.DefaultSession.Clone().(*fosite.DefaultSession), + UserInfo: s.UserInfo, + } +} diff --git a/internal/server/admin_handlers.go b/internal/server/admin_handlers.go index 2986989..88ab8e3 100644 --- a/internal/server/admin_handlers.go +++ b/internal/server/admin_handlers.go @@ -8,7 +8,7 @@ import ( "strings" "time" - "github.com/dgellow/mcp-front/internal/auth" + "github.com/dgellow/mcp-front/internal/adminauth" "github.com/dgellow/mcp-front/internal/client" "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/crypto" @@ -21,13 +21,13 @@ import ( // AdminHandlers handles the admin UI type AdminHandlers struct { storage storage.Storage - config *config.Config + config config.Config sessionManager *client.StdioSessionManager encryptionKey []byte // For HMAC-based CSRF tokens } // NewAdminHandlers creates a new admin handlers instance -func NewAdminHandlers(storage storage.Storage, config *config.Config, sessionManager *client.StdioSessionManager, encryptionKey string) *AdminHandlers { +func NewAdminHandlers(storage storage.Storage, config config.Config, sessionManager *client.StdioSessionManager, encryptionKey string) *AdminHandlers { return &AdminHandlers{ storage: storage, config: config, @@ -108,7 +108,7 @@ func (h *AdminHandlers) DashboardHandler(w http.ResponseWriter, r *http.Request) } // Double-check admin status - if !auth.IsAdmin(r.Context(), userEmail, h.config.Proxy.Admin, h.storage) { + if !adminauth.IsAdmin(r.Context(), userEmail, h.config.Proxy.Admin, h.storage) { jsonwriter.WriteForbidden(w, "Forbidden") return } @@ -126,7 +126,7 @@ func (h *AdminHandlers) DashboardHandler(w http.ResponseWriter, r *http.Request) // Load all data rawUsers, err := h.storage.GetAllUsers(r.Context()) if err != nil { - log.LogErrorWithFields("admin", "Failed to get users", map[string]interface{}{ + log.LogErrorWithFields("admin", "Failed to get users", map[string]any{ "error": err.Error(), }) rawUsers = []storage.UserInfo{} // Empty list on error @@ -137,13 +137,13 @@ func (h *AdminHandlers) DashboardHandler(w http.ResponseWriter, r *http.Request) for i, user := range rawUsers { users[i] = UserInfoWithAdminType{ UserInfo: user, - IsConfigAdmin: auth.IsConfigAdmin(user.Email, h.config.Proxy.Admin), + IsConfigAdmin: adminauth.IsConfigAdmin(user.Email, h.config.Proxy.Admin), } } sessions, err := h.storage.GetActiveSessions(r.Context()) if err != nil { - log.LogErrorWithFields("admin", "Failed to get sessions", map[string]interface{}{ + log.LogErrorWithFields("admin", "Failed to get sessions", map[string]any{ "error": err.Error(), }) sessions = []storage.ActiveSession{} // Empty list on error @@ -154,7 +154,7 @@ func (h *AdminHandlers) DashboardHandler(w http.ResponseWriter, r *http.Request) // Generate CSRF token csrfToken, err := h.generateCSRFToken() if err != nil { - log.LogErrorWithFields("admin", "Failed to generate CSRF token", map[string]interface{}{ + log.LogErrorWithFields("admin", "Failed to generate CSRF token", map[string]any{ "error": err.Error(), }) jsonwriter.WriteInternalServerError(w, "Internal server error") @@ -175,7 +175,7 @@ func (h *AdminHandlers) DashboardHandler(w http.ResponseWriter, r *http.Request) w.Header().Set("Content-Type", "text/html; charset=utf-8") if err := adminPageTemplate.Execute(w, data); err != nil { - log.LogErrorWithFields("admin", "Failed to render admin page", map[string]interface{}{ + log.LogErrorWithFields("admin", "Failed to render admin page", map[string]any{ "error": err.Error(), }) jsonwriter.WriteInternalServerError(w, "Internal server error") @@ -197,7 +197,7 @@ func (h *AdminHandlers) UserActionHandler(w http.ResponseWriter, r *http.Request } // Double-check admin status - if !auth.IsAdmin(r.Context(), userEmail, h.config.Proxy.Admin, h.storage) { + if !adminauth.IsAdmin(r.Context(), userEmail, h.config.Proxy.Admin, h.storage) { jsonwriter.WriteForbidden(w, "Forbidden") return } @@ -248,7 +248,7 @@ func (h *AdminHandlers) UserActionHandler(w http.ResponseWriter, r *http.Request if currentEnabled { message = fmt.Sprintf("User %s disabled", targetEmail) // Audit log - log.LogInfoWithFields("admin", "User disabled", map[string]interface{}{ + log.LogInfoWithFields("admin", "User disabled", map[string]any{ "admin_email": userEmail, "target_email": targetEmail, "action": "disable", @@ -256,7 +256,7 @@ func (h *AdminHandlers) UserActionHandler(w http.ResponseWriter, r *http.Request } else { message = fmt.Sprintf("User %s enabled", targetEmail) // Audit log - log.LogInfoWithFields("admin", "User enabled", map[string]interface{}{ + log.LogInfoWithFields("admin", "User enabled", map[string]any{ "admin_email": userEmail, "target_email": targetEmail, "action": "enable", @@ -272,7 +272,7 @@ func (h *AdminHandlers) UserActionHandler(w http.ResponseWriter, r *http.Request } else { message = fmt.Sprintf("User %s deleted", targetEmail) // Audit log - log.LogInfoWithFields("admin", "User deleted", map[string]interface{}{ + log.LogInfoWithFields("admin", "User deleted", map[string]any{ "admin_email": userEmail, "target_email": targetEmail, "action": "delete", @@ -309,7 +309,7 @@ func (h *AdminHandlers) UserActionHandler(w http.ResponseWriter, r *http.Request } else { message = fmt.Sprintf("User %s promoted to admin", targetEmail) // Audit log - log.LogInfoWithFields("admin", "User promoted to admin", map[string]interface{}{ + log.LogInfoWithFields("admin", "User promoted to admin", map[string]any{ "admin_email": userEmail, "target_email": targetEmail, "action": "promote", @@ -323,7 +323,7 @@ func (h *AdminHandlers) UserActionHandler(w http.ResponseWriter, r *http.Request if targetEmail == userEmail { message = "Cannot demote yourself" messageType = "error" - } else if auth.IsConfigAdmin(targetEmail, h.config.Proxy.Admin) { + } else if adminauth.IsConfigAdmin(targetEmail, h.config.Proxy.Admin) { // Prevent demoting config admins message = "Cannot demote config-defined admins" messageType = "error" @@ -334,7 +334,7 @@ func (h *AdminHandlers) UserActionHandler(w http.ResponseWriter, r *http.Request } else { message = fmt.Sprintf("User %s demoted from admin", targetEmail) // Audit log - log.LogInfoWithFields("admin", "User demoted from admin", map[string]interface{}{ + log.LogInfoWithFields("admin", "User demoted from admin", map[string]any{ "admin_email": userEmail, "target_email": targetEmail, "action": "demote", @@ -367,7 +367,7 @@ func (h *AdminHandlers) SessionActionHandler(w http.ResponseWriter, r *http.Requ } // Double-check admin status - if !auth.IsAdmin(r.Context(), userEmail, h.config.Proxy.Admin, h.storage) { + if !adminauth.IsAdmin(r.Context(), userEmail, h.config.Proxy.Admin, h.storage) { jsonwriter.WriteForbidden(w, "Forbidden") return } @@ -421,7 +421,7 @@ func (h *AdminHandlers) SessionActionHandler(w http.ResponseWriter, r *http.Requ } else { message = "Session revoked" // Audit log - log.LogInfoWithFields("admin", "Session revoked", map[string]interface{}{ + log.LogInfoWithFields("admin", "Session revoked", map[string]any{ "admin_email": userEmail, "session_id": sessionID, "action": "revoke_session", @@ -453,7 +453,7 @@ func (h *AdminHandlers) LoggingActionHandler(w http.ResponseWriter, r *http.Requ } // Double-check admin status - if !auth.IsAdmin(r.Context(), userEmail, h.config.Proxy.Admin, h.storage) { + if !adminauth.IsAdmin(r.Context(), userEmail, h.config.Proxy.Admin, h.storage) { jsonwriter.WriteForbidden(w, "Forbidden") return } @@ -487,7 +487,7 @@ func (h *AdminHandlers) LoggingActionHandler(w http.ResponseWriter, r *http.Requ message = fmt.Sprintf("Log level changed to %s", logLevel) // Log the change at INFO level - log.LogInfoWithFields("admin", "Log level changed by admin", map[string]interface{}{ + log.LogInfoWithFields("admin", "Log level changed by admin", map[string]any{ "new_level": logLevel, "admin": userEmail, }) diff --git a/internal/server/admin_handlers_test.go b/internal/server/admin_handlers_test.go index b0ec9d0..3ed3aee 100644 --- a/internal/server/admin_handlers_test.go +++ b/internal/server/admin_handlers_test.go @@ -23,7 +23,7 @@ func TestAdminHandlers_CSRF(t *testing.T) { encryptionKey := "test-encryption-key-32-bytes-long" // Create admin handlers - handlers := NewAdminHandlers(storage, cfg, sessionManager, encryptionKey) + handlers := NewAdminHandlers(storage, *cfg, sessionManager, encryptionKey) t.Run("generate and validate CSRF token", func(t *testing.T) { // Generate token @@ -76,7 +76,7 @@ func TestAdminHandlers_CSRF(t *testing.T) { t.Run("different encryption keys", func(t *testing.T) { // Create handlers with different key - handlers2 := NewAdminHandlers(storage, cfg, sessionManager, "different-encryption-key-32bytes") + handlers2 := NewAdminHandlers(storage, *cfg, sessionManager, "different-encryption-key-32bytes") // Generate token with first handler token1, err := handlers.generateCSRFToken() diff --git a/internal/server/auth_handlers.go b/internal/server/auth_handlers.go new file mode 100644 index 0000000..8388b31 --- /dev/null +++ b/internal/server/auth_handlers.go @@ -0,0 +1,614 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/dgellow/mcp-front/internal/auth" + "github.com/dgellow/mcp-front/internal/browserauth" + "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/crypto" + "github.com/dgellow/mcp-front/internal/envutil" + "github.com/dgellow/mcp-front/internal/googleauth" + jsonwriter "github.com/dgellow/mcp-front/internal/json" + "github.com/dgellow/mcp-front/internal/log" + "github.com/dgellow/mcp-front/internal/oauthsession" + "github.com/dgellow/mcp-front/internal/storage" + "github.com/ory/fosite" +) + +// AuthHandlers provides OAuth HTTP handlers with dependency injection +type AuthHandlers struct { + oauthProvider fosite.OAuth2Provider + authConfig config.OAuthAuthConfig + storage storage.Storage + sessionEncryptor crypto.Encryptor + mcpServers map[string]*config.MCPClientConfig + oauthStateToken crypto.TokenSigner + serviceOAuthClient *auth.ServiceOAuthClient +} + +// UpstreamOAuthState stores OAuth state during upstream authentication flow (MCP host → mcp-front) +type UpstreamOAuthState struct { + UserInfo googleauth.UserInfo `json:"user_info"` + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + Scopes []string `json:"scopes"` + State string `json:"state"` + ResponseType string `json:"response_type"` +} + +// NewAuthHandlers creates new auth handlers with dependency injection +func NewAuthHandlers( + oauthProvider fosite.OAuth2Provider, + authConfig config.OAuthAuthConfig, + storage storage.Storage, + sessionEncryptor crypto.Encryptor, + mcpServers map[string]*config.MCPClientConfig, + serviceOAuthClient *auth.ServiceOAuthClient, +) *AuthHandlers { + return &AuthHandlers{ + oauthProvider: oauthProvider, + authConfig: authConfig, + storage: storage, + sessionEncryptor: sessionEncryptor, + mcpServers: mcpServers, + oauthStateToken: crypto.NewTokenSigner([]byte(authConfig.EncryptionKey), 10*time.Minute), + serviceOAuthClient: serviceOAuthClient, + } +} + +// WellKnownHandler serves OAuth 2.0 metadata +func (h *AuthHandlers) WellKnownHandler(w http.ResponseWriter, r *http.Request) { + log.Logf("Well-known handler called: %s %s", r.Method, r.URL.Path) + + metadata := map[string]any{ + "issuer": h.authConfig.Issuer, + "authorization_endpoint": fmt.Sprintf("%s/authorize", h.authConfig.Issuer), + "token_endpoint": fmt.Sprintf("%s/token", h.authConfig.Issuer), + "registration_endpoint": fmt.Sprintf("%s/register", h.authConfig.Issuer), + "response_types_supported": []string{ + "code", + }, + "grant_types_supported": []string{ + "authorization_code", + "refresh_token", + }, + "code_challenge_methods_supported": []string{ + "S256", + }, + "token_endpoint_auth_methods_supported": []string{ + "none", + "client_secret_post", + }, + "scopes_supported": []string{ + "openid", + "profile", + "email", + "offline_access", + }, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(metadata); err != nil { + log.LogError("Failed to encode well-known metadata: %v", err) + jsonwriter.WriteInternalServerError(w, "Internal server error") + } +} + +// AuthorizeHandler handles OAuth 2.0 authorization requests +func (h *AuthHandlers) AuthorizeHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log.Logf("Authorize handler called: %s %s", r.Method, r.URL.Path) + + // In development mode, generate a secure state parameter if missing + // This works around bugs in OAuth clients that don't send state + stateParam := r.URL.Query().Get("state") + if envutil.IsDev() && len(stateParam) == 0 { + generatedState := crypto.GenerateSecureToken() + log.LogWarn("Development mode: generating state parameter '%s' for buggy client", generatedState) + q := r.URL.Query() + q.Set("state", generatedState) + r.URL.RawQuery = q.Encode() + // Also update the form values + if r.Form == nil { + _ = r.ParseForm() + } + r.Form.Set("state", generatedState) + } + + // Parse the authorize request + ar, err := h.oauthProvider.NewAuthorizeRequest(ctx, r) + if err != nil { + log.LogError("Authorize request error: %v", err) + h.oauthProvider.WriteAuthorizeError(w, ar, err) + return + } + + state := ar.GetState() + h.storage.StoreAuthorizeRequest(state, ar) + + authURL := googleauth.GoogleAuthURL(h.authConfig, state) + http.Redirect(w, r, authURL, http.StatusFound) +} + +// GoogleCallbackHandler handles the callback from Google OAuth +func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + state := r.URL.Query().Get("state") + code := r.URL.Query().Get("code") + + if errMsg := r.URL.Query().Get("error"); errMsg != "" { + errDesc := r.URL.Query().Get("error_description") + log.LogError("Google OAuth error: %s - %s", errMsg, errDesc) + jsonwriter.WriteBadRequest(w, fmt.Sprintf("Authentication failed: %s", errMsg)) + return + } + + if state == "" || code == "" { + log.LogError("Missing state or code in callback") + jsonwriter.WriteBadRequest(w, "Invalid callback parameters") + return + } + + var ar fosite.AuthorizeRequester + var isBrowserFlow bool + var returnURL string + + if strings.HasPrefix(state, "browser:") { + isBrowserFlow = true + stateToken := strings.TrimPrefix(state, "browser:") + + var browserState browserauth.AuthorizationState + if err := h.oauthStateToken.Verify(stateToken, &browserState); err != nil { + log.LogError("Invalid browser state: %v", err) + jsonwriter.WriteBadRequest(w, "Invalid state parameter") + return + } + returnURL = browserState.ReturnURL + } else { + // OAuth client flow - retrieve stored authorize request + var found bool + ar, found = h.storage.GetAuthorizeRequest(state) + if !found { + log.LogError("Invalid or expired state: %s", state) + jsonwriter.WriteBadRequest(w, "Invalid or expired authorization request") + return + } + } + + // Exchange code for token with timeout + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + token, err := googleauth.ExchangeCodeForToken(ctx, h.authConfig, code) + if err != nil { + log.LogError("Failed to exchange code: %v", err) + if !isBrowserFlow && ar != nil { + h.oauthProvider.WriteAuthorizeError(w, ar, fosite.ErrServerError.WithHint("Failed to exchange authorization code")) + } else { + jsonwriter.WriteInternalServerError(w, "Authentication failed") + } + return + } + + // Validate user + userInfo, err := googleauth.ValidateUser(ctx, h.authConfig, token) + if err != nil { + log.LogError("User validation failed: %v", err) + if !isBrowserFlow && ar != nil { + h.oauthProvider.WriteAuthorizeError(w, ar, fosite.ErrAccessDenied.WithHint(err.Error())) + } else { + jsonwriter.WriteForbidden(w, "Access denied") + } + return + } + + log.Logf("User authenticated: %s", userInfo.Email) + + // Store user in database + if err := h.storage.UpsertUser(ctx, userInfo.Email); err != nil { + log.LogWarnWithFields("auth", "Failed to track user", map[string]any{ + "email": userInfo.Email, + "error": err.Error(), + }) + } + + if isBrowserFlow { + // Browser SSO flow - set encrypted session cookie + // Browser sessions should last longer than API tokens for better UX + sessionDuration := 24 * time.Hour + + sessionData := browserauth.SessionCookie{ + Email: userInfo.Email, + Expires: time.Now().Add(sessionDuration), + } + + // Marshal session data to JSON + jsonData, err := json.Marshal(sessionData) + if err != nil { + log.LogError("Failed to marshal session data: %v", err) + jsonwriter.WriteInternalServerError(w, "Failed to create session") + return + } + + // Encrypt session data + encryptedData, err := h.sessionEncryptor.Encrypt(string(jsonData)) + if err != nil { + log.LogError("Failed to encrypt session: %v", err) + jsonwriter.WriteInternalServerError(w, "Failed to create session") + return + } + + // Set secure session cookie + http.SetCookie(w, &http.Cookie{ + Name: "mcp_session", + Value: encryptedData, + Path: "/", + HttpOnly: true, + Secure: !envutil.IsDev(), + SameSite: http.SameSiteStrictMode, + MaxAge: int(sessionDuration.Seconds()), + }) + + log.LogInfoWithFields("auth", "Browser SSO session created", map[string]any{ + "user": userInfo.Email, + "duration": sessionDuration, + "returnURL": returnURL, + }) + + // Redirect to return URL + http.Redirect(w, r, returnURL, http.StatusFound) + return + } + + // OAuth client flow - check if any services need OAuth + needsServiceAuth := false + for _, serverConfig := range h.mcpServers { + if serverConfig.RequiresUserToken && + serverConfig.UserAuthentication != nil && + serverConfig.UserAuthentication.Type == config.UserAuthTypeOAuth { + needsServiceAuth = true + break + } + } + + if needsServiceAuth { + stateData, err := h.signUpstreamOAuthState(ar, userInfo) + if err != nil { + log.LogError("Failed to sign OAuth state: %v", err) + h.oauthProvider.WriteAuthorizeError(w, ar, fosite.ErrServerError.WithHint("Failed to prepare service authentication")) + return + } + + http.Redirect(w, r, fmt.Sprintf("/oauth/services?state=%s", url.QueryEscape(stateData)), http.StatusFound) + return + } + + session := &oauthsession.Session{ + DefaultSession: &fosite.DefaultSession{ + ExpiresAt: map[fosite.TokenType]time.Time{ + fosite.AccessToken: time.Now().Add(h.authConfig.TokenTTL), + fosite.RefreshToken: time.Now().Add(h.authConfig.TokenTTL * 2), + }, + }, + UserInfo: userInfo, + } + + // Accept the authorization request + response, err := h.oauthProvider.NewAuthorizeResponse(ctx, ar, session) + if err != nil { + log.LogError("Authorize response error: %v", err) + h.oauthProvider.WriteAuthorizeError(w, ar, err) + return + } + + h.oauthProvider.WriteAuthorizeResponse(w, ar, response) +} + +// TokenHandler handles OAuth 2.0 token requests +// +// TODO: feels messy, need to see if we can simplify that whole logic +func (h *AuthHandlers) TokenHandler(w http.ResponseWriter, r *http.Request) { + log.Logf("Token handler called: %s %s", r.Method, r.URL.Path) + ctx := r.Context() + + // Create session for the token exchange + // Note: We create our custom Session type here, and fosite will populate it + // with the session data from the authorization code during NewAccessRequest + session := &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}} + + // Handle token request - this retrieves the session from the authorization code + accessRequest, err := h.oauthProvider.NewAccessRequest(ctx, r, session) + if err != nil { + log.LogError("Access request error: %v", err) + h.oauthProvider.WriteAccessError(w, accessRequest, err) + return + } + + // At this point, accessRequest.GetSession() contains the session data from + // the authorization phase (including our custom UserInfo). Fosite handles + // the session propagation internally when creating the access token. + + // Generate tokens + response, err := h.oauthProvider.NewAccessResponse(ctx, accessRequest) + if err != nil { + log.LogError("Access response error: %v", err) + h.oauthProvider.WriteAccessError(w, accessRequest, err) + return + } + + h.oauthProvider.WriteAccessResponse(w, accessRequest, response) +} + +// buildClientRegistrationResponse creates the registration response for a client +func (h *AuthHandlers) buildClientRegistrationResponse(client *fosite.DefaultClient, tokenEndpointAuthMethod string, clientSecret string) map[string]any { + response := map[string]any{ + "client_id": client.GetID(), + "client_id_issued_at": time.Now().Unix(), + "redirect_uris": client.GetRedirectURIs(), + "grant_types": client.GetGrantTypes(), + "response_types": client.GetResponseTypes(), + "scope": strings.Join(client.GetScopes(), " "), // Space-separated string + "token_endpoint_auth_method": tokenEndpointAuthMethod, + } + + // Include client_secret only for confidential clients + if clientSecret != "" { + response["client_secret"] = clientSecret + } + + return response +} + +// RegisterHandler handles dynamic client registration (RFC 7591) +func (h *AuthHandlers) RegisterHandler(w http.ResponseWriter, r *http.Request) { + log.Logf("Register handler called: %s %s", r.Method, r.URL.Path) + + if r.Method != http.MethodPost { + jsonwriter.WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed") + return + } + + // Parse client metadata + var metadata map[string]any + if err := json.NewDecoder(r.Body).Decode(&metadata); err != nil { + jsonwriter.WriteBadRequest(w, "Invalid request body") + return + } + + // Parse client request + redirectURIs, scopes, err := googleauth.ParseClientRequest(metadata) + if err != nil { + log.LogError("Client request parsing error: %v", err) + jsonwriter.WriteBadRequest(w, err.Error()) + return + } + + // Check if client requests client_secret_post authentication + tokenEndpointAuthMethod := "none" + var client *fosite.DefaultClient + var plaintextSecret string + clientID := crypto.GenerateSecureToken() + + if authMethod, ok := metadata["token_endpoint_auth_method"].(string); ok && authMethod == "client_secret_post" { + // Create confidential client with a secret + plaintextSecret = crypto.GenerateSecureToken() + hashedSecret, err := crypto.HashClientSecret(plaintextSecret) + if err != nil { + log.LogError("Failed to hash client secret: %v", err) + jsonwriter.WriteInternalServerError(w, "Failed to create client") + return + } + client = h.storage.CreateConfidentialClient(clientID, hashedSecret, redirectURIs, scopes, h.authConfig.Issuer) + tokenEndpointAuthMethod = "client_secret_post" + log.Logf("Creating confidential client %s with client_secret_post authentication", clientID) + } else { + // Create public client (no secret) + client = h.storage.CreateClient(clientID, redirectURIs, scopes, h.authConfig.Issuer) + log.Logf("Creating public client %s with no authentication", clientID) + } + + response := h.buildClientRegistrationResponse(client, tokenEndpointAuthMethod, plaintextSecret) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + if err := json.NewEncoder(w).Encode(response); err != nil { + log.LogError("Failed to encode registration response: %v", err) + jsonwriter.WriteInternalServerError(w, "Failed to create client") + } +} + +// signUpstreamOAuthState signs upstream OAuth state for secure storage +func (h *AuthHandlers) signUpstreamOAuthState(ar fosite.AuthorizeRequester, userInfo googleauth.UserInfo) (string, error) { + state := UpstreamOAuthState{ + UserInfo: userInfo, + ClientID: ar.GetClient().GetID(), + RedirectURI: ar.GetRedirectURI().String(), + Scopes: ar.GetRequestedScopes(), + State: ar.GetState(), + ResponseType: ar.GetResponseTypes()[0], + } + + return h.oauthStateToken.Sign(state) +} + +// verifyUpstreamOAuthState verifies and validates upstream OAuth state +func (h *AuthHandlers) verifyUpstreamOAuthState(signedState string) (*UpstreamOAuthState, error) { + var state UpstreamOAuthState + if err := h.oauthStateToken.Verify(signedState, &state); err != nil { + return nil, err + } + return &state, nil +} + +// ServiceSelectionHandler shows the interstitial page for selecting services to connect +func (h *AuthHandlers) ServiceSelectionHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + jsonwriter.WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed") + return + } + + signedState := r.URL.Query().Get("state") + if signedState == "" { + jsonwriter.WriteBadRequest(w, "Missing state parameter") + return + } + + upstreamOAuthState, err := h.verifyUpstreamOAuthState(signedState) + if err != nil { + log.LogError("Failed to verify OAuth state: %v", err) + jsonwriter.WriteBadRequest(w, "Invalid or expired session") + return + } + + userEmail := upstreamOAuthState.UserInfo.Email + + // Prepare template data + returnURL := fmt.Sprintf("/oauth/services?state=%s", url.QueryEscape(signedState)) + + // Prepare service list + var services []ServiceSelectionData + for name, serverConfig := range h.mcpServers { + if serverConfig.RequiresUserToken && + serverConfig.UserAuthentication != nil && + serverConfig.UserAuthentication.Type == config.UserAuthTypeOAuth { + + // Check if user already has valid token + token, _ := h.storage.GetUserToken(r.Context(), userEmail, name) + status := "not_connected" + if token != nil { + status = "connected" + } + + displayName := name + if serverConfig.UserAuthentication.DisplayName != "" { + displayName = serverConfig.UserAuthentication.DisplayName + } + + // Check for error from callback + errorMsg := "" + if r.URL.Query().Get("error") != "" && r.URL.Query().Get("service") == name { + status = "error" + // Use error_msg if available, fallback to error_description + errorMsg = r.URL.Query().Get("error_msg") + if errorMsg == "" { + errorMsg = r.URL.Query().Get("error_description") + } + if errorMsg == "" { + errorMsg = "OAuth connection failed" + } + } + + // Generate OAuth connect URL if OAuth client is available + connectURL := "" + if h.serviceOAuthClient != nil { + connectURL = h.serviceOAuthClient.GetConnectURL(name, returnURL) + } + + services = append(services, ServiceSelectionData{ + Name: name, + DisplayName: displayName, + Status: status, + ErrorMsg: errorMsg, + ConnectURL: connectURL, + }) + } + } + + pageData := ServicesPageData{ + Services: services, + State: url.QueryEscape(signedState), + ReturnURL: url.QueryEscape(returnURL), + } + + // Render template + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := servicesPageTemplate.Execute(w, pageData); err != nil { + log.LogError("Failed to render services page: %v", err) + jsonwriter.WriteInternalServerError(w, "Failed to render page") + } +} + +// CompleteOAuthHandler completes the original OAuth flow after service selection +func (h *AuthHandlers) CompleteOAuthHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + jsonwriter.WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed") + return + } + + signedState := r.URL.Query().Get("state") + if signedState == "" { + jsonwriter.WriteBadRequest(w, "Missing state parameter") + return + } + + upstreamOAuthState, err := h.verifyUpstreamOAuthState(signedState) + if err != nil { + log.LogError("Failed to verify OAuth state: %v", err) + jsonwriter.WriteBadRequest(w, "Invalid or expired session") + return + } + + // Recreate the authorize request + ctx := r.Context() + client, err := h.storage.GetClient(ctx, upstreamOAuthState.ClientID) + if err != nil { + log.LogError("Failed to get client: %v", err) + jsonwriter.WriteInternalServerError(w, "Client not found") + return + } + + // Create a new authorize request with the stored parameters + ar := &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{upstreamOAuthState.ResponseType}, + RedirectURI: &url.URL{}, + State: upstreamOAuthState.State, + HandledResponseTypes: fosite.Arguments{}, + Request: fosite.Request{ + ID: crypto.GenerateSecureToken(), + RequestedAt: time.Now(), + Client: client, + RequestedScope: upstreamOAuthState.Scopes, + GrantedScope: upstreamOAuthState.Scopes, + Session: &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}}, + }, + } + + redirectURI, err := url.Parse(upstreamOAuthState.RedirectURI) + if err != nil { + log.LogError("Failed to parse redirect URI: %v", err) + jsonwriter.WriteInternalServerError(w, "Invalid redirect URI") + return + } + ar.RedirectURI = redirectURI + + // Create session with user info + session := &oauthsession.Session{ + DefaultSession: &fosite.DefaultSession{ + ExpiresAt: map[fosite.TokenType]time.Time{ + fosite.AccessToken: time.Now().Add(h.authConfig.TokenTTL), + fosite.RefreshToken: time.Now().Add(h.authConfig.TokenTTL * 2), + }, + }, + UserInfo: upstreamOAuthState.UserInfo, + } + ar.SetSession(session) + + // Accept the authorization request + response, err := h.oauthProvider.NewAuthorizeResponse(ctx, ar, session) + if err != nil { + log.LogError("Authorize response error: %v", err) + h.oauthProvider.WriteAuthorizeError(w, ar, err) + return + } + + // Write the response (redirects to Claude) + h.oauthProvider.WriteAuthorizeResponse(w, ar, response) +} diff --git a/internal/server/auth_test.go b/internal/server/auth_test.go index 64b7f30..c369df7 100644 --- a/internal/server/auth_test.go +++ b/internal/server/auth_test.go @@ -9,9 +9,10 @@ import ( "testing" "time" + "github.com/dgellow/mcp-front/internal/auth" + "github.com/dgellow/mcp-front/internal/browserauth" "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/crypto" - "github.com/dgellow/mcp-front/internal/oauth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -33,11 +34,11 @@ func TestAuthenticationBoundaries(t *testing.T) { Auth: &config.OAuthAuthConfig{ Kind: config.AuthKindOAuth, GoogleClientID: "test-client-id", - GoogleClientSecret: "test-client-secret", + GoogleClientSecret: config.Secret("test-client-secret"), GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: strings.Repeat("a", 32), - EncryptionKey: strings.Repeat("b", 32), - TokenTTL: "1h", + JWTSecret: config.Secret(strings.Repeat("a", 32)), + EncryptionKey: config.Secret(strings.Repeat("b", 32)), + TokenTTL: time.Hour, Storage: "memory", }, }, @@ -54,11 +55,11 @@ func TestAuthenticationBoundaries(t *testing.T) { Auth: &config.OAuthAuthConfig{ Kind: config.AuthKindOAuth, GoogleClientID: "test-client-id", - GoogleClientSecret: "test-client-secret", + GoogleClientSecret: config.Secret("test-client-secret"), GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: strings.Repeat("a", 32), - EncryptionKey: strings.Repeat("b", 32), - TokenTTL: "1h", + JWTSecret: config.Secret(strings.Repeat("a", 32)), + EncryptionKey: config.Secret(strings.Repeat("b", 32)), + TokenTTL: time.Hour, Storage: "memory", }, }, @@ -75,11 +76,11 @@ func TestAuthenticationBoundaries(t *testing.T) { Auth: &config.OAuthAuthConfig{ Kind: config.AuthKindOAuth, GoogleClientID: "test-client-id", - GoogleClientSecret: "test-client-secret", + GoogleClientSecret: config.Secret("test-client-secret"), GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: strings.Repeat("a", 32), - EncryptionKey: strings.Repeat("b", 32), - TokenTTL: "1h", + JWTSecret: config.Secret(strings.Repeat("a", 32)), + EncryptionKey: config.Secret(strings.Repeat("b", 32)), + TokenTTL: time.Hour, Storage: "memory", }, }, @@ -130,7 +131,7 @@ func TestAuthenticationBoundaries(t *testing.T) { // Create a valid session cookie using the encryptor if oauthConfig, ok := tt.config.Proxy.Auth.(*config.OAuthAuthConfig); ok { // Create session data - sessionData := oauth.SessionData{ + sessionData := browserauth.SessionCookie{ Email: "test@example.com", Expires: time.Now().Add(24 * time.Hour), } @@ -192,11 +193,11 @@ func TestMCPAuthConfiguration(t *testing.T) { Auth: &config.OAuthAuthConfig{ Kind: config.AuthKindOAuth, GoogleClientID: "test-client-id", - GoogleClientSecret: "test-client-secret", + GoogleClientSecret: config.Secret("test-client-secret"), GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: strings.Repeat("a", 32), - EncryptionKey: strings.Repeat("b", 32), - TokenTTL: "1h", + JWTSecret: config.Secret(strings.Repeat("a", 32)), + EncryptionKey: config.Secret(strings.Repeat("b", 32)), + TokenTTL: time.Hour, Storage: "memory", }, }, diff --git a/internal/server/handler.go b/internal/server/handler.go deleted file mode 100644 index 7294e52..0000000 --- a/internal/server/handler.go +++ /dev/null @@ -1,506 +0,0 @@ -package server - -import ( - "context" - "fmt" - "net/http" - "net/url" - "time" - - "github.com/dgellow/mcp-front/internal/client" - "github.com/dgellow/mcp-front/internal/config" - "github.com/dgellow/mcp-front/internal/crypto" - "github.com/dgellow/mcp-front/internal/inline" - "github.com/dgellow/mcp-front/internal/log" - "github.com/dgellow/mcp-front/internal/oauth" - "github.com/dgellow/mcp-front/internal/storage" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" -) - -// Server represents the MCP proxy server -type Server struct { - mux *http.ServeMux - config *config.Config - oauthServer *oauth.Server - storage storage.Storage - sessionManager *client.StdioSessionManager - sseServers map[string]*server.SSEServer // serverName -> SSE server for stdio servers -} - -// NewServer creates a new MCP proxy server handler -func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) { - baseURL, err := url.Parse(cfg.Proxy.BaseURL) - if err != nil { - return nil, fmt.Errorf("invalid base URL: %w", err) - } - - mux := http.NewServeMux() - - // Create session manager for stdio servers with configurable timeouts - sessionTimeout := 5 * time.Minute - cleanupInterval := 1 * time.Minute - maxPerUser := 10 - - // Use config values if available - if cfg.Proxy.Sessions != nil { - if cfg.Proxy.Sessions.Timeout > 0 { - sessionTimeout = cfg.Proxy.Sessions.Timeout - log.LogInfoWithFields("server", "Using configured session timeout", map[string]interface{}{ - "timeout": sessionTimeout, - }) - } - if cfg.Proxy.Sessions.CleanupInterval > 0 { - cleanupInterval = cfg.Proxy.Sessions.CleanupInterval - log.LogInfoWithFields("server", "Using configured cleanup interval", map[string]interface{}{ - "interval": cleanupInterval, - }) - } - maxPerUser = cfg.Proxy.Sessions.MaxPerUser - } - - sessionManager := client.NewStdioSessionManager( - client.WithTimeout(sessionTimeout), - client.WithMaxPerUser(maxPerUser), - client.WithCleanupInterval(cleanupInterval), - ) - - s := &Server{ - mux: mux, - config: cfg, - sessionManager: sessionManager, - sseServers: make(map[string]*server.SSEServer), - } - - // Build list of allowed CORS origins - var allowedOrigins []string - if oauthAuth, ok := cfg.Proxy.Auth.(*config.OAuthAuthConfig); ok && oauthAuth != nil { - allowedOrigins = oauthAuth.AllowedOrigins - } - - info := mcp.Implementation{ - Name: cfg.Proxy.Name, - Version: "dev", - } - - // Initialize OAuth server if OAuth config is provided - if oauthAuth, ok := cfg.Proxy.Auth.(*config.OAuthAuthConfig); ok && oauthAuth != nil { - log.LogDebug("initializing OAuth 2.1 server") - - // Parse TTL duration - ttl, err := time.ParseDuration(oauthAuth.TokenTTL) - if err != nil { - return nil, fmt.Errorf("parsing OAuth token TTL: %w", err) - } - - // Create storage based on configuration - var store storage.Storage - if oauthAuth.Storage == "firestore" { - log.LogInfoWithFields("oauth", "Using Firestore storage", map[string]interface{}{ - "project": oauthAuth.GCPProject, - "database": oauthAuth.FirestoreDatabase, - "collection": oauthAuth.FirestoreCollection, - }) - // Create encryptor for Firestore storage - encryptor, err := crypto.NewEncryptor([]byte(oauthAuth.EncryptionKey)) - if err != nil { - return nil, fmt.Errorf("failed to create encryptor: %w", err) - } - firestoreStorage, err := storage.NewFirestoreStorage( - ctx, - oauthAuth.GCPProject, - oauthAuth.FirestoreDatabase, - oauthAuth.FirestoreCollection, - encryptor, - ) - if err != nil { - return nil, fmt.Errorf("failed to create Firestore storage: %w", err) - } - store = firestoreStorage - } else { - log.LogInfoWithFields("oauth", "Using in-memory storage", map[string]interface{}{}) - store = storage.NewMemoryStorage() - } - - oauthConfig := oauth.Config{ - Issuer: oauthAuth.Issuer, - TokenTTL: ttl, - AllowedDomains: oauthAuth.AllowedDomains, - AllowedOrigins: oauthAuth.AllowedOrigins, - GoogleClientID: oauthAuth.GoogleClientID, - GoogleClientSecret: oauthAuth.GoogleClientSecret, - GoogleRedirectURI: oauthAuth.GoogleRedirectURI, - JWTSecret: oauthAuth.JWTSecret, - EncryptionKey: oauthAuth.EncryptionKey, - StorageType: oauthAuth.Storage, - GCPProjectID: oauthAuth.GCPProject, - FirestoreDatabase: oauthAuth.FirestoreDatabase, - FirestoreCollection: oauthAuth.FirestoreCollection, - } - - s.oauthServer, err = oauth.NewServer(oauthConfig, store) - if err != nil { - return nil, fmt.Errorf("failed to create OAuth server: %w", err) - } - - s.storage = store - - // Initialize admin users if admin is enabled - if cfg.Proxy.Admin != nil && cfg.Proxy.Admin.Enabled { - for _, adminEmail := range cfg.Proxy.Admin.AdminEmails { - // Upsert admin user - if err := store.UpsertUser(ctx, adminEmail); err != nil { - log.LogWarnWithFields("server", "Failed to initialize admin user", map[string]interface{}{ - "email": adminEmail, - "error": err.Error(), - }) - continue - } - // Set as admin - if err := store.SetUserAdmin(ctx, adminEmail, true); err != nil { - log.LogWarnWithFields("server", "Failed to set user as admin", map[string]interface{}{ - "email": adminEmail, - "error": err.Error(), - }) - } - } - } - - // Register OAuth endpoints - oauthMiddlewares := []MiddlewareFunc{ - corsMiddleware(allowedOrigins), - loggerMiddleware("oauth"), - recoverMiddleware("mcp"), - } - - mux.Handle("/.well-known/oauth-authorization-server", chainMiddleware(http.HandlerFunc(s.oauthServer.WellKnownHandler), oauthMiddlewares...)) - mux.Handle("/authorize", chainMiddleware(http.HandlerFunc(s.oauthServer.AuthorizeHandler), oauthMiddlewares...)) - mux.Handle("/oauth/callback", chainMiddleware(http.HandlerFunc(s.oauthServer.GoogleCallbackHandler), oauthMiddlewares...)) - mux.Handle("/token", chainMiddleware(http.HandlerFunc(s.oauthServer.TokenHandler), oauthMiddlewares...)) - mux.Handle("/register", chainMiddleware(http.HandlerFunc(s.oauthServer.RegisterHandler), oauthMiddlewares...)) - - // Protected endpoints - require authentication - tokenHandlers := NewTokenHandlers(s.storage, cfg.MCPServers, s.oauthServer != nil) - tokenMiddlewares := []MiddlewareFunc{ - corsMiddleware(allowedOrigins), - loggerMiddleware("tokens"), - s.oauthServer.SSOMiddleware(), - recoverMiddleware("mcp"), - } - - // Token management UI endpoints - mux.Handle("/my/tokens", chainMiddleware(http.HandlerFunc(tokenHandlers.ListTokensHandler), tokenMiddlewares...)) - mux.Handle("/my/tokens/set", chainMiddleware(http.HandlerFunc(tokenHandlers.SetTokenHandler), tokenMiddlewares...)) - mux.Handle("/my/tokens/delete", chainMiddleware(http.HandlerFunc(tokenHandlers.DeleteTokenHandler), tokenMiddlewares...)) - } - - // Setup MCP server endpoints - for serverName, serverConfig := range cfg.MCPServers { - // Build path like /notion/sse - ssePathPrefix := "/" + serverName + "/sse" - - log.LogInfoWithFields("server", "Registering MCP server", map[string]interface{}{ - "name": serverName, - "sse_path": ssePathPrefix, - "transport_type": serverConfig.TransportType, - "requires_user_token": serverConfig.RequiresUserToken, - }) - - var handler http.Handler - - // For inline servers, create a custom handler - if serverConfig.TransportType == config.MCPClientTypeInline { - // Resolve inline config - inlineConfig, resolvedTools, err := inline.ResolveConfig(serverConfig.InlineConfig) - if err != nil { - return nil, fmt.Errorf("failed to resolve inline config for %s: %w", serverName, err) - } - - // Create inline server - inlineServer := inline.NewServer(serverName, inlineConfig, resolvedTools) - - // Create inline handler - handler = inline.NewHandler(serverName, inlineServer) - - log.LogInfoWithFields("server", "Created inline MCP server", map[string]interface{}{ - "name": serverName, - "tools": len(resolvedTools), - }) - } else { - - // For stdio servers, create a single shared MCP server - if isStdioServer(serverConfig) { - // Create the shared MCP server for this stdio server - // We need to create it first so we can reference it in the hooks - var mcpServer *server.MCPServer - - // Create hooks for session management - hooks := &server.Hooks{} - - // Store reference to server name for use in hooks - currentServerName := serverName - - // Setup hooks that will be called when sessions are created/destroyed - hooks.AddOnRegisterSession(func(sessionCtx context.Context, session server.ClientSession) { - // Extract handler from context - if handler, ok := sessionCtx.Value(sessionHandlerKey{}).(*sessionRequestHandler); ok { - // Pass the MCP server to the handler - handler.mcpServer = mcpServer - // Handle session registration - handleSessionRegistration(sessionCtx, session, handler, s.sessionManager) - } else { - log.LogErrorWithFields("server", "No session handler in context", map[string]interface{}{ - "sessionID": session.SessionID(), - "server": currentServerName, - }) - } - }) - - hooks.AddOnUnregisterSession(func(sessionCtx context.Context, session server.ClientSession) { - // Extract handler from context - if handler, ok := sessionCtx.Value(sessionHandlerKey{}).(*sessionRequestHandler); ok { - // Handle session cleanup - key := client.SessionKey{ - UserEmail: handler.userEmail, - ServerName: handler.h.serverName, - SessionID: session.SessionID(), - } - s.sessionManager.RemoveSession(key) - - if handler.h.storage != nil { - if err := handler.h.storage.RevokeSession(sessionCtx, session.SessionID()); err != nil { - log.LogWarnWithFields("server", "Failed to revoke session from storage", map[string]interface{}{ - "error": err.Error(), - "sessionID": session.SessionID(), - "user": handler.userEmail, - }) - } - } - - log.LogInfoWithFields("server", "Session unregistered and cleaned up", map[string]interface{}{ - "sessionID": session.SessionID(), - "server": currentServerName, - "user": handler.userEmail, - }) - } - }) - - // Now create the MCP server with the hooks - mcpServer = server.NewMCPServer(serverName, "1.0.0", - server.WithHooks(hooks), - server.WithPromptCapabilities(true), - server.WithResourceCapabilities(true, true), - server.WithToolCapabilities(true), - server.WithLogging(), - ) - - // Create the SSE server wrapper around the MCP server - sseServer := server.NewSSEServer(mcpServer, - server.WithStaticBasePath(serverName), - server.WithBaseURL(baseURL.String()), - ) - - s.sseServers[serverName] = sseServer - } - - // Create MCP handler for stdio/SSE servers - handler = NewMCPHandler( - serverName, - serverConfig, - s.storage, - baseURL.String(), - info, - s.sessionManager, - s.sseServers[serverName], // Pass the shared MCP server (nil for non-stdio) - ) - } - - // Setup middlewares - var middlewares []MiddlewareFunc - middlewares = append(middlewares, loggerMiddleware("mcp")) - middlewares = append(middlewares, corsMiddleware(allowedOrigins)) - - if s.oauthServer != nil { - log.LogTraceWithFields("server", "Adding OAuth middleware", map[string]interface{}{ - "server_name": serverName, - }) - middlewares = append(middlewares, s.oauthServer.ValidateTokenMiddleware()) - } - - if len(serverConfig.ServiceAuths) > 0 { - log.LogTraceWithFields("server", "Adding service auth middleware", map[string]interface{}{ - "server_name": serverName, - "auth_count": len(serverConfig.ServiceAuths), - }) - middlewares = append(middlewares, newServiceAuthMiddleware(serverConfig.ServiceAuths)) - } - - // important to be last, making it the outermost middleware, so it can recover from any middleware panic - middlewares = append(middlewares, recoverMiddleware("mcp")) - - // Register handler - SSE server needs to handle all paths under the server name - // It handles both /postgres/sse and /postgres/message endpoints - mux.Handle("/"+serverName+"/", chainMiddleware(handler, middlewares...)) - } - - // Admin routes - only if admin is enabled - if cfg.Proxy.Admin != nil && cfg.Proxy.Admin.Enabled { - log.LogInfoWithFields("server", "Admin UI enabled", map[string]interface{}{ - "admin_emails": cfg.Proxy.Admin.AdminEmails, - }) - - // Get encryption key from OAuth config - var encryptionKey string - if oauthAuth, ok := cfg.Proxy.Auth.(*config.OAuthAuthConfig); ok && oauthAuth != nil { - encryptionKey = oauthAuth.EncryptionKey - } - - adminHandlers := NewAdminHandlers(s.storage, cfg, s.sessionManager, encryptionKey) - adminMiddlewares := []MiddlewareFunc{ - corsMiddleware(allowedOrigins), - loggerMiddleware("admin"), - s.oauthServer.SSOMiddleware(), // Browser SSO - adminMiddleware(cfg.Proxy.Admin, s.storage), // Admin check - recoverMiddleware("mcp"), - } - - // Admin routes - all protected by admin middleware - mux.Handle("/admin", chainMiddleware(http.HandlerFunc(adminHandlers.DashboardHandler), adminMiddlewares...)) - mux.Handle("/admin/users", chainMiddleware(http.HandlerFunc(adminHandlers.UserActionHandler), adminMiddlewares...)) - mux.Handle("/admin/sessions", chainMiddleware(http.HandlerFunc(adminHandlers.SessionActionHandler), adminMiddlewares...)) - mux.Handle("/admin/logging", chainMiddleware(http.HandlerFunc(adminHandlers.LoggingActionHandler), adminMiddlewares...)) - } - - // Health check endpoint - mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{"status":"ok"}`)) - }) - - log.LogInfoWithFields("server", "MCP proxy server initialized", nil) - return s, nil -} - -// ServeHTTP implements http.Handler -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.mux.ServeHTTP(w, r) -} - -// Shutdown gracefully shuts down the server -func (s *Server) Shutdown() error { - if s.sessionManager != nil { - s.sessionManager.Shutdown() - } - return nil -} - -// isStdioServer checks if this is a stdio-based server -func isStdioServer(config *config.MCPClientConfig) bool { - return config.Command != "" -} - -// sessionHandlerKey is the context key for session handlers -type sessionHandlerKey struct{} - -// sessionRequestHandler handles session-specific logic for a request -type sessionRequestHandler struct { - h *MCPHandler - userEmail string - config *config.MCPClientConfig - mcpServer *server.MCPServer // The shared MCP server -} - -// handleSessionRegistration handles the registration of a new session -func handleSessionRegistration( - sessionCtx context.Context, - session server.ClientSession, - handler *sessionRequestHandler, - sessionManager *client.StdioSessionManager, -) { - // Create stdio process for this session - key := client.SessionKey{ - UserEmail: handler.userEmail, - ServerName: handler.h.serverName, - SessionID: session.SessionID(), - } - - log.LogDebugWithFields("server", "Registering session", map[string]interface{}{ - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) - - log.LogTraceWithFields("server", "Session registration started", map[string]interface{}{ - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - "requiresUserToken": handler.config.RequiresUserToken, - "transportType": handler.config.TransportType, - "command": handler.config.Command, - }) - - stdioSession, err := sessionManager.GetOrCreateSession( - sessionCtx, - key, - handler.config, - handler.h.info, - handler.h.setupBaseURL, - ) - if err != nil { - log.LogErrorWithFields("server", "Failed to create stdio session", map[string]interface{}{ - "error": err.Error(), - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) - return - } - - // Discover and register capabilities from the stdio process - if err := stdioSession.DiscoverAndRegisterCapabilities( - sessionCtx, - handler.mcpServer, - handler.userEmail, - handler.config.RequiresUserToken, - handler.h.storage, - handler.h.serverName, - handler.h.setupBaseURL, - handler.config.TokenSetup, - session, - ); err != nil { - log.LogErrorWithFields("server", "Failed to discover and register capabilities", map[string]interface{}{ - "error": err.Error(), - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) - sessionManager.RemoveSession(key) - return - } - - if handler.userEmail != "" { - if handler.h.storage != nil { - activeSession := storage.ActiveSession{ - SessionID: session.SessionID(), - UserEmail: handler.userEmail, - ServerName: handler.h.serverName, - Created: time.Now(), - LastActive: time.Now(), - } - if err := handler.h.storage.TrackSession(sessionCtx, activeSession); err != nil { - log.LogWarnWithFields("server", "Failed to track session", map[string]interface{}{ - "error": err.Error(), - "sessionID": session.SessionID(), - "user": handler.userEmail, - }) - } - } - } - - log.LogInfoWithFields("server", "Session successfully created and connected", map[string]interface{}{ - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) -} diff --git a/internal/server/http.go b/internal/server/http.go new file mode 100644 index 0000000..a61797a --- /dev/null +++ b/internal/server/http.go @@ -0,0 +1,313 @@ +package server + +import ( + "context" + "errors" + "net/http" + "strings" + "time" + + "github.com/dgellow/mcp-front/internal/auth" + "github.com/dgellow/mcp-front/internal/client" + "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/log" + "github.com/dgellow/mcp-front/internal/storage" + mcpserver "github.com/mark3labs/mcp-go/server" +) + +// UserTokenService handles user token retrieval and OAuth refresh +type UserTokenService struct { + storage storage.Storage + serviceOAuthClient *auth.ServiceOAuthClient +} + +// NewUserTokenService creates a new user token service +func NewUserTokenService(storage storage.Storage, serviceOAuthClient *auth.ServiceOAuthClient) *UserTokenService { + return &UserTokenService{ + storage: storage, + serviceOAuthClient: serviceOAuthClient, + } +} + +// GetUserToken retrieves and formats a user token for a service, handling OAuth refresh +func (uts *UserTokenService) GetUserToken(ctx context.Context, userEmail, serviceName string, serviceConfig *config.MCPClientConfig) (string, error) { + storedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName) + if err != nil { + return "", err + } + + switch storedToken.Type { + case storage.TokenTypeManual: + // Token is already in storedToken.Value, formatUserToken will handle it + break + case storage.TokenTypeOAuth: + if storedToken.OAuthData != nil && uts.serviceOAuthClient != nil { + if err := uts.serviceOAuthClient.RefreshToken(ctx, userEmail, serviceName, serviceConfig); err != nil { + log.LogWarnWithFields("user_token", "Failed to refresh OAuth token", map[string]any{ + "service": serviceName, + "user": userEmail, + "error": err.Error(), + }) + // Continue with current token - the service will handle auth failure + } else { + // Re-fetch the updated token after refresh + refreshedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName) + if err != nil { + log.LogErrorWithFields("user_token", "Failed to fetch token after successful refresh", map[string]any{ + "service": serviceName, + "user": userEmail, + "error": err.Error(), + }) + // Continue with original token - the service will handle auth failure + } else { + storedToken = refreshedToken + var expiresAt time.Time + if refreshedToken.OAuthData != nil { + expiresAt = refreshedToken.OAuthData.ExpiresAt + } + log.LogInfoWithFields("user_token", "OAuth token refreshed and updated", map[string]any{ + "service": serviceName, + "user": userEmail, + "expiresAt": expiresAt, + }) + } + } + } + } + + return formatUserToken(storedToken, serviceConfig.UserAuthentication), nil +} + +// HTTPServer manages the HTTP server lifecycle +type HTTPServer struct { + server *http.Server +} + +// NewHTTPServer creates a new HTTP server with the given handler and address +func NewHTTPServer(handler http.Handler, addr string) *HTTPServer { + return &HTTPServer{ + server: &http.Server{ + Addr: addr, + Handler: handler, + }, + } +} + +// Handler builders and mux assembly + +// HealthHandler handles health check requests +type HealthHandler struct{} + +// NewHealthHandler creates a new health handler +func NewHealthHandler() *HealthHandler { + return &HealthHandler{} +} + +// ServeHTTP implements http.Handler for health checks +func (h *HealthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"ok"}`)) +} + +// Start starts the HTTP server +func (h *HTTPServer) Start() error { + log.LogInfoWithFields("http", "HTTP server starting", map[string]any{ + "addr": h.server.Addr, + }) + + if err := h.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil +} + +// Stop gracefully stops the HTTP server +func (h *HTTPServer) Stop(ctx context.Context) error { + log.LogInfoWithFields("http", "HTTP server stopping", map[string]any{ + "addr": h.server.Addr, + }) + + if err := h.server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err + } + + log.LogInfoWithFields("http", "HTTP server stopped", map[string]any{ + "addr": h.server.Addr, + }) + return nil +} + +// isStdioServer checks if this is a stdio-based server +func isStdioServer(config *config.MCPClientConfig) bool { + return config.Command != "" +} + +// formatUserToken formats a stored token according to the user authentication configuration +func formatUserToken(storedToken *storage.StoredToken, auth *config.UserAuthentication) string { + if storedToken == nil { + return "" + } + + if storedToken.Type == storage.TokenTypeOAuth && storedToken.OAuthData != nil { + token := storedToken.OAuthData.AccessToken + if auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" { + return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token) + } + return token + } + + token := storedToken.Value + if auth != nil && auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" { + return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token) + } + return token +} + +// SessionHandlerKey is the context key for session handlers +type SessionHandlerKey struct{} + +// SessionRequestHandler handles session-specific logic for a request +type SessionRequestHandler struct { + h *MCPHandler + userEmail string + config *config.MCPClientConfig + mcpServer *mcpserver.MCPServer // The shared MCP server +} + +// NewSessionRequestHandler creates a new session request handler with all dependencies +func NewSessionRequestHandler(h *MCPHandler, userEmail string, config *config.MCPClientConfig, mcpServer *mcpserver.MCPServer) *SessionRequestHandler { + return &SessionRequestHandler{ + h: h, + userEmail: userEmail, + config: config, + mcpServer: mcpServer, + } +} + +// GetUserEmail returns the user email for this session +func (s *SessionRequestHandler) GetUserEmail() string { + return s.userEmail +} + +// GetServerName returns the server name for this session +func (s *SessionRequestHandler) GetServerName() string { + return s.h.serverName +} + +// GetStorage returns the storage interface +func (s *SessionRequestHandler) GetStorage() storage.Storage { + return s.h.storage +} + +// HandleSessionRegistration handles the registration of a new session +func HandleSessionRegistration( + sessionCtx context.Context, + session mcpserver.ClientSession, + handler *SessionRequestHandler, + sessionManager *client.StdioSessionManager, +) { + // Create stdio process for this session + key := client.SessionKey{ + UserEmail: handler.userEmail, + ServerName: handler.h.serverName, + SessionID: session.SessionID(), + } + + log.LogDebugWithFields("server", "Registering session", map[string]any{ + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) + + log.LogTraceWithFields("server", "Session registration started", map[string]any{ + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + "requiresUserToken": handler.config.RequiresUserToken, + "transportType": handler.config.TransportType, + "command": handler.config.Command, + }) + + var userToken string + if handler.config.RequiresUserToken && handler.userEmail != "" && handler.h.storage != nil { + storedToken, err := handler.h.storage.GetUserToken(sessionCtx, handler.userEmail, handler.h.serverName) + if err != nil { + log.LogDebugWithFields("server", "No user token found", map[string]any{ + "server": handler.h.serverName, + "user": handler.userEmail, + }) + } else if storedToken != nil { + if handler.config.UserAuthentication != nil { + userToken = formatUserToken(storedToken, handler.config.UserAuthentication) + } else { + userToken = storedToken.Value + } + } + } + + stdioSession, err := sessionManager.GetOrCreateSession( + sessionCtx, + key, + handler.config, + handler.h.info, + handler.h.setupBaseURL, + userToken, + ) + if err != nil { + log.LogErrorWithFields("server", "Failed to create stdio session", map[string]any{ + "error": err.Error(), + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) + return + } + + // Discover and register capabilities from the stdio process + if err := stdioSession.DiscoverAndRegisterCapabilities( + sessionCtx, + handler.mcpServer, + handler.userEmail, + handler.config.RequiresUserToken, + handler.h.storage, + handler.h.serverName, + handler.h.setupBaseURL, + handler.config.UserAuthentication, + session, + ); err != nil { + log.LogErrorWithFields("server", "Failed to discover and register capabilities", map[string]any{ + "error": err.Error(), + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) + sessionManager.RemoveSession(key) + return + } + + if handler.userEmail != "" { + if handler.h.storage != nil { + activeSession := storage.ActiveSession{ + SessionID: session.SessionID(), + UserEmail: handler.userEmail, + ServerName: handler.h.serverName, + Created: time.Now(), + LastActive: time.Now(), + } + if err := handler.h.storage.TrackSession(sessionCtx, activeSession); err != nil { + log.LogWarnWithFields("server", "Failed to track session", map[string]any{ + "error": err.Error(), + "sessionID": session.SessionID(), + "user": handler.userEmail, + }) + } + } + } + + log.LogInfoWithFields("server", "Session successfully created and connected", map[string]any{ + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) +} diff --git a/internal/server/mcp_handler.go b/internal/server/mcp_handler.go index 69d7231..029cb5d 100644 --- a/internal/server/mcp_handler.go +++ b/internal/server/mcp_handler.go @@ -5,16 +5,17 @@ import ( "context" "fmt" "io" + "maps" "net/http" "strings" - "github.com/dgellow/mcp-front/internal/auth" "github.com/dgellow/mcp-front/internal/client" "github.com/dgellow/mcp-front/internal/config" jsonwriter "github.com/dgellow/mcp-front/internal/json" "github.com/dgellow/mcp-front/internal/jsonrpc" "github.com/dgellow/mcp-front/internal/log" "github.com/dgellow/mcp-front/internal/oauth" + "github.com/dgellow/mcp-front/internal/servicecontext" "github.com/dgellow/mcp-front/internal/storage" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -23,11 +24,14 @@ import ( // SessionManager defines the interface for managing stdio sessions type SessionManager interface { GetSession(key client.SessionKey) (*client.StdioSession, bool) - GetOrCreateSession(ctx context.Context, key client.SessionKey, config *config.MCPClientConfig, info mcp.Implementation, setupBaseURL string) (*client.StdioSession, error) + GetOrCreateSession(ctx context.Context, key client.SessionKey, config *config.MCPClientConfig, info mcp.Implementation, setupBaseURL string, userToken string) (*client.StdioSession, error) RemoveSession(key client.SessionKey) Shutdown() } +// UserTokenFunc defines a function that retrieves a formatted user token for a service +type UserTokenFunc func(ctx context.Context, userEmail, serviceName string, serviceConfig *config.MCPClientConfig) (string, error) + // MCPHandler handles MCP requests with session management for stdio servers type MCPHandler struct { serverName string @@ -37,6 +41,8 @@ type MCPHandler struct { info mcp.Implementation sessionManager SessionManager sharedSSEServer *server.SSEServer // Shared SSE server for stdio servers + sharedMCPServer *server.MCPServer // Shared MCP server for stdio servers + getUserToken UserTokenFunc // Function to get formatted user tokens } // NewMCPHandler creates a new MCP handler with session management @@ -48,6 +54,8 @@ func NewMCPHandler( info mcp.Implementation, sessionManager SessionManager, sharedSSEServer *server.SSEServer, // Shared SSE server for stdio servers + sharedMCPServer *server.MCPServer, // Shared MCP server for stdio servers + getUserToken UserTokenFunc, ) *MCPHandler { return &MCPHandler{ serverName: serverName, @@ -57,6 +65,8 @@ func NewMCPHandler( info: info, sessionManager: sessionManager, sharedSSEServer: sharedSSEServer, + sharedMCPServer: sharedMCPServer, + getUserToken: getUserToken, } } @@ -67,7 +77,7 @@ func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { userEmail, _ := oauth.GetUserFromContext(ctx) if userEmail == "" { // Check for basic auth username - username, _ := auth.GetUser(ctx) + username, _ := servicecontext.GetUser(ctx) userEmail = username } @@ -85,7 +95,8 @@ func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if serverConfig.TransportType == config.MCPClientTypeStreamable { - if r.Method == http.MethodPost { + switch r.Method { + case http.MethodPost: log.LogInfoWithFields("mcp", "Handling streamable POST request", map[string]any{ "path": r.URL.Path, "server": h.serverName, @@ -94,7 +105,7 @@ func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { "contentLength": r.ContentLength, }) h.handleStreamablePost(ctx, w, r, userEmail, serverConfig) - } else if r.Method == http.MethodGet { + case http.MethodGet: log.LogInfoWithFields("mcp", "Handling streamable GET request", map[string]any{ "path": r.URL.Path, "server": h.serverName, @@ -103,7 +114,7 @@ func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { "userAgent": r.UserAgent(), }) h.handleStreamableGet(ctx, w, r, userEmail, serverConfig) - } else { + default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } } else { @@ -176,14 +187,10 @@ func (h *MCPHandler) handleSSERequest(ctx context.Context, w http.ResponseWriter // that will be called when sessions are registered/unregistered // We need to set up our session-specific handlers // Create a custom hook handler for this specific request - sessionHandler := &sessionRequestHandler{ - h: h, - userEmail: userEmail, - config: config, - } + sessionHandler := NewSessionRequestHandler(h, userEmail, config, h.sharedMCPServer) // Store the handler in context so hooks can access it - ctx = context.WithValue(ctx, sessionHandlerKey{}, sessionHandler) + ctx = context.WithValue(ctx, SessionHandlerKey{}, sessionHandler) r = r.WithContext(ctx) log.LogInfoWithFields("mcp", "Serving SSE request for stdio server", map[string]any{ "server": h.serverName, @@ -269,15 +276,15 @@ func (h *MCPHandler) getUserTokenIfAvailable(ctx context.Context, userEmail stri return "", fmt.Errorf("authentication required") } - log.LogTraceWithFields("mcp_handler", "Attempting to resolve user token", map[string]interface{}{ + log.LogTraceWithFields("mcp_handler", "Attempting to resolve user token", map[string]any{ "server_name": h.serverName, "user": userEmail, }) // Check for service auth first - services provide their own user tokens - if serviceAuth, ok := auth.GetServiceAuth(ctx); ok { + if serviceAuth, ok := servicecontext.GetAuthInfo(ctx); ok { if serviceAuth.UserToken != "" { - log.LogTraceWithFields("mcp_handler", "Found user token in service auth context", map[string]interface{}{ + log.LogTraceWithFields("mcp_handler", "Found user token in service auth context", map[string]any{ "server_name": h.serverName, "user": userEmail, }) @@ -285,7 +292,7 @@ func (h *MCPHandler) getUserTokenIfAvailable(ctx context.Context, userEmail stri } } - log.LogTraceWithFields("mcp_handler", "No user token in service auth context, falling back to storage lookup", map[string]interface{}{ + log.LogTraceWithFields("mcp_handler", "No user token in service auth context, falling back to storage lookup", map[string]any{ "server_name": h.serverName, "user": userEmail, }) @@ -295,22 +302,28 @@ func (h *MCPHandler) getUserTokenIfAvailable(ctx context.Context, userEmail stri return "", fmt.Errorf("storage not configured") } - token, err := h.storage.GetUserToken(ctx, userEmail, h.serverName) + storedToken, err := h.storage.GetUserToken(ctx, userEmail, h.serverName) if err != nil { return "", err } - // Validate token format if configured - if h.serverConfig.TokenSetup != nil && h.serverConfig.TokenSetup.CompiledRegex != nil { - if !h.serverConfig.TokenSetup.CompiledRegex.MatchString(token) { - log.LogWarnWithFields("mcp", "User token doesn't match expected format", map[string]any{ - "user": userEmail, - "service": h.serverName, - }) + // Use injected function to get formatted token with refresh handling + if h.getUserToken != nil { + return h.getUserToken(ctx, userEmail, h.serverName, h.serverConfig) + } + + // Fallback: extract raw token without refresh (for backwards compatibility) + var tokenString string + switch storedToken.Type { + case storage.TokenTypeManual: + tokenString = storedToken.Value + case storage.TokenTypeOAuth: + if storedToken.OAuthData != nil { + tokenString = storedToken.OAuthData.AccessToken } } - return token, nil + return tokenString, nil } func (h *MCPHandler) forwardMessageToBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, config *config.MCPClientConfig) { @@ -375,9 +388,7 @@ func (h *MCPHandler) forwardMessageToBackend(ctx context.Context, w http.Respons w.WriteHeader(resp.StatusCode) - for k, v := range resp.Header { - w.Header()[k] = v - } + maps.Copy(w.Header(), resp.Header) if _, err := io.Copy(w, resp.Body); err != nil { log.LogErrorWithFields("mcp", "Failed to copy response body", map[string]any{ diff --git a/internal/server/mcp_handler_test.go b/internal/server/mcp_handler_test.go index 2c0a7a0..e057d9f 100644 --- a/internal/server/mcp_handler_test.go +++ b/internal/server/mcp_handler_test.go @@ -31,11 +31,13 @@ type mockStorage struct { mock.Mock } -// Override only the methods we want to mock -func (m *mockStorage) GetUserToken(ctx context.Context, userEmail, service string) (string, error) { +func (m *mockStorage) GetUserToken(ctx context.Context, userEmail, service string) (*storage.StoredToken, error) { if m.Mock.ExpectedCalls != nil { args := m.Called(ctx, userEmail, service) - return args.String(0), args.Error(1) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*storage.StoredToken), args.Error(1) } return m.MemoryStorage.GetUserToken(ctx, userEmail, service) } @@ -56,8 +58,8 @@ func (m *mockSessionManager) GetSession(key client.SessionKey) (*client.StdioSes return args.Get(0).(*client.StdioSession), args.Bool(1) } -func (m *mockSessionManager) GetOrCreateSession(ctx context.Context, key client.SessionKey, config *config.MCPClientConfig, info mcp.Implementation, setupBaseURL string) (*client.StdioSession, error) { - args := m.Called(ctx, key, config, info, setupBaseURL) +func (m *mockSessionManager) GetOrCreateSession(ctx context.Context, key client.SessionKey, config *config.MCPClientConfig, info mcp.Implementation, setupBaseURL string, userToken string) (*client.StdioSession, error) { + args := m.Called(ctx, key, config, info, setupBaseURL, userToken) if args.Get(0) == nil { return nil, args.Error(1) } @@ -73,7 +75,7 @@ func (m *mockSessionManager) Shutdown() { } // Test helper to create MCPHandler for SSE tests -func createTestMCPHandler(serverName string, config *config.MCPClientConfig) *MCPHandler { +func createTestMCPHandler(serverName string, serverConfig *config.MCPClientConfig) *MCPHandler { mockStore := &mockStorage{ MemoryStorage: storage.NewMemoryStorage(), } @@ -82,12 +84,18 @@ func createTestMCPHandler(serverName string, config *config.MCPClientConfig) *MC return NewMCPHandler( serverName, - config, + serverConfig, mockStore, "http://localhost:8080", info, sessionManager, nil, + func(ctx context.Context, userEmail, serviceName string, serviceConfig *config.MCPClientConfig) (string, error) { + if userEmail == "" { + return "", fmt.Errorf("no user") + } + return "test-token-for-" + serviceName, nil + }, ) } @@ -576,7 +584,7 @@ func TestHandleStreamableGet(t *testing.T) { defer backend.Close() // Configure client - config := &config.MCPClientConfig{ + serverConfig := &config.MCPClientConfig{ URL: backend.URL, TransportType: config.MCPClientTypeStreamable, Headers: map[string]string{ @@ -585,7 +593,7 @@ func TestHandleStreamableGet(t *testing.T) { Timeout: 5 * time.Second, } - handler := createTestMCPHandler("test-streamable", config) + handler := createTestMCPHandler("test-streamable", serverConfig) // Create request with Accept header req := httptest.NewRequest(http.MethodGet, "/test-streamable", nil) @@ -593,7 +601,7 @@ func TestHandleStreamableGet(t *testing.T) { rec := httptest.NewRecorder() // Call the function - handler.handleStreamableGet(context.Background(), rec, req, "user@example.com", config) + handler.handleStreamableGet(context.Background(), rec, req, "user@example.com", serverConfig) // Verify response assert.Equal(t, http.StatusOK, rec.Code) @@ -602,18 +610,18 @@ func TestHandleStreamableGet(t *testing.T) { }) t.Run("missing Accept header", func(t *testing.T) { - config := &config.MCPClientConfig{ + serverConfig := &config.MCPClientConfig{ URL: "http://example.com", TransportType: config.MCPClientTypeStreamable, } - handler := createTestMCPHandler("test-streamable", config) + handler := createTestMCPHandler("test-streamable", serverConfig) // Create request without Accept header req := httptest.NewRequest(http.MethodGet, "/test-streamable", nil) rec := httptest.NewRecorder() - handler.handleStreamableGet(context.Background(), rec, req, "user@example.com", config) + handler.handleStreamableGet(context.Background(), rec, req, "user@example.com", serverConfig) // Should return 406 Not Acceptable assert.Equal(t, http.StatusNotAcceptable, rec.Code) @@ -621,19 +629,19 @@ func TestHandleStreamableGet(t *testing.T) { }) t.Run("wrong Accept header", func(t *testing.T) { - config := &config.MCPClientConfig{ + serverConfig := &config.MCPClientConfig{ URL: "http://example.com", TransportType: config.MCPClientTypeStreamable, } - handler := createTestMCPHandler("test-streamable", config) + handler := createTestMCPHandler("test-streamable", serverConfig) // Create request with wrong Accept header req := httptest.NewRequest(http.MethodGet, "/test-streamable", nil) req.Header.Set("Accept", "application/json") rec := httptest.NewRecorder() - handler.handleStreamableGet(context.Background(), rec, req, "user@example.com", config) + handler.handleStreamableGet(context.Background(), rec, req, "user@example.com", serverConfig) // Should return 406 Not Acceptable assert.Equal(t, http.StatusNotAcceptable, rec.Code) @@ -650,12 +658,12 @@ func TestStreamableTransportRouting(t *testing.T) { })) defer backend.Close() - config := &config.MCPClientConfig{ + serverConfig := &config.MCPClientConfig{ URL: backend.URL, TransportType: config.MCPClientTypeStreamable, } - handler := createTestMCPHandler("test-streamable", config) + handler := createTestMCPHandler("test-streamable", serverConfig) req := httptest.NewRequest(http.MethodPost, "/test-streamable", bytes.NewReader([]byte("{}"))) req.Header.Set("Content-Type", "application/json") @@ -675,12 +683,12 @@ func TestStreamableTransportRouting(t *testing.T) { })) defer backend.Close() - config := &config.MCPClientConfig{ + serverConfig := &config.MCPClientConfig{ URL: backend.URL, TransportType: config.MCPClientTypeStreamable, } - handler := createTestMCPHandler("test-streamable", config) + handler := createTestMCPHandler("test-streamable", serverConfig) req := httptest.NewRequest(http.MethodGet, "/test-streamable", nil) req.Header.Set("Accept", "text/event-stream") @@ -693,12 +701,12 @@ func TestStreamableTransportRouting(t *testing.T) { }) t.Run("unsupported method returns 405", func(t *testing.T) { - config := &config.MCPClientConfig{ + serverConfig := &config.MCPClientConfig{ URL: "http://example.com", TransportType: config.MCPClientTypeStreamable, } - handler := createTestMCPHandler("test-streamable", config) + handler := createTestMCPHandler("test-streamable", serverConfig) req := httptest.NewRequest(http.MethodPut, "/test-streamable", nil) rec := httptest.NewRecorder() diff --git a/internal/server/middleware.go b/internal/server/middleware.go index 839e707..4729a46 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -2,15 +2,22 @@ package server import ( "encoding/base64" + "encoding/json" "net/http" + "slices" "strings" "time" - "github.com/dgellow/mcp-front/internal/auth" + "github.com/dgellow/mcp-front/internal/adminauth" + "github.com/dgellow/mcp-front/internal/browserauth" "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/cookie" + "github.com/dgellow/mcp-front/internal/crypto" + "github.com/dgellow/mcp-front/internal/googleauth" jsonwriter "github.com/dgellow/mcp-front/internal/json" "github.com/dgellow/mcp-front/internal/log" "github.com/dgellow/mcp-front/internal/oauth" + "github.com/dgellow/mcp-front/internal/servicecontext" "github.com/dgellow/mcp-front/internal/storage" "golang.org/x/crypto/bcrypt" ) @@ -18,16 +25,16 @@ import ( // MiddlewareFunc is a function that wraps an http.Handler type MiddlewareFunc func(http.Handler) http.Handler -// chainMiddleware chains multiple middleware functions -func chainMiddleware(h http.Handler, middlewares ...MiddlewareFunc) http.Handler { +// ChainMiddleware chains multiple middleware functions +func ChainMiddleware(h http.Handler, middlewares ...MiddlewareFunc) http.Handler { for _, mw := range middlewares { h = mw(h) } return h } -// corsMiddleware adds CORS headers to responses -func corsMiddleware(allowedOrigins []string) MiddlewareFunc { +// NewCORSMiddleware adds CORS headers to responses +func NewCORSMiddleware(allowedOrigins []string) MiddlewareFunc { // Build a map for faster lookup allowedMap := make(map[string]bool) for _, origin := range allowedOrigins { @@ -124,7 +131,7 @@ var _ http.ResponseWriter = (*responseWriterDelegator)(nil) var _ http.Flusher = (*responseWriterDelegator)(nil) // loggerMiddleware adds request/response logging -func loggerMiddleware(prefix string) MiddlewareFunc { +func NewLoggerMiddleware(prefix string) MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() @@ -133,7 +140,7 @@ func loggerMiddleware(prefix string) MiddlewareFunc { next.ServeHTTP(wrapped, r) // Log request with response details - fields := map[string]interface{}{ + fields := map[string]any{ "method": r.Method, "path": r.URL.Path, "status": wrapped.Status(), @@ -152,8 +159,8 @@ func loggerMiddleware(prefix string) MiddlewareFunc { } } -// recoverMiddleware recovers from panics -func recoverMiddleware(prefix string) MiddlewareFunc { +// NewRecoverMiddleware recovers from panics +func NewRecoverMiddleware(prefix string) MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { @@ -175,7 +182,7 @@ func newServiceAuthMiddleware(serviceAuths []config.ServiceAuth) MiddlewareFunc // Check if user context is already set — OAuth succeeded, no need for further auth if userEmail, ok := oauth.GetUserFromContext(ctx); ok && userEmail != "" { - log.LogTraceWithFields("service_auth", "Skipping service auth, user already authenticated via OAuth", map[string]interface{}{ + log.LogTraceWithFields("service_auth", "Skipping service auth, user already authenticated via OAuth", map[string]any{ "user": userEmail, }) next.ServeHTTP(w, r) @@ -197,16 +204,14 @@ func newServiceAuthMiddleware(serviceAuths []config.ServiceAuth) MiddlewareFunc continue } - for _, allowedToken := range serviceAuth.Tokens { - if token == allowedToken { - // Auth succeeded - log.LogTraceWithFields("service_auth", "Bearer token service auth successful", map[string]interface{}{ - "service_name": "service", - }) - ctx := auth.WithServiceAuth(r.Context(), "service", serviceAuth.ResolvedUserToken) - next.ServeHTTP(w, r.WithContext(ctx)) - return - } + if slices.Contains(serviceAuth.Tokens, token) { + // Auth succeeded + log.LogTraceWithFields("service_auth", "Bearer token service auth successful", map[string]any{ + "service_name": "service", + }) + ctx := servicecontext.WithAuthInfo(r.Context(), "service", string(serviceAuth.UserToken)) + next.ServeHTTP(w, r.WithContext(ctx)) + return } } log.LogTraceWithFields("service_auth", "Bearer token service auth failed: invalid token", nil) @@ -217,7 +222,7 @@ func newServiceAuthMiddleware(serviceAuths []config.ServiceAuth) MiddlewareFunc log.LogTraceWithFields("service_auth", "Attempting basic service auth", nil) decoded, err := base64.StdEncoding.DecodeString(encoded) if err != nil { - log.LogTraceWithFields("service_auth", "Basic service auth failed: invalid base64 encoding", map[string]interface{}{ + log.LogTraceWithFields("service_auth", "Basic service auth failed: invalid base64 encoding", map[string]any{ "error": err.Error(), }) w.Header().Set("WWW-Authenticate", `Basic realm="mcp-front"`) @@ -243,12 +248,12 @@ func newServiceAuthMiddleware(serviceAuths []config.ServiceAuth) MiddlewareFunc } if username == serviceAuth.Username { - if err := bcrypt.CompareHashAndPassword([]byte(serviceAuth.HashedPassword), []byte(password)); err == nil { + if err := bcrypt.CompareHashAndPassword([]byte(string(serviceAuth.HashedPassword)), []byte(password)); err == nil { // Auth succeeded - log.LogTraceWithFields("service_auth", "Basic service auth successful", map[string]interface{}{ + log.LogTraceWithFields("service_auth", "Basic service auth successful", map[string]any{ "username": username, }) - ctx := auth.WithServiceAuth(r.Context(), serviceAuth.Username, serviceAuth.ResolvedUserToken) + ctx := servicecontext.WithAuthInfo(r.Context(), serviceAuth.Username, string(serviceAuth.UserToken)) next.ServeHTTP(w, r.WithContext(ctx)) return } @@ -272,7 +277,7 @@ func adminMiddleware(adminConfig *config.AdminConfig, store storage.Storage) Mid return } - if !auth.IsAdmin(r.Context(), userEmail, adminConfig, store) { + if !adminauth.IsAdmin(r.Context(), userEmail, adminConfig, store) { jsonwriter.WriteForbidden(w, "Forbidden - Admin access required") return } @@ -281,3 +286,90 @@ func adminMiddleware(adminConfig *config.AdminConfig, store storage.Storage) Mid }) } } + +// NewBrowserSSOMiddleware creates middleware for browser-based SSO authentication +func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor crypto.Encryptor, browserStateToken *crypto.TokenSigner) MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check for session cookie + sessionValue, err := cookie.GetSession(r) + if err != nil { + // No cookie, redirect directly to OAuth + state := generateBrowserState(browserStateToken, r.URL.String()) + if state == "" { + jsonwriter.WriteInternalServerError(w, "Failed to generate authentication state") + return + } + googleURL := googleauth.GoogleAuthURL(authConfig, state) + http.Redirect(w, r, googleURL, http.StatusFound) + return + } + + // Decrypt cookie + decrypted, err := sessionEncryptor.Decrypt(sessionValue) + if err != nil { + // Invalid cookie, redirect to OAuth + log.LogDebug("Invalid session cookie: %v", err) + cookie.ClearSession(w) // Clear bad cookie + state := generateBrowserState(browserStateToken, r.URL.String()) + googleURL := googleauth.GoogleAuthURL(authConfig, state) + http.Redirect(w, r, googleURL, http.StatusFound) + return + } + + // Parse session data + var sessionData browserauth.SessionCookie + if err := json.NewDecoder(strings.NewReader(decrypted)).Decode(&sessionData); err != nil { + // Invalid format + cookie.ClearSession(w) + jsonwriter.WriteUnauthorized(w, "Invalid session") + return + } + + // Check expiration + if time.Now().After(sessionData.Expires) { + log.LogDebug("Session expired for user %s", sessionData.Email) + cookie.ClearSession(w) + // Redirect directly to Google OAuth + state := generateBrowserState(browserStateToken, r.URL.String()) + if state == "" { + jsonwriter.WriteInternalServerError(w, "Failed to generate authentication state") + return + } + googleURL := googleauth.GoogleAuthURL(authConfig, state) + http.Redirect(w, r, googleURL, http.StatusFound) + return + } + + // Valid session, set user in context + ctx := servicecontext.WithUser(r.Context(), sessionData.Email) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// generateBrowserState creates a secure state parameter for browser SSO +func generateBrowserState(browserStateToken *crypto.TokenSigner, returnURL string) string { + state := browserauth.AuthorizationState{ + Nonce: crypto.GenerateSecureToken(), + ReturnURL: returnURL, + } + + token, err := browserStateToken.Sign(state) + if err != nil { + log.LogError("Failed to sign browser state: %v", err) + // Return empty string to trigger auth failure - middleware will handle it + return "" + } + return "browser:" + token +} + +// NewServiceAuthMiddleware creates middleware for service-to-service authentication +func NewServiceAuthMiddleware(serviceAuths []config.ServiceAuth) MiddlewareFunc { + return newServiceAuthMiddleware(serviceAuths) +} + +// NewAdminMiddleware creates middleware for admin access control +func NewAdminMiddleware(adminConfig *config.AdminConfig, store storage.Storage) MiddlewareFunc { + return adminMiddleware(adminConfig, store) +} diff --git a/internal/server/middleware_test.go b/internal/server/middleware_test.go index 6cad420..811107b 100644 --- a/internal/server/middleware_test.go +++ b/internal/server/middleware_test.go @@ -8,6 +8,7 @@ import ( "github.com/dgellow/mcp-front/internal/auth" "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/servicecontext" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" @@ -179,7 +180,7 @@ func TestServiceAuthMiddleware(t *testing.T) { { Type: config.ServiceAuthTypeBasic, Username: "user", - HashedPassword: string(hashedPassword), + HashedPassword: config.Secret(hashedPassword), }, } @@ -268,15 +269,15 @@ func TestServiceAuthMiddleware_Context(t *testing.T) { serviceAuths := []config.ServiceAuth{ { - Type: config.ServiceAuthTypeBearer, - Tokens: []string{"valid-token"}, - ResolvedUserToken: "bearer-user-token", + Type: config.ServiceAuthTypeBearer, + Tokens: []string{"valid-token"}, + UserToken: config.Secret("bearer-user-token"), }, { - Type: config.ServiceAuthTypeBasic, - Username: "user", - HashedPassword: string(hashedPassword), - ResolvedUserToken: "basic-user-token", + Type: config.ServiceAuthTypeBasic, + Username: "user", + HashedPassword: config.Secret(hashedPassword), + UserToken: config.Secret("basic-user-token"), }, } @@ -285,14 +286,14 @@ func TestServiceAuthMiddleware_Context(t *testing.T) { authHeader string expectStatus int expectServiceAuth bool - expectAuthInfo auth.ServiceAuthInfo + expectAuthInfo servicecontext.Info }{ { name: "bearer token sets context", authHeader: "Bearer valid-token", expectStatus: http.StatusOK, expectServiceAuth: true, - expectAuthInfo: auth.ServiceAuthInfo{ + expectAuthInfo: servicecontext.Info{ ServiceName: "service", UserToken: "bearer-user-token", }, @@ -302,7 +303,7 @@ func TestServiceAuthMiddleware_Context(t *testing.T) { authHeader: "Basic " + base64.StdEncoding.EncodeToString([]byte("user:password123")), expectStatus: http.StatusOK, expectServiceAuth: true, - expectAuthInfo: auth.ServiceAuthInfo{ + expectAuthInfo: servicecontext.Info{ ServiceName: "user", UserToken: "basic-user-token", }, @@ -317,11 +318,11 @@ func TestServiceAuthMiddleware_Context(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var actualAuthInfo auth.ServiceAuthInfo + var actualAuthInfo servicecontext.Info var hasAuthInfo bool handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - actualAuthInfo, hasAuthInfo = auth.GetServiceAuth(r.Context()) + actualAuthInfo, hasAuthInfo = servicecontext.GetAuthInfo(r.Context()) w.WriteHeader(http.StatusOK) }) diff --git a/internal/server/server.go b/internal/server/server.go deleted file mode 100644 index 96e1338..0000000 --- a/internal/server/server.go +++ /dev/null @@ -1,99 +0,0 @@ -package server - -import ( - "context" - "errors" - "fmt" - "net/http" - "os" - "os/signal" - "syscall" - "time" - - "github.com/dgellow/mcp-front/internal/config" - "github.com/dgellow/mcp-front/internal/log" -) - -// Run starts and runs the MCP proxy server -func Run(cfg *config.Config) error { - log.LogInfoWithFields("server", "Starting MCP proxy server", map[string]interface{}{ - "addr": cfg.Proxy.Addr, - "baseURL": cfg.Proxy.BaseURL, - "mcpServers": len(cfg.MCPServers), - }) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Create the server handler - handler, err := NewServer(ctx, cfg) - if err != nil { - return fmt.Errorf("failed to create server: %w", err) - } - - httpServer := &http.Server{ - Addr: cfg.Proxy.Addr, - Handler: handler, - } - - // Channel to signal errors that should trigger shutdown - errChan := make(chan error, 1) - - // Start HTTP server - go func() { - log.LogInfoWithFields("server", "HTTP server starting", map[string]interface{}{ - "addr": cfg.Proxy.Addr, - }) - if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- fmt.Errorf("server error: %w", err) - } - }() - - // Handle graceful shutdown - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - var shutdownReason string - select { - case sig := <-sigChan: - shutdownReason = fmt.Sprintf("signal %v", sig) - log.LogInfoWithFields("server", "Received shutdown signal", map[string]interface{}{ - "signal": sig.String(), - }) - case err := <-errChan: - shutdownReason = fmt.Sprintf("error: %v", err) - log.LogErrorWithFields("server", "Shutting down due to error", map[string]interface{}{ - "error": err.Error(), - }) - case <-ctx.Done(): - shutdownReason = "context cancelled" - log.LogInfoWithFields("server", "Context cancelled, shutting down", nil) - } - - // Graceful shutdown - log.LogInfoWithFields("server", "Starting graceful shutdown", map[string]interface{}{ - "reason": shutdownReason, - "timeout": "30s", - }) - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) - defer shutdownCancel() - - if err := httpServer.Shutdown(shutdownCtx); err != nil && !errors.Is(err, http.ErrServerClosed) { - log.LogErrorWithFields("server", "HTTP server shutdown error", map[string]interface{}{ - "error": err.Error(), - }) - return err - } - - // Shutdown the handler (which includes session manager) - if err := handler.Shutdown(); err != nil { - log.LogErrorWithFields("server", "Handler shutdown error", map[string]interface{}{ - "error": err.Error(), - }) - } - - log.LogInfoWithFields("server", "Server shutdown complete", map[string]interface{}{ - "reason": shutdownReason, - }) - return nil -} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index f30c96d..f221cf3 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/dgellow/mcp-front/internal/oauth" + "github.com/dgellow/mcp-front/internal/auth" "github.com/dgellow/mcp-front/internal/storage" ) @@ -48,7 +48,7 @@ func TestHealthEndpoint(t *testing.T) { } func TestOAuthEndpointsCORS(t *testing.T) { - oauthConfig := oauth.Config{ + authConfig := auth.Config{ Issuer: "https://test.example.com", TokenTTL: time.Hour, AllowedDomains: []string{"example.com"}, @@ -61,7 +61,7 @@ func TestOAuthEndpointsCORS(t *testing.T) { } store := storage.NewMemoryStorage() - server, err := oauth.NewServer(oauthConfig, store) + authServer, err := auth.NewServer(authConfig, store) if err != nil { t.Fatalf("Failed to create OAuth server: %v", err) } @@ -105,9 +105,11 @@ func TestOAuthEndpointsCORS(t *testing.T) { switch endpoint.path { case "/.well-known/oauth-authorization-server": - handler = corsHandler(http.HandlerFunc(server.WellKnownHandler)) + authHandlers := NewAuthHandlers(authServer, nil, nil) + handler = corsHandler(http.HandlerFunc(authHandlers.WellKnownHandler)) case "/register": - handler = corsHandler(http.HandlerFunc(server.RegisterHandler)) + authHandlers := NewAuthHandlers(authServer, nil, nil) + handler = corsHandler(http.HandlerFunc(authHandlers.RegisterHandler)) default: t.Fatalf("Unknown endpoint: %s", endpoint.path) } diff --git a/internal/server/service_auth_handlers.go b/internal/server/service_auth_handlers.go new file mode 100644 index 0000000..bab07b2 --- /dev/null +++ b/internal/server/service_auth_handlers.go @@ -0,0 +1,255 @@ +package server + +import ( + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/dgellow/mcp-front/internal/auth" + "github.com/dgellow/mcp-front/internal/config" + jsonwriter "github.com/dgellow/mcp-front/internal/json" + "github.com/dgellow/mcp-front/internal/log" + "github.com/dgellow/mcp-front/internal/oauth" + "github.com/dgellow/mcp-front/internal/storage" +) + +// ServiceAuthHandlers handles OAuth flows for external services +type ServiceAuthHandlers struct { + oauthClient *auth.ServiceOAuthClient + mcpServers map[string]*config.MCPClientConfig + storage storage.Storage +} + +// NewServiceAuthHandlers creates new service auth handlers +func NewServiceAuthHandlers(oauthClient *auth.ServiceOAuthClient, mcpServers map[string]*config.MCPClientConfig, storage storage.Storage) *ServiceAuthHandlers { + return &ServiceAuthHandlers{ + oauthClient: oauthClient, + mcpServers: mcpServers, + storage: storage, + } +} + +// ConnectHandler initiates OAuth flow for a service +func (h *ServiceAuthHandlers) ConnectHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + jsonwriter.WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed") + return + } + + // Get authenticated user + userEmail, ok := oauth.GetUserFromContext(r.Context()) + if !ok { + jsonwriter.WriteUnauthorized(w, "Authentication required") + return + } + + // Get service name from query + serviceName := r.URL.Query().Get("service") + if serviceName == "" { + jsonwriter.WriteBadRequest(w, "Service name is required") + return + } + + // Validate service exists and supports OAuth + serviceConfig, exists := h.mcpServers[serviceName] + if !exists { + jsonwriter.WriteNotFound(w, "Service not found") + return + } + + if !serviceConfig.RequiresUserToken || + serviceConfig.UserAuthentication == nil || + serviceConfig.UserAuthentication.Type != config.UserAuthTypeOAuth { + jsonwriter.WriteBadRequest(w, "Service does not support OAuth") + return + } + + // Start OAuth flow - service OAuth always returns to interstitial page + authURL, err := h.oauthClient.StartOAuthFlow( + r.Context(), + userEmail, + serviceName, + serviceConfig, + ) + if err != nil { + log.LogErrorWithFields("oauth_handlers", "Failed to start OAuth flow", map[string]any{ + "service": serviceName, + "user": userEmail, + "error": err.Error(), + }) + jsonwriter.WriteInternalServerError(w, "Failed to start OAuth flow") + return + } + + // Redirect to authorization URL + http.Redirect(w, r, authURL, http.StatusFound) +} + +// CallbackHandler handles OAuth callbacks from services +func (h *ServiceAuthHandlers) CallbackHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + jsonwriter.WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed") + return + } + + // Extract service name from path: /oauth/callback/{service} + pathParts := strings.Split(strings.TrimPrefix(r.URL.Path, "/"), "/") + if len(pathParts) < 3 { + jsonwriter.WriteBadRequest(w, "Invalid callback path") + return + } + serviceName := pathParts[2] + if serviceName == "" { + jsonwriter.WriteBadRequest(w, "Service name is required") + return + } + + // Get authorization code and state + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errorParam := r.URL.Query().Get("error") + + // Handle OAuth errors + if errorParam != "" { + errorDesc := r.URL.Query().Get("error_description") + log.LogWarnWithFields("oauth_handlers", "OAuth error from provider", map[string]any{ + "service": serviceName, + "error": errorParam, + "description": errorDesc, + }) + + // Service OAuth errors always redirect back to interstitial page + // This maintains user context in the upstream OAuth flow + errorMsg := fmt.Sprintf("OAuth authorization failed: %s", errorParam) + if errorDesc != "" { + errorMsg = fmt.Sprintf("%s - %s", errorMsg, errorDesc) + } + + // Redirect to interstitial page with error + errorURL := fmt.Sprintf("/oauth/services?error=%s&service=%s&error_msg=%s", + url.QueryEscape(errorParam), + url.QueryEscape(serviceName), + url.QueryEscape(errorMsg), + ) + http.Redirect(w, r, errorURL, http.StatusFound) + return + } + + if code == "" || state == "" { + jsonwriter.WriteBadRequest(w, "Missing code or state parameter") + return + } + + // Validate service configuration + serviceConfig, exists := h.mcpServers[serviceName] + if !exists { + jsonwriter.WriteNotFound(w, "Service not found") + return + } + + // Handle callback + userEmail, err := h.oauthClient.HandleCallback( + r.Context(), + serviceName, + code, + state, + serviceConfig, + ) + if err != nil { + log.LogErrorWithFields("oauth_handlers", "Failed to handle OAuth callback", map[string]any{ + "service": serviceName, + "error": err.Error(), + }) + + // User-friendly error message + message := "Failed to complete OAuth authorization" + if strings.Contains(err.Error(), "invalid state") { + message = "OAuth session expired. Please try again" + } + + // Service OAuth callback errors always redirect back to interstitial page + // This maintains user context in the upstream OAuth flow + errorURL := fmt.Sprintf("/oauth/services?error=callback_failed&service=%s&error_msg=%s", + url.QueryEscape(serviceName), + url.QueryEscape(message), + ) + http.Redirect(w, r, errorURL, http.StatusFound) + return + } + + // Log successful connection + log.LogInfoWithFields("oauth_handlers", "OAuth connection successful", map[string]any{ + "service": serviceName, + "user": userEmail, + }) + + // Display name for success message + displayName := serviceName + if serviceConfig.UserAuthentication != nil && serviceConfig.UserAuthentication.DisplayName != "" { + displayName = serviceConfig.UserAuthentication.DisplayName + } + + // Service OAuth success always redirects back to interstitial page + // This maintains user context in the upstream OAuth flow + successURL := fmt.Sprintf("/oauth/services?message=%s&type=success", + strings.ReplaceAll(fmt.Sprintf("Successfully connected to %s", displayName), " ", "+"), + ) + http.Redirect(w, r, successURL, http.StatusFound) +} + +// DisconnectHandler revokes OAuth access for a service +func (h *ServiceAuthHandlers) DisconnectHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + jsonwriter.WriteError(w, http.StatusMethodNotAllowed, "method_not_allowed", "Method not allowed") + return + } + + // Get authenticated user + userEmail, ok := oauth.GetUserFromContext(r.Context()) + if !ok { + jsonwriter.WriteUnauthorized(w, "Authentication required") + return + } + + // Parse form + if err := r.ParseForm(); err != nil { + jsonwriter.WriteBadRequest(w, "Bad request") + return + } + + serviceName := r.FormValue("service") + if serviceName == "" { + jsonwriter.WriteBadRequest(w, "Service name is required") + return + } + + // Note: We don't validate CSRF for disconnect as it's less critical + // and the user is already authenticated + + // Delete the token + if err := h.storage.DeleteUserToken(r.Context(), userEmail, serviceName); err != nil { + log.LogErrorWithFields("oauth_handlers", "Failed to delete OAuth token", map[string]any{ + "service": serviceName, + "user": userEmail, + "error": err.Error(), + }) + jsonwriter.WriteInternalServerError(w, "Failed to disconnect") + return + } + + log.LogInfoWithFields("oauth_handlers", "OAuth disconnection successful", map[string]any{ + "service": serviceName, + "user": userEmail, + }) + + // Get display name + displayName := serviceName + if serviceConfig, exists := h.mcpServers[serviceName]; exists { + if serviceConfig.UserAuthentication != nil && serviceConfig.UserAuthentication.DisplayName != "" { + displayName = serviceConfig.UserAuthentication.DisplayName + } + } + + redirectWithMessage(w, r, fmt.Sprintf("Disconnected from %s", displayName), "success") +} diff --git a/internal/server/streamable_proxy.go b/internal/server/streamable_proxy.go index a467ac9..69a38b5 100644 --- a/internal/server/streamable_proxy.go +++ b/internal/server/streamable_proxy.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "maps" "net/http" "strings" @@ -106,9 +107,7 @@ func forwardStreamablePostToBackend(ctx context.Context, w http.ResponseWriter, } else { w.WriteHeader(resp.StatusCode) - for k, v := range resp.Header { - w.Header()[k] = v - } + maps.Copy(w.Header(), resp.Header) if _, err := io.Copy(w, resp.Body); err != nil { log.LogErrorWithFields("streamable_proxy", "Failed to copy response body", map[string]any{ diff --git a/internal/server/templates.go b/internal/server/templates.go index 536372f..40462d4 100644 --- a/internal/server/templates.go +++ b/internal/server/templates.go @@ -11,8 +11,12 @@ var tokenPageTemplateHTML string //go:embed templates/admin.html var adminPageTemplateHTML string +//go:embed templates/services.html +var servicesPageTemplateHTML string + var tokenPageTemplate = template.Must(template.New("tokens").Parse(tokenPageTemplateHTML)) var adminPageTemplate = template.Must(template.New("admin").Parse(adminPageTemplateHTML)) +var servicesPageTemplate = template.Must(template.New("services").Parse(servicesPageTemplateHTML)) // TokenPageData represents the data for the token management page type TokenPageData struct { @@ -25,12 +29,31 @@ type TokenPageData struct { // ServiceTokenData represents a single service in the token page type ServiceTokenData struct { - Name string - DisplayName string - Instructions string - HelpURL string - TokenFormat string - HasToken bool - RequiresToken bool - AuthType string // "oauth", "bearer", or "none" + Name string + DisplayName string + Instructions string + HelpURL string + TokenFormat string + HasToken bool + RequiresToken bool + AuthType string // "oauth", "bearer", or "none" + SupportsOAuth bool // Whether this service supports OAuth authentication + IsOAuthConnected bool // Whether the user has connected OAuth for this service + ConnectURL string // Pre-generated OAuth connect URL +} + +// ServicesPageData represents the data for the service selection page +type ServicesPageData struct { + Services []ServiceSelectionData + State string + ReturnURL string +} + +// ServiceSelectionData represents a single service in the selection page +type ServiceSelectionData struct { + Name string + DisplayName string + Status string // "not_connected", "connected", "error" + ErrorMsg string + ConnectURL string // Pre-generated OAuth connect URL } diff --git a/internal/server/templates/services.html b/internal/server/templates/services.html new file mode 100644 index 0000000..b30de81 --- /dev/null +++ b/internal/server/templates/services.html @@ -0,0 +1,141 @@ + + + + + + Connect Services - MCP Front + + + +
+

Optional Service Connections

+

Some MCP servers require additional authentication. You can connect them now or later.

+ +
+ {{range .Services}} +
+ {{if eq .Status "connected"}} + {{.DisplayName}} + ✓ Connected + {{else if eq .Status "error"}} +
+ {{.DisplayName}} + {{if .ErrorMsg}}
{{.ErrorMsg}}
{{end}} +
+ Try Again + {{else}} + {{.DisplayName}} + Connect + {{end}} +
+ {{end}} +
+ +
+

+ You can connect to additional services later by visiting your + token management page. +

+ Continue +
+
+ + \ No newline at end of file diff --git a/internal/server/templates/tokens.html b/internal/server/templates/tokens.html index c670883..0626d44 100644 --- a/internal/server/templates/tokens.html +++ b/internal/server/templates/tokens.html @@ -143,6 +143,30 @@ background-color: #c82333; } + button.oauth { + background-color: #4285f4; + color: white; + } + + button.oauth:hover { + background-color: #357ae8; + } + + .oauth-status { + display: flex; + align-items: center; + gap: 10px; + margin-top: 15px; + } + + .oauth-connected { + color: #2e7d32; + font-size: 14px; + display: flex; + align-items: center; + gap: 5px; + } + .message { padding: 12px; margin-bottom: 20px; @@ -206,25 +230,43 @@

{{.DisplayName}}

{{end}} -
- - - - -
- - {{if .HasToken}} -
- - - -
+ {{if .SupportsOAuth}} + {{if .IsOAuthConnected}} +
+ ✓ Connected via OAuth +
+ + +
+
+ {{else}} + + {{end}} + {{else}} +
+ + + + +
+ + {{if .HasToken}} +
+ + + +
+ {{end}} {{end}} {{else}}
diff --git a/internal/server/token_handlers.go b/internal/server/token_handlers.go index a7e5363..b2cbdaa 100644 --- a/internal/server/token_handlers.go +++ b/internal/server/token_handlers.go @@ -5,7 +5,9 @@ import ( "net/http" "strings" "sync" + "time" + "github.com/dgellow/mcp-front/internal/auth" "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/crypto" jsonwriter "github.com/dgellow/mcp-front/internal/json" @@ -16,18 +18,20 @@ import ( // TokenHandlers handles the web UI for token management type TokenHandlers struct { - tokenStore storage.UserTokenStore - mcpServers map[string]*config.MCPClientConfig - csrfTokens sync.Map // Thread-safe CSRF token storage - oauthEnabled bool + tokenStore storage.UserTokenStore + mcpServers map[string]*config.MCPClientConfig + csrfTokens sync.Map // Thread-safe CSRF token storage + oauthEnabled bool + serviceOAuthClient *auth.ServiceOAuthClient } // NewTokenHandlers creates a new token handlers instance -func NewTokenHandlers(tokenStore storage.UserTokenStore, mcpServers map[string]*config.MCPClientConfig, oauthEnabled bool) *TokenHandlers { +func NewTokenHandlers(tokenStore storage.UserTokenStore, mcpServers map[string]*config.MCPClientConfig, oauthEnabled bool, serviceOAuthClient *auth.ServiceOAuthClient) *TokenHandlers { return &TokenHandlers{ - tokenStore: tokenStore, - mcpServers: mcpServers, - oauthEnabled: oauthEnabled, + tokenStore: tokenStore, + mcpServers: mcpServers, + oauthEnabled: oauthEnabled, + serviceOAuthClient: serviceOAuthClient, } } @@ -70,35 +74,59 @@ func (h *TokenHandlers) ListTokensHandler(w http.ResponseWriter, r *http.Request // Build service list var services []ServiceTokenData - for name, config := range h.mcpServers { + for name, serverConfig := range h.mcpServers { service := ServiceTokenData{ Name: name, DisplayName: name, } // Determine authentication type - if config.RequiresUserToken { + if serverConfig.RequiresUserToken { service.RequiresToken = true service.Instructions = fmt.Sprintf("Please create a %s API token", name) - if config.TokenSetup != nil { - if config.TokenSetup.DisplayName != "" { - service.DisplayName = config.TokenSetup.DisplayName + if serverConfig.UserAuthentication != nil { + if serverConfig.UserAuthentication.DisplayName != "" { + service.DisplayName = serverConfig.UserAuthentication.DisplayName } - if config.TokenSetup.Instructions != "" { - service.Instructions = config.TokenSetup.Instructions + + // Check if this service supports OAuth + if serverConfig.UserAuthentication.Type == config.UserAuthTypeOAuth { + service.SupportsOAuth = true + service.Instructions = fmt.Sprintf("Connect your %s account via OAuth", service.DisplayName) + + // Generate OAuth connect URL if OAuth client is available + if h.serviceOAuthClient != nil { + service.ConnectURL = h.serviceOAuthClient.GetConnectURL(name, "/my/tokens") + } + + // Check if OAuth is already connected + storedToken, err := h.tokenStore.GetUserToken(r.Context(), userEmail, name) + if err == nil && storedToken.Type == storage.TokenTypeOAuth { + service.IsOAuthConnected = true + service.HasToken = true + } + } else if serverConfig.UserAuthentication.Type == config.UserAuthTypeManual { + if serverConfig.UserAuthentication.Instructions != "" { + service.Instructions = serverConfig.UserAuthentication.Instructions + } + service.HelpURL = serverConfig.UserAuthentication.HelpURL + + // Check if manual token exists + _, err := h.tokenStore.GetUserToken(r.Context(), userEmail, name) + service.HasToken = err == nil } - service.HelpURL = config.TokenSetup.HelpURL - service.TokenFormat = config.TokenSetup.TokenFormat + service.TokenFormat = serverConfig.UserAuthentication.TokenFormat + } else { + // No UserAuthentication config means manual token + _, err := h.tokenStore.GetUserToken(r.Context(), userEmail, name) + service.HasToken = err == nil } - - _, err := h.tokenStore.GetUserToken(r.Context(), userEmail, name) - service.HasToken = err == nil } else { // Determine if it's OAuth authenticated or uses bearer tokens if h.oauthEnabled { service.AuthType = "oauth" - } else if config.Options != nil && len(config.Options.AuthTokens) > 0 { + } else if serverConfig.Options != nil && len(serverConfig.Options.AuthTokens) > 0 { service.AuthType = "bearer" } else { service.AuthType = "none" @@ -111,7 +139,7 @@ func (h *TokenHandlers) ListTokensHandler(w http.ResponseWriter, r *http.Request // Generate CSRF token csrfToken, err := h.generateCSRFToken() if err != nil { - log.LogErrorWithFields("token", "Failed to generate CSRF token", map[string]interface{}{ + log.LogErrorWithFields("token", "Failed to generate CSRF token", map[string]any{ "error": err.Error(), "user": userEmail, }) @@ -130,7 +158,7 @@ func (h *TokenHandlers) ListTokensHandler(w http.ResponseWriter, r *http.Request w.Header().Set("Content-Type", "text/html; charset=utf-8") if err := tokenPageTemplate.Execute(w, data); err != nil { - log.LogErrorWithFields("token", "Failed to render token page", map[string]interface{}{ + log.LogErrorWithFields("token", "Failed to render token page", map[string]any{ "error": err.Error(), "user": userEmail, }) @@ -191,29 +219,32 @@ func (h *TokenHandlers) SetTokenHandler(w http.ResponseWriter, r *http.Request) return } - if serviceConfig.TokenSetup != nil && serviceConfig.TokenSetup.CompiledRegex != nil { - if !serviceConfig.TokenSetup.CompiledRegex.MatchString(token) { + if serviceConfig.UserAuthentication != nil && + serviceConfig.UserAuthentication.Type == config.UserAuthTypeManual && + serviceConfig.UserAuthentication.ValidationRegex != nil { + if !serviceConfig.UserAuthentication.ValidationRegex.MatchString(token) { var helpMsg string displayName := serviceName - if serviceConfig.TokenSetup.DisplayName != "" { - displayName = serviceConfig.TokenSetup.DisplayName + if serviceConfig.UserAuthentication.DisplayName != "" { + displayName = serviceConfig.UserAuthentication.DisplayName } // Provide specific error messages based on common token patterns + validation := serviceConfig.UserAuthentication.Validation switch { - case serviceConfig.TokenSetup.TokenFormat == "^[A-Za-z0-9_-]+$": + case validation == "^[A-Za-z0-9_-]+$": helpMsg = fmt.Sprintf("%s token must contain only letters, numbers, underscores, and hyphens", displayName) - case strings.Contains(serviceConfig.TokenSetup.TokenFormat, "^[A-Fa-f0-9]{64}$"): + case strings.Contains(validation, "^[A-Fa-f0-9]{64}$"): helpMsg = fmt.Sprintf("%s token must be a 64-character hexadecimal string", displayName) - case strings.Contains(serviceConfig.TokenSetup.TokenFormat, "Bearer "): + case strings.Contains(serviceConfig.UserAuthentication.TokenFormat, "Bearer "): helpMsg = fmt.Sprintf("%s token should not include 'Bearer' prefix - just the token value", displayName) default: - if serviceConfig.TokenSetup.HelpURL != "" { + if serviceConfig.UserAuthentication.HelpURL != "" { helpMsg = fmt.Sprintf("Invalid %s token format. Please check the required format at %s", - displayName, serviceConfig.TokenSetup.HelpURL) + displayName, serviceConfig.UserAuthentication.HelpURL) } else { helpMsg = fmt.Sprintf("Invalid %s token format. Expected pattern: %s", - displayName, serviceConfig.TokenSetup.TokenFormat) + displayName, validation) } } redirectWithMessage(w, r, helpMsg, "error") @@ -221,8 +252,15 @@ func (h *TokenHandlers) SetTokenHandler(w http.ResponseWriter, r *http.Request) } } - if err := h.tokenStore.SetUserToken(r.Context(), userEmail, serviceName, token); err != nil { - log.LogErrorWithFields("token", "Failed to store token", map[string]interface{}{ + // Create StoredToken for manual entry + storedToken := &storage.StoredToken{ + Type: storage.TokenTypeManual, + Value: token, + UpdatedAt: time.Now(), + } + + if err := h.tokenStore.SetUserToken(r.Context(), userEmail, serviceName, storedToken); err != nil { + log.LogErrorWithFields("token", "Failed to store token", map[string]any{ "error": err.Error(), "user": userEmail, "service": serviceName, @@ -232,11 +270,11 @@ func (h *TokenHandlers) SetTokenHandler(w http.ResponseWriter, r *http.Request) } displayName := serviceName - if serviceConfig.TokenSetup != nil && serviceConfig.TokenSetup.DisplayName != "" { - displayName = serviceConfig.TokenSetup.DisplayName + if serviceConfig.UserAuthentication != nil && serviceConfig.UserAuthentication.DisplayName != "" { + displayName = serviceConfig.UserAuthentication.DisplayName } - log.LogInfoWithFields("token", "User configured token", map[string]interface{}{ + log.LogInfoWithFields("token", "User configured token", map[string]any{ "user": userEmail, "service": serviceName, "action": "set_token", @@ -284,7 +322,7 @@ func (h *TokenHandlers) DeleteTokenHandler(w http.ResponseWriter, r *http.Reques } if err := h.tokenStore.DeleteUserToken(r.Context(), userEmail, serviceName); err != nil { - log.LogErrorWithFields("token", "Failed to delete token", map[string]interface{}{ + log.LogErrorWithFields("token", "Failed to delete token", map[string]any{ "error": err.Error(), "user": userEmail, "service": serviceName, @@ -294,11 +332,11 @@ func (h *TokenHandlers) DeleteTokenHandler(w http.ResponseWriter, r *http.Reques } displayName := serviceName - if serviceConfig.TokenSetup != nil && serviceConfig.TokenSetup.DisplayName != "" { - displayName = serviceConfig.TokenSetup.DisplayName + if serviceConfig.UserAuthentication != nil && serviceConfig.UserAuthentication.DisplayName != "" { + displayName = serviceConfig.UserAuthentication.DisplayName } - log.LogInfoWithFields("token", "User deleted token", map[string]interface{}{ + log.LogInfoWithFields("token", "User deleted token", map[string]any{ "user": userEmail, "service": serviceName, "action": "delete_token", diff --git a/internal/auth/context.go b/internal/servicecontext/context.go similarity index 60% rename from internal/auth/context.go rename to internal/servicecontext/context.go index 38aba4b..38e6239 100644 --- a/internal/auth/context.go +++ b/internal/servicecontext/context.go @@ -1,4 +1,4 @@ -package auth +package servicecontext import ( "context" @@ -11,8 +11,8 @@ const ( serviceAuthKey contextKey = "auth.service" ) -// ServiceAuthInfo contains service authentication details -type ServiceAuthInfo struct { +// Info contains service authentication details +type Info struct { ServiceName string UserToken string } @@ -28,23 +28,23 @@ func GetUser(ctx context.Context) (string, bool) { return username, ok } -// WithServiceAuth adds service authentication info to the context -func WithServiceAuth(ctx context.Context, serviceName, userToken string) context.Context { - return context.WithValue(ctx, serviceAuthKey, ServiceAuthInfo{ +// WithAuthInfo adds service authentication info to the context +func WithAuthInfo(ctx context.Context, serviceName, userToken string) context.Context { + return context.WithValue(ctx, serviceAuthKey, Info{ ServiceName: serviceName, UserToken: userToken, }) } -// GetServiceAuth retrieves service auth info from context -func GetServiceAuth(ctx context.Context) (ServiceAuthInfo, bool) { - info, ok := ctx.Value(serviceAuthKey).(ServiceAuthInfo) +// GetAuthInfo retrieves service auth info from context +func GetAuthInfo(ctx context.Context) (Info, bool) { + info, ok := ctx.Value(serviceAuthKey).(Info) return info, ok } // GetServiceName retrieves the service name from context func GetServiceName(ctx context.Context) (string, bool) { - info, ok := GetServiceAuth(ctx) + info, ok := GetAuthInfo(ctx) if !ok { return "", false } diff --git a/internal/sse/writer.go b/internal/sse/writer.go index e863c98..0d2773e 100644 --- a/internal/sse/writer.go +++ b/internal/sse/writer.go @@ -7,7 +7,7 @@ import ( ) // WriteMessage writes a SSE message to the response writer -func WriteMessage(w http.ResponseWriter, flusher http.Flusher, data interface{}) error { +func WriteMessage(w http.ResponseWriter, flusher http.Flusher, data any) error { jsonData, err := json.Marshal(data) if err != nil { return fmt.Errorf("failed to marshal data: %w", err) diff --git a/internal/storage/firestore.go b/internal/storage/firestore.go index 269240c..fc54c31 100644 --- a/internal/storage/firestore.go +++ b/internal/storage/firestore.go @@ -3,6 +3,7 @@ package storage import ( "context" "fmt" + "maps" "sync" "time" @@ -34,10 +35,12 @@ var _ fosite.Storage = (*FirestoreStorage)(nil) // UserTokenDoc represents a user token document in Firestore type UserTokenDoc struct { - UserEmail string `firestore:"user_email"` - Service string `firestore:"service"` - Token string `firestore:"token"` // Encrypted - UpdatedAt time.Time `firestore:"updated_at"` + UserEmail string `firestore:"user_email"` + Service string `firestore:"service"` + Type TokenType `firestore:"type"` + Value string `firestore:"value,omitempty"` // Encrypted manual token + OAuthData *OAuthTokenData `firestore:"oauth_data,omitempty"` // OAuth metadata (tokens encrypted) + UpdatedAt time.Time `firestore:"updated_at"` } // OAuthClientEntity represents the structure stored in Firestore @@ -337,9 +340,7 @@ func (s *FirestoreStorage) GetAllClients() map[string]fosite.Client { // Create a copy to avoid race conditions clients := make(map[string]fosite.Client, len(s.MemoryStore.Clients)) - for id, client := range s.MemoryStore.Clients { - clients[id] = client - } + maps.Copy(clients, s.MemoryStore.Clients) return clients } @@ -361,47 +362,121 @@ func (s *FirestoreStorage) makeUserTokenDocID(userEmail, service string) string } // GetUserToken retrieves a user's token for a specific service -func (s *FirestoreStorage) GetUserToken(ctx context.Context, userEmail, service string) (string, error) { +func (s *FirestoreStorage) GetUserToken(ctx context.Context, userEmail, service string) (*StoredToken, error) { docID := s.makeUserTokenDocID(userEmail, service) doc, err := s.client.Collection(s.tokenCollection).Doc(docID).Get(ctx) if err != nil { if status.Code(err) == codes.NotFound { - return "", ErrUserTokenNotFound + return nil, ErrUserTokenNotFound } - return "", fmt.Errorf("failed to get token from Firestore: %w", err) + return nil, fmt.Errorf("failed to get token from Firestore: %w", err) } var tokenDoc UserTokenDoc if err := doc.DataTo(&tokenDoc); err != nil { - return "", fmt.Errorf("failed to unmarshal token: %w", err) + return nil, fmt.Errorf("failed to unmarshal token: %w", err) } - // Decrypt the token - decrypted, err := s.encryptor.Decrypt(tokenDoc.Token) - if err != nil { - return "", fmt.Errorf("failed to decrypt token: %w", err) + // Build StoredToken + storedToken := &StoredToken{ + Type: tokenDoc.Type, + UpdatedAt: tokenDoc.UpdatedAt, + } + + // Decrypt based on type + switch tokenDoc.Type { + case TokenTypeManual: + if tokenDoc.Value != "" { + decrypted, err := s.encryptor.Decrypt(tokenDoc.Value) + if err != nil { + return nil, fmt.Errorf("failed to decrypt manual token: %w", err) + } + storedToken.Value = decrypted + } + case TokenTypeOAuth: + if tokenDoc.OAuthData != nil { + // Decrypt OAuth tokens + decryptedAccess, err := s.encryptor.Decrypt(tokenDoc.OAuthData.AccessToken) + if err != nil { + return nil, fmt.Errorf("failed to decrypt access token: %w", err) + } + + oauthData := &OAuthTokenData{ + AccessToken: decryptedAccess, + TokenType: tokenDoc.OAuthData.TokenType, + ExpiresAt: tokenDoc.OAuthData.ExpiresAt, + Scopes: tokenDoc.OAuthData.Scopes, + } + + if tokenDoc.OAuthData.RefreshToken != "" { + decryptedRefresh, err := s.encryptor.Decrypt(tokenDoc.OAuthData.RefreshToken) + if err != nil { + return nil, fmt.Errorf("failed to decrypt refresh token: %w", err) + } + oauthData.RefreshToken = decryptedRefresh + } + + storedToken.OAuthData = oauthData + } } - return decrypted, nil + return storedToken, nil } // SetUserToken stores or updates a user's token for a specific service -func (s *FirestoreStorage) SetUserToken(ctx context.Context, userEmail, service, token string) error { - // Encrypt the token before storing - encrypted, err := s.encryptor.Encrypt(token) - if err != nil { - return fmt.Errorf("failed to encrypt token: %w", err) +func (s *FirestoreStorage) SetUserToken(ctx context.Context, userEmail, service string, token *StoredToken) error { + if token == nil { + return fmt.Errorf("token cannot be nil") } docID := s.makeUserTokenDocID(userEmail, service) tokenDoc := UserTokenDoc{ UserEmail: userEmail, Service: service, - Token: encrypted, + Type: token.Type, UpdatedAt: time.Now(), } - _, err = s.client.Collection(s.tokenCollection).Doc(docID).Set(ctx, tokenDoc) + // Encrypt based on type + switch token.Type { + case TokenTypeManual: + if token.Value != "" { + encrypted, err := s.encryptor.Encrypt(token.Value) + if err != nil { + return fmt.Errorf("failed to encrypt manual token: %w", err) + } + tokenDoc.Value = encrypted + } + case TokenTypeOAuth: + if token.OAuthData != nil { + // Encrypt OAuth tokens + encryptedAccess, err := s.encryptor.Encrypt(token.OAuthData.AccessToken) + if err != nil { + return fmt.Errorf("failed to encrypt access token: %w", err) + } + + oauthData := &OAuthTokenData{ + AccessToken: encryptedAccess, + TokenType: token.OAuthData.TokenType, + ExpiresAt: token.OAuthData.ExpiresAt, + Scopes: token.OAuthData.Scopes, + } + + if token.OAuthData.RefreshToken != "" { + encryptedRefresh, err := s.encryptor.Encrypt(token.OAuthData.RefreshToken) + if err != nil { + return fmt.Errorf("failed to encrypt refresh token: %w", err) + } + oauthData.RefreshToken = encryptedRefresh + } + + tokenDoc.OAuthData = oauthData + } + default: + return fmt.Errorf("unknown token type: %s", token.Type) + } + + _, err := s.client.Collection(s.tokenCollection).Doc(docID).Set(ctx, tokenDoc) if err != nil { return fmt.Errorf("failed to store token in Firestore: %w", err) } diff --git a/internal/storage/memory.go b/internal/storage/memory.go index ce274e7..d01d2e2 100644 --- a/internal/storage/memory.go +++ b/internal/storage/memory.go @@ -2,6 +2,8 @@ package storage import ( "context" + "fmt" + "maps" "strings" "sync" "time" @@ -19,9 +21,9 @@ var _ fosite.Storage = (*MemoryStorage)(nil) // It extends the MemoryStore with thread-safe client management type MemoryStorage struct { *storage.MemoryStore - stateCache sync.Map // map[string]fosite.AuthorizeRequester - clientsMutex sync.RWMutex // For thread-safe client access - userTokens map[string]string // map["email:service"] = token + stateCache sync.Map // map[string]fosite.AuthorizeRequester + clientsMutex sync.RWMutex // For thread-safe client access + userTokens map[string]*StoredToken // map["email:service"] = token userTokensMutex sync.RWMutex users map[string]*UserInfo // map[email] = UserInfo usersMutex sync.RWMutex @@ -33,7 +35,7 @@ type MemoryStorage struct { func NewMemoryStorage() *MemoryStorage { return &MemoryStorage{ MemoryStore: storage.NewMemoryStore(), - userTokens: make(map[string]string), + userTokens: make(map[string]*StoredToken), users: make(map[string]*UserInfo), sessions: make(map[string]*ActiveSession), } @@ -121,9 +123,7 @@ func (s *MemoryStorage) GetAllClients() map[string]fosite.Client { defer s.clientsMutex.RUnlock() clients := make(map[string]fosite.Client, len(s.MemoryStore.Clients)) // Copy to avoid races - for id, client := range s.MemoryStore.Clients { - clients[id] = client - } + maps.Copy(clients, s.MemoryStore.Clients) return clients } @@ -140,20 +140,24 @@ func (s *MemoryStorage) makeUserTokenKey(userEmail, service string) string { } // GetUserToken retrieves a user's token for a specific service -func (s *MemoryStorage) GetUserToken(ctx context.Context, userEmail, service string) (string, error) { +func (s *MemoryStorage) GetUserToken(ctx context.Context, userEmail, service string) (*StoredToken, error) { s.userTokensMutex.RLock() defer s.userTokensMutex.RUnlock() key := s.makeUserTokenKey(userEmail, service) token, exists := s.userTokens[key] if !exists { - return "", ErrUserTokenNotFound + return nil, ErrUserTokenNotFound } return token, nil } // SetUserToken stores or updates a user's token for a specific service -func (s *MemoryStorage) SetUserToken(ctx context.Context, userEmail, service, token string) error { +func (s *MemoryStorage) SetUserToken(ctx context.Context, userEmail, service string, token *StoredToken) error { + if token == nil { + return fmt.Errorf("token cannot be nil") + } + s.userTokensMutex.Lock() defer s.userTokensMutex.Unlock() @@ -180,8 +184,8 @@ func (s *MemoryStorage) ListUserServices(ctx context.Context, userEmail string) var services []string prefix := userEmail + ":" for key := range s.userTokens { - if strings.HasPrefix(key, prefix) { - service := strings.TrimPrefix(key, prefix) + if after, ok := strings.CutPrefix(key, prefix); ok { + service := after services = append(services, service) } } diff --git a/internal/storage/storage.go b/internal/storage/storage.go index a03f739..2f241e4 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -27,6 +27,31 @@ type UserInfo struct { IsAdmin bool `json:"is_admin"` } +// TokenType represents the type of stored token +type TokenType string + +const ( + TokenTypeManual TokenType = "manual" + TokenTypeOAuth TokenType = "oauth" +) + +// OAuthTokenData represents OAuth token metadata +type OAuthTokenData struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + Scopes []string `json:"scopes,omitempty"` +} + +// StoredToken represents a token with its metadata +type StoredToken struct { + Type TokenType `json:"type"` + Value string `json:"value,omitempty"` // For manual tokens + OAuthData *OAuthTokenData `json:"oauth,omitempty"` // For OAuth tokens + UpdatedAt time.Time `json:"updated_at"` +} + // ActiveSession represents an active MCP session type ActiveSession struct { SessionID string `json:"session_id"` @@ -40,8 +65,8 @@ type ActiveSession struct { // This interface is used by handlers that need to access user-specific tokens // for external services (e.g., Notion, GitHub). type UserTokenStore interface { - GetUserToken(ctx context.Context, userEmail, service string) (string, error) - SetUserToken(ctx context.Context, userEmail, service, token string) error + GetUserToken(ctx context.Context, userEmail, service string) (*StoredToken, error) + SetUserToken(ctx context.Context, userEmail, service string, token *StoredToken) error DeleteUserToken(ctx context.Context, userEmail, service string) error ListUserServices(ctx context.Context, userEmail string) ([]string, error) } diff --git a/internal/testutil/mocks.go b/internal/testutil/mocks.go index 1c3f3f6..903582f 100644 --- a/internal/testutil/mocks.go +++ b/internal/testutil/mocks.go @@ -3,6 +3,7 @@ package testutil import ( "context" + "github.com/dgellow/mcp-front/internal/storage" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/mock" @@ -157,12 +158,15 @@ type MockUserTokenStore struct { mock.Mock } -func (m *MockUserTokenStore) GetUserToken(ctx context.Context, userEmail, serverName string) (string, error) { +func (m *MockUserTokenStore) GetUserToken(ctx context.Context, userEmail, serverName string) (*storage.StoredToken, error) { args := m.Called(ctx, userEmail, serverName) - return args.String(0), args.Error(1) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*storage.StoredToken), args.Error(1) } -func (m *MockUserTokenStore) SetUserToken(ctx context.Context, userEmail, serverName, token string) error { +func (m *MockUserTokenStore) SetUserToken(ctx context.Context, userEmail, serverName string, token *storage.StoredToken) error { args := m.Called(ctx, userEmail, serverName, token) return args.Error(0) } diff --git a/internal/utils/email.go b/internal/utils/email.go deleted file mode 100644 index 36b3ab5..0000000 --- a/internal/utils/email.go +++ /dev/null @@ -1,9 +0,0 @@ -package utils - -import "strings" - -// NormalizeEmail normalizes an email address for consistent comparison -// by converting to lowercase and trimming whitespace -func NormalizeEmail(email string) string { - return strings.ToLower(strings.TrimSpace(email)) -} diff --git a/oauth-user-auth-plan.md b/oauth-user-auth-plan.md new file mode 100644 index 0000000..5c52487 --- /dev/null +++ b/oauth-user-auth-plan.md @@ -0,0 +1,378 @@ +# OAuth User Authentication for MCP Servers + +## Overview + +Add OAuth authentication support for MCP servers with an interstitial service selection page, allowing users to connect multiple OAuth-based services after Google authentication before returning to Claude.ai. + +## User Flow + +``` +1. Claude.ai → User clicks "Connect" on MCP integration +2. → Redirected to mcp-front Google OAuth +3. → User completes Google OAuth +4. → mcp-front shows interstitial page listing OAuth-requiring services +5. → User optionally connects to services (e.g., Stainless, Linear) +6. → Each service OAuth completes and returns to interstitial +7. → User clicks "Continue to Claude" +8. → mcp-front redirects back to Claude.ai with original auth code +9. → Complete! User connected to mcp-front and optionally to services +``` + +## Config Structure + +### OAuth Authentication + +```json +{ + "mcpServers": { + "stainless": { + "transportType": "stdio", + "command": "stainless", + "args": ["mcp"], + "env": { + "STAINLESS_API_TOKEN": {"$userToken": "{{token}}"} + }, + "requiresUserToken": true, + "userAuthentication": { + "type": "oauth", + "displayName": "Stainless", + "clientId": {"$env": "STAINLESS_OAUTH_CLIENT_ID"}, + "clientSecret": {"$env": "STAINLESS_OAUTH_CLIENT_SECRET"}, + "authorizationUrl": "https://api.stainless.com/oauth/authorize", + "tokenUrl": "https://api.stainless.com/oauth/token", + "scopes": ["mcp:read", "mcp:write"], + "tokenFormat": "Bearer {{token}}" + } + } + } +} +``` + +### Manual Token Authentication + +```json +{ + "mcpServers": { + "notion": { + "transportType": "stdio", + "command": "notion-mcp", + "args": [], + "env": { + "NOTION_API_KEY": {"$userToken": "{{token}}"} + }, + "requiresUserToken": true, + "userAuthentication": { + "type": "manual", + "displayName": "Notion API Token", + "instructions": "Get your token from https://notion.so/my-integrations", + "helpUrl": "https://developers.notion.com/docs/authorization", + "tokenFormat": "{{token}}", + "validation": "^secret_[a-zA-Z0-9]{43}$" + } + } + } +} +``` + +## Implementation Details + +### 1. Interstitial Service Selection Page + +After Google OAuth completes, show a page allowing users to connect OAuth-requiring services: + +```go +// In GoogleCallbackHandler: +func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Request) { + // ... complete Google OAuth ... + + // For OAuth client flow, check if any services need OAuth + if needsServiceAuth := h.checkServicesNeedOAuth(userEmail); needsServiceAuth { + // Encrypt Claude's OAuth state for later + encryptedState := h.encryptOAuthState(OAuthState{ + AuthCode: authCode, + State: originalState, + RedirectURI: originalRedirectURI, + Timestamp: time.Now(), + }) + + // Redirect to service selection page + http.Redirect(w, r, fmt.Sprintf("/oauth/services?state=%s", encryptedState), http.StatusFound) + return + } + + // No services need OAuth, complete Claude flow + h.completeClaudeOAuth(w, r, authCode, originalState, originalRedirectURI) +} + +// Service selection page shows: +// - List of OAuth-requiring services from config +// - Current connection status for each +// - Connect/Reconnect buttons +// - "Continue to Claude" button +``` + +### 2. Service OAuth Endpoints + +```go +// GET /oauth/connect?service={serviceName}&return={encodedReturnURL} +func (h *ServiceAuthHandlers) ConnectHandler(w http.ResponseWriter, r *http.Request) { + serviceName := r.URL.Query().Get("service") + returnURL := r.URL.Query().Get("return") + + // Validate service supports OAuth + serviceConfig := h.mcpServers[serviceName] + if !serviceConfig.RequiresUserToken || + serviceConfig.UserAuthentication.Type != UserAuthTypeOAuth { + // Redirect back with error + return + } + + // Start OAuth flow with service + authURL, err := h.oauthClient.StartOAuthFlow( + ctx, userEmail, serviceName, serviceConfig, returnURL) + + http.Redirect(w, r, authURL, http.StatusFound) +} + +// GET /oauth/callback/{serviceName} +func (h *ServiceAuthHandlers) CallbackHandler(w http.ResponseWriter, r *http.Request) { + serviceName := getServiceFromPath(r.URL.Path) + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + + if errorParam := r.URL.Query().Get("error"); errorParam != "" { + // Service OAuth failed - redirect to interstitial with error + returnURL := h.getReturnURLFromState(state) + errorURL := fmt.Sprintf("%s&error=%s&service=%s", + returnURL, errorParam, serviceName) + http.Redirect(w, r, errorURL, http.StatusFound) + return + } + + // Exchange code for tokens + tokens, returnURL, err := h.oauthClient.HandleCallback( + ctx, serviceName, code, state, serviceConfig) + + // Store encrypted tokens + h.storage.SetUserToken(ctx, userEmail, serviceName, &StoredToken{ + Type: TokenTypeOAuth, + OAuthData: tokens, + UpdatedAt: time.Now(), + }) + + // Redirect back to interstitial page + http.Redirect(w, r, returnURL, http.StatusFound) +} +``` + +### 3. Enhanced Config Types + +```go +type UserAuthType string + +const ( + UserAuthTypeManual UserAuthType = "manual" + UserAuthTypeOAuth UserAuthType = "oauth" +) + +type UserAuthentication struct { + Type UserAuthType `json:"type"` + DisplayName string `json:"displayName"` + + // For OAuth + ClientID json.RawMessage `json:"clientId,omitempty"` + ClientSecret json.RawMessage `json:"clientSecret,omitempty"` + AuthorizationURL string `json:"authorizationUrl,omitempty"` + TokenURL string `json:"tokenUrl,omitempty"` + Scopes []string `json:"scopes,omitempty"` + + // For Manual + Instructions string `json:"instructions,omitempty"` + HelpURL string `json:"helpUrl,omitempty"` + Validation string `json:"validation,omitempty"` + + // Common + TokenFormat string `json:"tokenFormat,omitempty"` + + // Resolved values (not in JSON) + ResolvedClientID string `json:"-"` + ResolvedClientSecret string `json:"-"` + CompiledValidation *regexp.Regexp `json:"-"` +} +``` + +### 4. Token Storage with OAuth Metadata + +```go +type OAuthTokenData struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresAt time.Time `json:"expires_at"` + Scopes []string `json:"scopes"` +} + +type StoredUserToken struct { + Type UserAuthType `json:"type"` + Token string `json:"token,omitempty"` // For manual + OAuthData *OAuthTokenData `json:"oauth,omitempty"` // For OAuth + UpdatedAt time.Time `json:"updated_at"` +} +``` + +### 5. Automatic Token Refresh + +```go +func (h *MCPHandler) getUserTokenIfAvailable(ctx context.Context, userEmail string) (string, error) { + stored, err := h.storage.GetUserToken(ctx, userEmail, h.serverName) + if err != nil { + return "", err + } + + if stored.Type == UserAuthTypeOAuth && stored.OAuthData != nil { + // Check if token needs refresh + if time.Now().After(stored.OAuthData.ExpiresAt.Add(-5 * time.Minute)) { + // Refresh token + client := NewOAuthClient(h.serverConfig.UserAuthentication) + newTokens, err := client.RefreshToken(ctx, stored.OAuthData.RefreshToken) + if err != nil { + return "", fmt.Errorf("token refresh failed: %w", err) + } + + // Update storage + stored.OAuthData = newTokens + stored.UpdatedAt = time.Now() + h.storage.SetUserToken(ctx, userEmail, h.serverName, stored) + } + + // Format token + return formatToken(h.serverConfig.UserAuthentication.TokenFormat, + stored.OAuthData.AccessToken), nil + } + + // Manual token + return formatToken(h.serverConfig.UserAuthentication.TokenFormat, stored.Token), nil +} +``` + +### 6. State Management + +```go +// Encrypted state for preserving Claude OAuth while doing service auth +type OAuthState struct { + AuthCode string `json:"code"` + State string `json:"state"` + RedirectURI string `json:"redirect_uri"` + Timestamp time.Time `json:"timestamp"` +} + +func (h *AuthHandlers) encryptOAuthState(state OAuthState) string { + data, _ := json.Marshal(state) + encrypted, _ := h.encryptor.Encrypt(string(data)) + // Add HMAC for tamper protection + signature := crypto.SignData(encrypted, h.encryptionKey) + return base64.URLEncoding.EncodeToString( + []byte(fmt.Sprintf("%s.%s", encrypted, signature))) +} + +func (h *AuthHandlers) decryptOAuthState(encrypted string) (*OAuthState, error) { + // Verify HMAC signature + // Check timestamp (10 minute expiry) + // Decrypt and unmarshal +} +``` + +### 7. Interstitial Page UI + +```go +// GET /oauth/services?state={encryptedState} +func (h *AuthHandlers) ServiceSelectionHandler(w http.ResponseWriter, r *http.Request) { + encryptedState := r.URL.Query().Get("state") + + // Get OAuth-requiring services and their status + services := []ServiceStatus{} + for name, config := range h.mcpServers { + if config.RequiresUserToken && + config.UserAuthentication.Type == UserAuthTypeOAuth { + + // Check if user already has valid token + token, _ := h.storage.GetUserToken(ctx, userEmail, name) + status := "not_connected" + if token != nil && !token.IsExpired() { + status = "connected" + } + + services = append(services, ServiceStatus{ + Name: name, + DisplayName: config.UserAuthentication.DisplayName, + Status: status, + Error: r.URL.Query().Get("error") == name, + }) + } + } + + // Render template showing: + // - Service list with Connect/Connected/Failed states + // - "Skip for now" and "Continue to Claude" buttons + // - Clear messaging that this is optional +} +``` + +### 8. Routes Configuration + +```go +// OAuth endpoints +mux.Handle("/authorize", authHandlers.AuthorizeHandler) +mux.Handle("/oauth/callback", authHandlers.GoogleCallbackHandler) +mux.Handle("/token", authHandlers.TokenHandler) + +// Service OAuth endpoints +mux.Handle("/oauth/services", authHandlers.ServiceSelectionHandler) +mux.Handle("/oauth/connect", serviceAuthHandlers.ConnectHandler) +mux.Handle("/oauth/callback/", serviceAuthHandlers.CallbackHandler) +mux.Handle("/oauth/complete", authHandlers.CompleteClaudeOAuthHandler) + +// Token management UI +mux.Handle("/my/tokens", tokenHandlers.ListTokensHandler) +``` + +## Benefits + +1. **Clear User Intent**: Users explicitly choose which services to connect +2. **No Protocol Violations**: Standard OAuth flow without custom parameters +3. **Progressive Disclosure**: Only OAuth-requiring services shown +4. **Flexible**: Connect some services now, others later +5. **Error Recovery**: Service OAuth failures don't block Claude connection +6. **Automatic Refresh**: OAuth tokens refreshed transparently + +## Security Considerations + +1. **State parameter**: HMAC-signed to prevent tampering +2. **OAuth secrets**: Encrypted at rest in storage +3. **Token refresh**: Automatic refresh 5 minutes before expiry +4. **Isolation**: Each service has separate OAuth configuration +5. **Time limits**: Encrypted state expires after 10 minutes +6. **Error handling**: Service OAuth failures don't compromise main flow + +## Example: Complete Flow + +1. User in Claude.ai clicks "Connect MCP" +2. Claude.ai redirects to: `https://mcp-front.com/authorize?client_id=claude&redirect_uri=https://claude.ai/callback&state=abc123` +3. User completes Google OAuth +4. mcp-front shows interstitial page: + ``` + Some MCP servers require additional authentication: + + Stainless [Connect] + Linear [Connected ✓] + + [Skip for now] [Continue to Claude] + ``` +5. User clicks "Connect" for Stainless +6. Redirected to: `https://api.stainless.com/oauth/authorize?...` +7. User approves Stainless access +8. Returns to interstitial showing: Stainless [Connected ✓] +9. User clicks "Continue to Claude" +10. mcp-front redirects to: `https://claude.ai/callback?code=...&state=abc123` +11. Complete! User connected to both mcp-front and Stainless +