diff --git a/pkg/http/authorization.go b/pkg/http/authorization.go index c517cb72..099cf268 100644 --- a/pkg/http/authorization.go +++ b/pkg/http/authorization.go @@ -20,7 +20,7 @@ const ( ) // AuthorizationMiddleware validates the OAuth flow using Kubernetes TokenReview API -func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp.Server) func(http.Handler) http.Handler { +func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *oidc.Provider, mcpServer *mcp.Server) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/healthz" || r.URL.Path == "/.well-known/oauth-protected-resource" { @@ -67,8 +67,7 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized) return } - - oidcProvider := mcpServer.GetOIDCProvider() + if oidcProvider != nil { // If OIDC Provider is configured, this token must be validated against it. if err := validateTokenWithOIDC(r.Context(), oidcProvider, token, audience); err != nil { diff --git a/pkg/http/authorization_test.go b/pkg/http/authorization_test.go index 0ffe816e..021df27f 100644 --- a/pkg/http/authorization_test.go +++ b/pkg/http/authorization_test.go @@ -198,7 +198,7 @@ func TestAuthorizationMiddleware(t *testing.T) { handlerCalled = false // Create middleware with OAuth disabled - middleware := AuthorizationMiddleware(false, "", nil) + middleware := AuthorizationMiddleware(false, "", nil, nil) wrappedHandler := middleware(handler) // Create request without authorization header @@ -219,7 +219,7 @@ func TestAuthorizationMiddleware(t *testing.T) { handlerCalled = false // Create middleware with OAuth enabled - middleware := AuthorizationMiddleware(true, "", nil) + middleware := AuthorizationMiddleware(true, "", nil, nil) wrappedHandler := middleware(handler) // Create request to healthz endpoint @@ -240,7 +240,7 @@ func TestAuthorizationMiddleware(t *testing.T) { handlerCalled = false // Create middleware with OAuth enabled - middleware := AuthorizationMiddleware(true, "", nil) + middleware := AuthorizationMiddleware(true, "", nil, nil) wrappedHandler := middleware(handler) // Create request without authorization header @@ -264,7 +264,7 @@ func TestAuthorizationMiddleware(t *testing.T) { handlerCalled = false // Create middleware with OAuth enabled - middleware := AuthorizationMiddleware(true, "", nil) + middleware := AuthorizationMiddleware(true, "", nil, nil) wrappedHandler := middleware(handler) // Create request with invalid bearer token diff --git a/pkg/http/http.go b/pkg/http/http.go index 02f0f060..a12d840d 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "github.com/coreos/go-oidc/v3/oidc" "net/http" "os" "os/signal" @@ -18,11 +19,11 @@ import ( const oauthProtectedResourceEndpoint = "/.well-known/oauth-protected-resource" -func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig) error { +func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig, oidcProvider *oidc.Provider) error { mux := http.NewServeMux() wrappedMux := RequestMiddleware( - AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.ServerURL, mcpServer)(mux), + AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.ServerURL, oidcProvider, mcpServer)(mux), ) httpServer := &http.Server{ diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index c81594a7..fd76a955 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -69,7 +69,7 @@ func (c *httpContext) beforeEach() { timeoutCtx, c.timeoutCancel = context.WithTimeout(c.t.Context(), 10*time.Second) group, gc := errgroup.WithContext(timeoutCtx) cancelCtx, c.stopServer = context.WithCancel(gc) - group.Go(func() error { return Serve(cancelCtx, mcpServer, staticConfig) }) + group.Go(func() error { return Serve(cancelCtx, mcpServer, staticConfig, nil) }) c.waitForShutdown = group.Wait // Wait for HTTP server to start (using net) for i := 0; i < 10; i++ { diff --git a/pkg/kubernetes-mcp-server/cmd/root.go b/pkg/kubernetes-mcp-server/cmd/root.go index dfc260a3..f39ab38b 100644 --- a/pkg/kubernetes-mcp-server/cmd/root.go +++ b/pkg/kubernetes-mcp-server/cmd/root.go @@ -259,7 +259,6 @@ func (m *MCPServerOptions) Run() error { Profile: profile, ListOutput: listOutput, StaticConfig: m.StaticConfig, - OIDCProvider: oidcProvider, }) if err != nil { return fmt.Errorf("Failed to initialize MCP server: %w\n", err) @@ -268,7 +267,7 @@ func (m *MCPServerOptions) Run() error { if m.StaticConfig.Port != "" { ctx := context.Background() - return internalhttp.Serve(ctx, mcpServer, m.StaticConfig) + return internalhttp.Serve(ctx, mcpServer, m.StaticConfig, oidcProvider) } if err := mcpServer.ServeStdio(); err != nil && !errors.Is(err, context.Canceled) { diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 91646172..d3f1ca50 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -6,7 +6,6 @@ import ( "net/http" "slices" - "github.com/coreos/go-oidc/v3/oidc" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" authenticationapiv1 "k8s.io/api/authentication/v1" @@ -19,9 +18,8 @@ import ( ) type Configuration struct { - Profile Profile - ListOutput output.Output - OIDCProvider *oidc.Provider + Profile Profile + ListOutput output.Output StaticConfig *config.StaticConfig } @@ -124,13 +122,6 @@ func (s *Server) GetKubernetesAPIServerHost() string { return s.k.GetAPIServerHost() } -func (s *Server) GetOIDCProvider() *oidc.Provider { - if s.configuration.OIDCProvider == nil { - return nil - } - return s.configuration.OIDCProvider -} - func (s *Server) Close() { if s.k != nil { s.k.Close()