From 1b6ec5539c646e4732023db106f529fe6d8511c5 Mon Sep 17 00:00:00 2001 From: Samuel El-Borai Date: Sun, 22 Jun 2025 16:52:24 +0200 Subject: [PATCH] feat: implement MCP Standard 2025-06-18 simple compliance changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add WWW-Authenticate headers to 401 responses with configurable realm - Add OAuth Protected Resource Metadata endpoint at /.well-known/oauth-protected-resource - Use proxy name from config as realm instead of hardcoded value - Add urlutil package for proper URL path joining - Include comprehensive tests for all new functionality No breaking changes - fully backward compatible with existing clients. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- internal/json/writer.go | 12 +++++ internal/json/writer_test.go | 81 ++++++++++++++++++++++++++++++ internal/oauth/metadata.go | 33 +++++++++++++ internal/oauth/metadata_test.go | 68 ++++++++++++++++++++++++++ internal/oauth/oauth.go | 25 +++++++--- internal/server/handler.go | 21 +++++++- internal/server/middleware.go | 14 ++++-- internal/urlutil/join.go | 35 +++++++++++++ internal/urlutil/join_test.go | 87 +++++++++++++++++++++++++++++++++ 9 files changed, 364 insertions(+), 12 deletions(-) create mode 100644 internal/json/writer_test.go create mode 100644 internal/oauth/metadata.go create mode 100644 internal/oauth/metadata_test.go create mode 100644 internal/urlutil/join.go create mode 100644 internal/urlutil/join_test.go diff --git a/internal/json/writer.go b/internal/json/writer.go index de1f20d..6fa3a18 100644 --- a/internal/json/writer.go +++ b/internal/json/writer.go @@ -2,6 +2,7 @@ package json import ( "encoding/json" + "fmt" "net/http" "github.com/dgellow/mcp-front/internal" @@ -48,6 +49,17 @@ func WriteUnauthorized(w http.ResponseWriter, message string) { WriteError(w, http.StatusUnauthorized, "unauthorized", message) } +// WriteUnauthorizedWithChallenge writes a 401 Unauthorized response with WWW-Authenticate header +// This is used for MCP Standard 2025-06-18 compliance +func WriteUnauthorizedWithChallenge(w http.ResponseWriter, message string, realm string, resourceMetadataURI string) { + if resourceMetadataURI != "" && realm != "" { + // Format: Bearer realm="example", as_uri="https://example.com/.well-known/oauth-protected-resource" + challenge := fmt.Sprintf(`Bearer realm="%s", as_uri="%s"`, realm, resourceMetadataURI) + w.Header().Set("WWW-Authenticate", challenge) + } + WriteError(w, http.StatusUnauthorized, "unauthorized", message) +} + func WriteInternalServerError(w http.ResponseWriter, message string) { WriteError(w, http.StatusInternalServerError, "internal_server_error", message) } diff --git a/internal/json/writer_test.go b/internal/json/writer_test.go new file mode 100644 index 0000000..d98146f --- /dev/null +++ b/internal/json/writer_test.go @@ -0,0 +1,81 @@ +package json + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestWriteUnauthorizedWithChallenge(t *testing.T) { + tests := []struct { + name string + message string + realm string + resourceMetadataURI string + wantHeader string + wantStatus int + }{ + { + name: "with resource metadata URI and realm", + message: "Invalid token", + realm: "TestProxy", + resourceMetadataURI: "https://example.com/.well-known/oauth-protected-resource", + wantHeader: `Bearer realm="TestProxy", as_uri="https://example.com/.well-known/oauth-protected-resource"`, + wantStatus: http.StatusUnauthorized, + }, + { + name: "without resource metadata URI", + message: "Invalid token", + realm: "TestProxy", + resourceMetadataURI: "", + wantHeader: "", + wantStatus: http.StatusUnauthorized, + }, + { + name: "without realm", + message: "Invalid token", + realm: "", + resourceMetadataURI: "https://example.com/.well-known/oauth-protected-resource", + wantHeader: "", + wantStatus: http.StatusUnauthorized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + + WriteUnauthorizedWithChallenge(w, tt.message, tt.realm, tt.resourceMetadataURI) + + if w.Code != tt.wantStatus { + t.Errorf("status = %v, want %v", w.Code, tt.wantStatus) + } + + gotHeader := w.Header().Get("WWW-Authenticate") + if gotHeader != tt.wantHeader { + t.Errorf("WWW-Authenticate header = %q, want %q", gotHeader, tt.wantHeader) + } + + // Check that response contains error message + body := w.Body.String() + if body == "" { + t.Error("expected non-empty response body") + } + }) + } +} + +func TestWriteUnauthorized(t *testing.T) { + w := httptest.NewRecorder() + + WriteUnauthorized(w, "Test error") + + if w.Code != http.StatusUnauthorized { + t.Errorf("status = %v, want %v", w.Code, http.StatusUnauthorized) + } + + // Should not have WWW-Authenticate header + if header := w.Header().Get("WWW-Authenticate"); header != "" { + t.Errorf("unexpected WWW-Authenticate header: %q", header) + } +} \ No newline at end of file diff --git a/internal/oauth/metadata.go b/internal/oauth/metadata.go new file mode 100644 index 0000000..2090de4 --- /dev/null +++ b/internal/oauth/metadata.go @@ -0,0 +1,33 @@ +package oauth + +import ( + "net/http" + + "github.com/dgellow/mcp-front/internal" + jsonwriter "github.com/dgellow/mcp-front/internal/json" + "github.com/dgellow/mcp-front/internal/urlutil" +) + +// ProtectedResourceMetadataHandler serves OAuth 2.0 Protected Resource Metadata (RFC 9728) +// This endpoint helps clients discover which authorization servers this resource server trusts +func (s *Server) ProtectedResourceMetadataHandler(w http.ResponseWriter, r *http.Request) { + // Build the metadata response + metadata := map[string]interface{}{ + "resource": s.config.Issuer, // The canonical URI of this resource server + "authorization_servers": []string{ + s.config.Issuer, // We are our own authorization server + }, + "_links": map[string]interface{}{ + "oauth-authorization-server": map[string]string{ + "href": urlutil.MustJoinPath(s.config.Issuer, ".well-known", "oauth-authorization-server"), + }, + }, + } + + // Write the response + if err := jsonwriter.Write(w, metadata); err != nil { + internal.LogErrorWithFields("oauth", "Failed to encode protected resource metadata", map[string]interface{}{ + "error": err.Error(), + }) + } +} \ No newline at end of file diff --git a/internal/oauth/metadata_test.go b/internal/oauth/metadata_test.go new file mode 100644 index 0000000..1af7b7a --- /dev/null +++ b/internal/oauth/metadata_test.go @@ -0,0 +1,68 @@ +package oauth + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/dgellow/mcp-front/internal/storage" +) + +func TestProtectedResourceMetadataHandler(t *testing.T) { + // Create a test server + config := Config{ + Issuer: "https://example.com", + ProxyName: "test-proxy", + EncryptionKey: "test-encryption-key-32-bytes----", // Exactly 32 bytes for AES-256 + JWTSecret: "test-jwt-secret-at-least-32-bytes-long", + } + + store := storage.NewMemoryStorage() + server, err := NewServer(config, store) + if err != nil { + t.Fatalf("Failed to create OAuth server: %v", err) + } + + // Create test request + req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil) + w := httptest.NewRecorder() + + // Call the handler + server.ProtectedResourceMetadataHandler(w, req) + + // Check response + if w.Code != http.StatusOK { + t.Errorf("status = %v, want %v", w.Code, http.StatusOK) + } + + // Parse response + var metadata map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &metadata); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + // Check required fields + if resource, ok := metadata["resource"].(string); !ok || resource != "https://example.com" { + t.Errorf("resource = %v, want %v", metadata["resource"], "https://example.com") + } + + if authServers, ok := metadata["authorization_servers"].([]interface{}); !ok || len(authServers) != 1 { + t.Errorf("authorization_servers = %v, want array with 1 element", metadata["authorization_servers"]) + } else if authServers[0] != "https://example.com" { + t.Errorf("authorization_servers[0] = %v, want %v", authServers[0], "https://example.com") + } + + // Check _links + if links, ok := metadata["_links"].(map[string]interface{}); !ok { + t.Error("_links field missing or not an object") + } else { + if authServerLink, ok := links["oauth-authorization-server"].(map[string]interface{}); !ok { + t.Error("oauth-authorization-server link missing") + } else { + if href, ok := authServerLink["href"].(string); !ok || href != "https://example.com/.well-known/oauth-authorization-server" { + t.Errorf("oauth-authorization-server href = %v, want %v", href, "https://example.com/.well-known/oauth-authorization-server") + } + } + } +} \ No newline at end of file diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 8e9c44b..9acfa99 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -14,6 +14,7 @@ import ( "github.com/dgellow/mcp-front/internal/crypto" jsonwriter "github.com/dgellow/mcp-front/internal/json" "github.com/dgellow/mcp-front/internal/storage" + "github.com/dgellow/mcp-front/internal/urlutil" "github.com/ory/fosite" "github.com/ory/fosite/compose" ) @@ -47,6 +48,7 @@ type Server struct { // Config holds OAuth server configuration type Config struct { Issuer string + ProxyName string // Name of the proxy for WWW-Authenticate realm TokenTTL time.Duration SessionDuration time.Duration // Duration for browser session cookies (default: 24h) AllowedDomains []string @@ -149,9 +151,9 @@ func NewServer(config Config, store storage.Storage) (*Server, error) { func (s *Server) WellKnownHandler(w http.ResponseWriter, r *http.Request) { metadata := map[string]interface{}{ "issuer": s.config.Issuer, - "authorization_endpoint": s.config.Issuer + "/authorize", - "token_endpoint": s.config.Issuer + "/token", - "registration_endpoint": s.config.Issuer + "/register", + "authorization_endpoint": urlutil.MustJoinPath(s.config.Issuer, "authorize"), + "token_endpoint": urlutil.MustJoinPath(s.config.Issuer, "token"), + "registration_endpoint": urlutil.MustJoinPath(s.config.Issuer, "register"), "scopes_supported": []string{ "read", "write", @@ -170,8 +172,8 @@ func (s *Server) WellKnownHandler(w http.ResponseWriter, r *http.Request) { "none", "client_secret_post", }, - "revocation_endpoint": s.config.Issuer + "/revoke", - "introspection_endpoint": s.config.Issuer + "/introspect", + "revocation_endpoint": urlutil.MustJoinPath(s.config.Issuer, "revoke"), + "introspection_endpoint": urlutil.MustJoinPath(s.config.Issuer, "introspect"), } if err := jsonwriter.Write(w, metadata); err != nil { @@ -520,6 +522,13 @@ func (s *Server) DebugClientsHandler(w http.ResponseWriter, r *http.Request) { // ValidateTokenMiddleware creates middleware that validates OAuth tokens func (s *Server) ValidateTokenMiddleware() func(http.Handler) http.Handler { + // Build the resource metadata URI once + resourceMetadataURI := urlutil.MustJoinPath(s.config.Issuer, ".well-known", "oauth-protected-resource") + realm := s.config.ProxyName + if realm == "" { + realm = "OAuth" // Fallback if proxy name not configured + } + return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -527,13 +536,13 @@ func (s *Server) ValidateTokenMiddleware() func(http.Handler) http.Handler { // Extract token from Authorization header auth := r.Header.Get("Authorization") if auth == "" { - jsonwriter.WriteUnauthorized(w, "Missing authorization header") + jsonwriter.WriteUnauthorizedWithChallenge(w, "Missing authorization header", realm, resourceMetadataURI) return } parts := strings.Split(auth, " ") if len(parts) != 2 || parts[0] != "Bearer" { - jsonwriter.WriteUnauthorized(w, "Invalid authorization header format") + jsonwriter.WriteUnauthorizedWithChallenge(w, "Invalid authorization header format", realm, resourceMetadataURI) return } @@ -548,7 +557,7 @@ func (s *Server) ValidateTokenMiddleware() func(http.Handler) http.Handler { session := &Session{DefaultSession: &fosite.DefaultSession{}} _, accessRequest, err := s.provider.IntrospectToken(ctx, token, fosite.AccessToken, session) if err != nil { - jsonwriter.WriteUnauthorized(w, "Invalid or expired token") + jsonwriter.WriteUnauthorizedWithChallenge(w, "Invalid or expired token", realm, resourceMetadataURI) return } diff --git a/internal/server/handler.go b/internal/server/handler.go index 17fd99a..174a884 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -2,6 +2,7 @@ package server import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -134,6 +135,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) { oauthConfig := oauth.Config{ Issuer: oauthAuth.Issuer, + ProxyName: cfg.Proxy.Name, TokenTTL: ttl, AllowedDomains: oauthAuth.AllowedDomains, AllowedOrigins: oauthAuth.AllowedOrigins, @@ -184,6 +186,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) { } mux.Handle("/.well-known/oauth-authorization-server", chainMiddleware(http.HandlerFunc(s.oauthServer.WellKnownHandler), oauthMiddlewares...)) + mux.Handle("/.well-known/oauth-protected-resource", chainMiddleware(http.HandlerFunc(s.oauthServer.ProtectedResourceMetadataHandler), oauthMiddlewares...)) mux.Handle("/authorize", chainMiddleware(http.HandlerFunc(s.oauthServer.AuthorizeHandler), oauthMiddlewares...)) mux.Handle("/oauth/callback", chainMiddleware(http.HandlerFunc(s.oauthServer.GoogleCallbackHandler), oauthMiddlewares...)) mux.Handle("/token", chainMiddleware(http.HandlerFunc(s.oauthServer.TokenHandler), oauthMiddlewares...)) @@ -354,7 +357,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) { middlewares = append(middlewares, s.oauthServer.ValidateTokenMiddleware()) } else if serverConfig.Options != nil && len(serverConfig.Options.AuthTokens) > 0 { // Bearer token authentication - request must include valid bearer token - middlewares = append(middlewares, newAuthMiddleware(serverConfig.Options.AuthTokens)) + middlewares = append(middlewares, newAuthMiddleware(serverConfig.Options.AuthTokens, baseURL.String(), cfg.Proxy.Name)) } // else: no auth required for this endpoint @@ -402,6 +405,22 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"status":"ok"}`)) }) + + // If OAuth is not configured, still provide the protected resource metadata endpoint + // This ensures compliance with MCP Standard 2025-06-18 even for bearer token auth + if s.oauthServer == nil { + mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) { + metadata := map[string]interface{}{ + "resource": baseURL.String(), + "authorization_servers": []string{}, // No authorization servers for bearer token auth + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(metadata); err != nil { + internal.LogError("Failed to encode protected resource metadata: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + }) + } internal.LogInfoWithFields("server", "MCP proxy server initialized", nil) return s, nil diff --git a/internal/server/middleware.go b/internal/server/middleware.go index b34e26d..96cfa19 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -11,6 +11,7 @@ import ( jsonwriter "github.com/dgellow/mcp-front/internal/json" "github.com/dgellow/mcp-front/internal/oauth" "github.com/dgellow/mcp-front/internal/storage" + "github.com/dgellow/mcp-front/internal/urlutil" ) // MiddlewareFunc is a function that wraps an http.Handler @@ -166,24 +167,31 @@ func recoverMiddleware(prefix string) MiddlewareFunc { } // newAuthMiddleware creates middleware for bearer token authentication -func newAuthMiddleware(tokens []string) MiddlewareFunc { +func newAuthMiddleware(tokens []string, baseURL string, realm string) MiddlewareFunc { tokenSet := make(map[string]struct{}, len(tokens)) for _, token := range tokens { tokenSet[token] = struct{}{} } + + // Build the resource metadata URI + resourceMetadataURI := urlutil.MustJoinPath(baseURL, ".well-known", "oauth-protected-resource") + if realm == "" { + realm = "Bearer" // Fallback if realm not provided + } + return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if len(tokens) != 0 { authHeader := r.Header.Get("Authorization") if !strings.HasPrefix(authHeader, "Bearer ") { - jsonwriter.WriteUnauthorized(w, "Unauthorized") + jsonwriter.WriteUnauthorizedWithChallenge(w, "Unauthorized", realm, resourceMetadataURI) return } token := authHeader[7:] // Extract the actual token if _, ok := tokenSet[token]; !ok { - jsonwriter.WriteUnauthorized(w, "Unauthorized") + jsonwriter.WriteUnauthorizedWithChallenge(w, "Unauthorized", realm, resourceMetadataURI) return } } diff --git a/internal/urlutil/join.go b/internal/urlutil/join.go new file mode 100644 index 0000000..10d03d7 --- /dev/null +++ b/internal/urlutil/join.go @@ -0,0 +1,35 @@ +package urlutil + +import ( + "net/url" + "path" + "strings" +) + +// JoinPath safely joins URL paths, handling trailing and leading slashes correctly +func JoinPath(base string, paths ...string) (string, error) { + u, err := url.Parse(base) + if err != nil { + return "", err + } + + // Join paths, ensuring proper slash handling + allPaths := append([]string{u.Path}, paths...) + u.Path = path.Join(allPaths...) + + // Preserve trailing slash if the last path component had one + if len(paths) > 0 && strings.HasSuffix(paths[len(paths)-1], "/") { + u.Path += "/" + } + + return u.String(), nil +} + +// MustJoinPath is like JoinPath but panics on error (for use with known-good URLs) +func MustJoinPath(base string, paths ...string) string { + result, err := JoinPath(base, paths...) + if err != nil { + panic(err) + } + return result +} \ No newline at end of file diff --git a/internal/urlutil/join_test.go b/internal/urlutil/join_test.go new file mode 100644 index 0000000..41b49ff --- /dev/null +++ b/internal/urlutil/join_test.go @@ -0,0 +1,87 @@ +package urlutil + +import ( + "testing" +) + +func TestJoinPath(t *testing.T) { + tests := []struct { + name string + base string + paths []string + want string + wantErr bool + }{ + { + name: "simple join", + base: "https://example.com", + paths: []string{"api", "v1"}, + want: "https://example.com/api/v1", + }, + { + name: "base with path", + base: "https://example.com/base", + paths: []string{"api", "v1"}, + want: "https://example.com/base/api/v1", + }, + { + name: "trailing slash preserved", + base: "https://example.com", + paths: []string{"api", "v1/"}, + want: "https://example.com/api/v1/", + }, + { + name: "well-known path", + base: "https://example.com", + paths: []string{".well-known", "oauth-protected-resource"}, + want: "https://example.com/.well-known/oauth-protected-resource", + }, + { + name: "empty paths", + base: "https://example.com", + paths: []string{}, + want: "https://example.com", + }, + { + name: "base with trailing slash", + base: "https://example.com/", + paths: []string{"api"}, + want: "https://example.com/api", + }, + { + name: "invalid base URL", + base: "://invalid", + paths: []string{"api"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := JoinPath(tt.base, tt.paths...) + if (err != nil) != tt.wantErr { + t.Errorf("JoinPath() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.want { + t.Errorf("JoinPath() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMustJoinPath(t *testing.T) { + // Test normal operation + result := MustJoinPath("https://example.com", "api", "v1") + if result != "https://example.com/api/v1" { + t.Errorf("MustJoinPath() = %v, want %v", result, "https://example.com/api/v1") + } + + // Test panic on invalid URL + defer func() { + if r := recover(); r == nil { + t.Errorf("MustJoinPath() should have panicked") + } + }() + MustJoinPath("://invalid", "api") +} \ No newline at end of file