Skip to content

Commit 73e9e84

Browse files
authored
refactor(auth): carry oidc provider directly instead of mcpServer
1 parent cb9f296 commit 73e9e84

File tree

6 files changed

+13
-23
lines changed

6 files changed

+13
-23
lines changed

pkg/http/authorization.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ const (
2020
)
2121

2222
// AuthorizationMiddleware validates the OAuth flow using Kubernetes TokenReview API
23-
func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp.Server) func(http.Handler) http.Handler {
23+
func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *oidc.Provider, mcpServer *mcp.Server) func(http.Handler) http.Handler {
2424
return func(next http.Handler) http.Handler {
2525
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2626
if r.URL.Path == healthEndpoint || r.URL.Path == oauthProtectedResourceEndpoint {
@@ -67,8 +67,7 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp
6767
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
6868
return
6969
}
70-
71-
oidcProvider := mcpServer.GetOIDCProvider()
70+
7271
if oidcProvider != nil {
7372
// If OIDC Provider is configured, this token must be validated against it.
7473
if err := validateTokenWithOIDC(r.Context(), oidcProvider, token, audience); err != nil {

pkg/http/authorization_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ func TestAuthorizationMiddleware(t *testing.T) {
198198
handlerCalled = false
199199

200200
// Create middleware with OAuth disabled
201-
middleware := AuthorizationMiddleware(false, "", nil)
201+
middleware := AuthorizationMiddleware(false, "", nil, nil)
202202
wrappedHandler := middleware(handler)
203203

204204
// Create request without authorization header
@@ -219,7 +219,7 @@ func TestAuthorizationMiddleware(t *testing.T) {
219219
handlerCalled = false
220220

221221
// Create middleware with OAuth enabled
222-
middleware := AuthorizationMiddleware(true, "", nil)
222+
middleware := AuthorizationMiddleware(true, "", nil, nil)
223223
wrappedHandler := middleware(handler)
224224

225225
// Create request to healthz endpoint
@@ -240,7 +240,7 @@ func TestAuthorizationMiddleware(t *testing.T) {
240240
handlerCalled = false
241241

242242
// Create middleware with OAuth enabled
243-
middleware := AuthorizationMiddleware(true, "", nil)
243+
middleware := AuthorizationMiddleware(true, "", nil, nil)
244244
wrappedHandler := middleware(handler)
245245

246246
// Create request without authorization header
@@ -264,7 +264,7 @@ func TestAuthorizationMiddleware(t *testing.T) {
264264
handlerCalled = false
265265

266266
// Create middleware with OAuth enabled
267-
middleware := AuthorizationMiddleware(true, "", nil)
267+
middleware := AuthorizationMiddleware(true, "", nil, nil)
268268
wrappedHandler := middleware(handler)
269269

270270
// Create request with invalid bearer token

pkg/http/http.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"errors"
7+
"github.com/coreos/go-oidc/v3/oidc"
78
"net/http"
89
"os"
910
"os/signal"
@@ -24,11 +25,11 @@ const (
2425
sseMessageEndpoint = "/message"
2526
)
2627

27-
func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig) error {
28+
func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.StaticConfig, oidcProvider *oidc.Provider) error {
2829
mux := http.NewServeMux()
2930

3031
wrappedMux := RequestMiddleware(
31-
AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.ServerURL, mcpServer)(mux),
32+
AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.ServerURL, oidcProvider, mcpServer)(mux),
3233
)
3334

3435
httpServer := &http.Server{

pkg/http/http_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (c *httpContext) beforeEach() {
6969
timeoutCtx, c.timeoutCancel = context.WithTimeout(c.t.Context(), 10*time.Second)
7070
group, gc := errgroup.WithContext(timeoutCtx)
7171
cancelCtx, c.stopServer = context.WithCancel(gc)
72-
group.Go(func() error { return Serve(cancelCtx, mcpServer, staticConfig) })
72+
group.Go(func() error { return Serve(cancelCtx, mcpServer, staticConfig, nil) })
7373
c.waitForShutdown = group.Wait
7474
// Wait for HTTP server to start (using net)
7575
for i := 0; i < 10; i++ {

pkg/kubernetes-mcp-server/cmd/root.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ func (m *MCPServerOptions) Run() error {
259259
Profile: profile,
260260
ListOutput: listOutput,
261261
StaticConfig: m.StaticConfig,
262-
OIDCProvider: oidcProvider,
263262
})
264263
if err != nil {
265264
return fmt.Errorf("Failed to initialize MCP server: %w\n", err)
@@ -268,7 +267,7 @@ func (m *MCPServerOptions) Run() error {
268267

269268
if m.StaticConfig.Port != "" {
270269
ctx := context.Background()
271-
return internalhttp.Serve(ctx, mcpServer, m.StaticConfig)
270+
return internalhttp.Serve(ctx, mcpServer, m.StaticConfig, oidcProvider)
272271
}
273272

274273
if err := mcpServer.ServeStdio(); err != nil && !errors.Is(err, context.Canceled) {

pkg/mcp/mcp.go

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"net/http"
77
"slices"
88

9-
"github.com/coreos/go-oidc/v3/oidc"
109
"github.com/mark3labs/mcp-go/mcp"
1110
"github.com/mark3labs/mcp-go/server"
1211
authenticationapiv1 "k8s.io/api/authentication/v1"
@@ -19,9 +18,8 @@ import (
1918
)
2019

2120
type Configuration struct {
22-
Profile Profile
23-
ListOutput output.Output
24-
OIDCProvider *oidc.Provider
21+
Profile Profile
22+
ListOutput output.Output
2523

2624
StaticConfig *config.StaticConfig
2725
}
@@ -124,13 +122,6 @@ func (s *Server) GetKubernetesAPIServerHost() string {
124122
return s.k.GetAPIServerHost()
125123
}
126124

127-
func (s *Server) GetOIDCProvider() *oidc.Provider {
128-
if s.configuration.OIDCProvider == nil {
129-
return nil
130-
}
131-
return s.configuration.OIDCProvider
132-
}
133-
134125
func (s *Server) Close() {
135126
if s.k != nil {
136127
s.k.Close()

0 commit comments

Comments
 (0)