Skip to content

Commit dfcecd5

Browse files
authored
feat(auth): configurable Kubernetes API token validation (#252)
Signed-off-by: Marc Nuri <[email protected]>
1 parent 7b11c16 commit dfcecd5

File tree

5 files changed

+102
-84
lines changed

5 files changed

+102
-84
lines changed

pkg/config/config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@ type StaticConfig struct {
2323
EnabledTools []string `toml:"enabled_tools,omitempty"`
2424
DisabledTools []string `toml:"disabled_tools,omitempty"`
2525
RequireOAuth bool `toml:"require_oauth,omitempty"`
26+
27+
//Authorization related fields
2628
// OAuthAudience is the valid audience for the OAuth tokens, used for offline JWT claim validation.
2729
OAuthAudience string `toml:"oauth_audience,omitempty"`
30+
// ValidateToken indicates whether the server should validate the token against the Kubernetes API Server using TokenReview.
31+
ValidateToken bool `toml:"validate_token,omitempty"`
2832
// AuthorizationURL is the URL of the OIDC authorization server.
2933
// It is used for token validation and for STS token exchange.
3034
AuthorizationURL string `toml:"authorization_url,omitempty"`

pkg/http/authorization.go

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,15 @@ import (
66
"net/http"
77
"strings"
88

9+
"github.com/containers/kubernetes-mcp-server/pkg/mcp"
910
"github.com/coreos/go-oidc/v3/oidc"
1011
"github.com/go-jose/go-jose/v4"
1112
"github.com/go-jose/go-jose/v4/jwt"
1213
authenticationapiv1 "k8s.io/api/authentication/v1"
1314
"k8s.io/klog/v2"
1415
"k8s.io/utils/strings/slices"
1516

16-
"github.com/containers/kubernetes-mcp-server/pkg/mcp"
17-
)
18-
19-
const (
20-
Audience = "mcp-server"
17+
"github.com/containers/kubernetes-mcp-server/pkg/config"
2118
)
2219

2320
type KubernetesApiTokenVerifier interface {
@@ -42,30 +39,31 @@ type KubernetesApiTokenVerifier interface {
4239
//
4340
// 2.1. Raw Token Validation (oidcProvider is nil):
4441
// - The token is validated offline for basic sanity checks (expiration).
45-
// - If audience is set, the token is validated against the audience.
46-
// - The token is then used against the Kubernetes API Server for TokenReview.
42+
// - If OAuthAudience is set, the token is validated against the audience.
43+
// - If ValidateToken is set, the token is then used against the Kubernetes API Server for TokenReview.
4744
//
4845
// 2.2. OIDC Provider Validation (oidcProvider is not nil):
4946
// - The token is validated offline for basic sanity checks (audience and expiration).
47+
// - If OAuthAudience is set, the token is validated against the audience.
5048
// - The token is then validated against the OIDC Provider.
51-
// - The token is then used against the Kubernetes API Server for TokenReview.
49+
// - If ValidateToken is set, the token is then used against the Kubernetes API Server for TokenReview.
5250
//
5351
// 2.3. OIDC Token Exchange (oidcProvider is not nil and xxx):
54-
func AuthorizationMiddleware(requireOAuth bool, audience string, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler {
52+
func AuthorizationMiddleware(staticConfig *config.StaticConfig, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler {
5553
return func(next http.Handler) http.Handler {
5654
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
5755
if r.URL.Path == healthEndpoint || slices.Contains(WellKnownEndpoints, r.URL.EscapedPath()) {
5856
next.ServeHTTP(w, r)
5957
return
6058
}
61-
if !requireOAuth {
59+
if !staticConfig.RequireOAuth {
6260
next.ServeHTTP(w, r)
6361
return
6462
}
6563

6664
wwwAuthenticateHeader := "Bearer realm=\"Kubernetes MCP Server\""
67-
if audience != "" {
68-
wwwAuthenticateHeader += fmt.Sprintf(`, audience="%s"`, audience)
65+
if staticConfig.OAuthAudience != "" {
66+
wwwAuthenticateHeader += fmt.Sprintf(`, audience="%s"`, staticConfig.OAuthAudience)
6967
}
7068

7169
authHeader := r.Header.Get("Authorization")
@@ -80,11 +78,27 @@ func AuthorizationMiddleware(requireOAuth bool, audience string, oidcProvider *o
8078
token := strings.TrimPrefix(authHeader, "Bearer ")
8179

8280
claims, err := ParseJWTClaims(token)
83-
if err == nil && claims != nil {
84-
err = claims.ValidateOffline(audience)
81+
if err == nil && claims == nil {
82+
// Impossible case, but just in case
83+
err = fmt.Errorf("failed to parse JWT claims from token")
84+
}
85+
// Offline validation
86+
if err == nil {
87+
err = claims.ValidateOffline(staticConfig.OAuthAudience)
8588
}
86-
if err == nil && claims != nil {
87-
err = claims.ValidateWithProvider(r.Context(), audience, oidcProvider)
89+
// Online OIDC provider validation
90+
if err == nil {
91+
err = claims.ValidateWithProvider(r.Context(), staticConfig.OAuthAudience, oidcProvider)
92+
}
93+
// Scopes propagation, they are likely to be used for authorization.
94+
if err == nil {
95+
scopes := claims.GetScopes()
96+
klog.V(2).Infof("JWT token validated - Scopes: %v", scopes)
97+
r = r.WithContext(context.WithValue(r.Context(), mcp.TokenScopesContextKey, scopes))
98+
}
99+
// Kubernetes API Server TokenReview validation
100+
if err == nil && staticConfig.ValidateToken {
101+
err = claims.ValidateWithKubernetesApi(r.Context(), staticConfig.OAuthAudience, verifier)
88102
}
89103
if err != nil {
90104
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
@@ -94,32 +108,6 @@ func AuthorizationMiddleware(requireOAuth bool, audience string, oidcProvider *o
94108
return
95109
}
96110

97-
// Scopes are likely to be used for authorization.
98-
scopes := claims.GetScopes()
99-
klog.V(2).Infof("JWT token validated - Scopes: %v", scopes)
100-
r = r.WithContext(context.WithValue(r.Context(), mcp.TokenScopesContextKey, scopes))
101-
102-
// Now, there are a couple of options:
103-
// 1. If there is no authorization url configured for this MCP Server,
104-
// that means this token will be used against the Kubernetes API Server.
105-
// So that we need to validate the token using Kubernetes TokenReview API beforehand.
106-
// 2. If there is an authorization url configured for this MCP Server,
107-
// that means up to this point, the token is validated against the OIDC Provider already.
108-
// 2. a. If this is the only token in the headers, this validated token
109-
// is supposed to be used against the Kubernetes API Server as well. Therefore,
110-
// TokenReview request must succeed.
111-
// 2. b. If this is not the only token in the headers, the token in here is used
112-
// only for authentication and authorization. Therefore, we need to send TokenReview request
113-
// with the other token in the headers (TODO: still need to validate aud and exp of this token separately).
114-
_, _, err = verifier.KubernetesApiVerifyToken(r.Context(), token, audience)
115-
if err != nil {
116-
klog.V(1).Infof("Authentication failed - API Server token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
117-
118-
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
119-
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
120-
return
121-
}
122-
123111
next.ServeHTTP(w, r)
124112
})
125113
}
@@ -180,6 +168,16 @@ func (c *JWTClaims) ValidateWithProvider(ctx context.Context, audience string, p
180168
return nil
181169
}
182170

171+
func (c *JWTClaims) ValidateWithKubernetesApi(ctx context.Context, audience string, verifier KubernetesApiTokenVerifier) error {
172+
if verifier != nil {
173+
_, _, err := verifier.KubernetesApiVerifyToken(ctx, c.Token, audience)
174+
if err != nil {
175+
return fmt.Errorf("kubernetes API token validation error: %v", err)
176+
}
177+
}
178+
return nil
179+
}
180+
183181
func ParseJWTClaims(token string) (*JWTClaims, error) {
184182
tkn, err := jwt.ParseSigned(token, allSignatureAlgorithms)
185183
if err != nil {

pkg/http/http.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat
2828
mux := http.NewServeMux()
2929

3030
wrappedMux := RequestMiddleware(
31-
AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.OAuthAudience, oidcProvider, mcpServer)(mux),
31+
AuthorizationMiddleware(staticConfig, oidcProvider, mcpServer)(mux),
3232
)
3333

3434
httpServer := &http.Server{

pkg/http/http_test.go

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ func TestHealthCheck(t *testing.T) {
284284
})
285285
})
286286
// Health exposed even when require Authorization
287-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
287+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
288288
resp, err := http.Get(fmt.Sprintf("http://%s/healthz", ctx.HttpAddress))
289289
if err != nil {
290290
t.Fatalf("Failed to get health check endpoint with OAuth: %v", err)
@@ -305,7 +305,7 @@ func TestWellKnownReverseProxy(t *testing.T) {
305305
".well-known/openid-configuration",
306306
}
307307
// With No Authorization URL configured
308-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
308+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
309309
for _, path := range cases {
310310
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
311311
t.Cleanup(func() { _ = resp.Body.Close() })
@@ -329,7 +329,7 @@ func TestWellKnownReverseProxy(t *testing.T) {
329329
_, _ = w.Write([]byte(`{"issuer": "https://example.com","scopes_supported":["mcp-server"]}`))
330330
}))
331331
t.Cleanup(testServer.Close)
332-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true}}, func(ctx *httpContext) {
332+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{AuthorizationURL: testServer.URL, RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
333333
for _, path := range cases {
334334
resp, err := http.Get(fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path))
335335
t.Cleanup(func() { _ = resp.Body.Close() })
@@ -377,7 +377,7 @@ func TestMiddlewareLogging(t *testing.T) {
377377

378378
func TestAuthorizationUnauthorized(t *testing.T) {
379379
// Missing Authorization header
380-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
380+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
381381
resp, err := http.Get(fmt.Sprintf("http://%s/mcp", ctx.HttpAddress))
382382
if err != nil {
383383
t.Fatalf("Failed to get protected endpoint: %v", err)
@@ -402,7 +402,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
402402
})
403403
})
404404
// Authorization header without Bearer prefix
405-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
405+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
406406
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
407407
if err != nil {
408408
t.Fatalf("Failed to create request: %v", err)
@@ -427,7 +427,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
427427
})
428428
})
429429
// Invalid Authorization header
430-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
430+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
431431
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
432432
if err != nil {
433433
t.Fatalf("Failed to create request: %v", err)
@@ -458,7 +458,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
458458
})
459459
})
460460
// Expired Authorization Bearer token
461-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
461+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, ValidateToken: true}}, func(ctx *httpContext) {
462462
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
463463
if err != nil {
464464
t.Fatalf("Failed to create request: %v", err)
@@ -489,7 +489,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
489489
})
490490
})
491491
// Invalid audience claim Bearer token
492-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "expected-audience"}}, func(ctx *httpContext) {
492+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "expected-audience", ValidateToken: true}}, func(ctx *httpContext) {
493493
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
494494
if err != nil {
495495
t.Fatalf("Failed to create request: %v", err)
@@ -522,7 +522,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
522522
// Failed OIDC validation
523523
key, oidcProvider, httpServer := NewOidcTestServer(t)
524524
t.Cleanup(httpServer.Close)
525-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server"}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
525+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
526526
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
527527
if err != nil {
528528
t.Fatalf("Failed to create request: %v", err)
@@ -559,7 +559,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
559559
"aud": "mcp-server"
560560
}`
561561
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
562-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server"}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
562+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
563563
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
564564
if err != nil {
565565
t.Fatalf("Failed to create request: %v", err)
@@ -583,7 +583,8 @@ func TestAuthorizationUnauthorized(t *testing.T) {
583583
}
584584
})
585585
t.Run("Protected resource with INVALID KUBERNETES Authorization header logs error", func(t *testing.T) {
586-
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - API Server token validation error") {
586+
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") ||
587+
!strings.Contains(ctx.LogBuffer.String(), "kubernetes API token validation error: failed to create token review") {
587588
t.Errorf("Expected log entry for Kubernetes TokenReview error, got: %s", ctx.LogBuffer.String())
588589
}
589590
})
@@ -607,12 +608,17 @@ func TestAuthorizationRequireOAuthFalse(t *testing.T) {
607608
}
608609

609610
func TestAuthorizationRawToken(t *testing.T) {
610-
cases := []string{
611-
"",
612-
"mcp-server",
611+
cases := []struct {
612+
audience string
613+
validateToken bool
614+
}{
615+
{"", false}, // No audience, no validation
616+
{"", true}, // No audience, validation enabled
617+
{"mcp-server", false}, // Audience set, no validation
618+
{"mcp-server", true}, // Audience set, validation enabled
613619
}
614-
for _, audience := range cases {
615-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: audience}}, func(ctx *httpContext) {
620+
for _, c := range cases {
621+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: c.audience, ValidateToken: c.validateToken}}, func(ctx *httpContext) {
616622
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
617623
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
618624
w.Header().Set("Content-Type", "application/json")
@@ -630,7 +636,7 @@ func TestAuthorizationRawToken(t *testing.T) {
630636
t.Fatalf("Failed to get protected endpoint: %v", err)
631637
}
632638
t.Cleanup(func() { _ = resp.Body.Close() })
633-
t.Run("Protected resource with audience = '"+audience+"' with VALID Authorization header returns 200 - OK", func(t *testing.T) {
639+
t.Run(fmt.Sprintf("Protected resource with audience = '%s' and validate-token = '%t', with VALID Authorization header returns 200 - OK", c.audience, c.validateToken), func(t *testing.T) {
634640
if resp.StatusCode != http.StatusOK {
635641
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
636642
}
@@ -649,28 +655,32 @@ func TestAuthorizationOidcToken(t *testing.T) {
649655
"aud": "mcp-server"
650656
}`
651657
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
652-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server"}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
653-
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
654-
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
655-
w.Header().Set("Content-Type", "application/json")
656-
_, _ = w.Write([]byte(tokenReviewSuccessful))
657-
return
658-
}
659-
}))
660-
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
661-
if err != nil {
662-
t.Fatalf("Failed to create request: %v", err)
663-
}
664-
req.Header.Set("Authorization", "Bearer "+validOidcToken)
665-
resp, err := http.DefaultClient.Do(req)
666-
if err != nil {
667-
t.Fatalf("Failed to get protected endpoint: %v", err)
668-
}
669-
t.Cleanup(func() { _ = resp.Body.Close() })
670-
t.Run("Protected resource with VALID OIDC Authorization header returns 200 - OK", func(t *testing.T) {
671-
if resp.StatusCode != http.StatusOK {
672-
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
658+
cases := []bool{false, true}
659+
for _, validateToken := range cases {
660+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server", ValidateToken: validateToken}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
661+
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
662+
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
663+
w.Header().Set("Content-Type", "application/json")
664+
_, _ = w.Write([]byte(tokenReviewSuccessful))
665+
return
666+
}
667+
}))
668+
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
669+
if err != nil {
670+
t.Fatalf("Failed to create request: %v", err)
671+
}
672+
req.Header.Set("Authorization", "Bearer "+validOidcToken)
673+
resp, err := http.DefaultClient.Do(req)
674+
if err != nil {
675+
t.Fatalf("Failed to get protected endpoint: %v", err)
673676
}
677+
t.Cleanup(func() { _ = resp.Body.Close() })
678+
t.Run(fmt.Sprintf("Protected resource with validate-token='%t' with VALID OIDC Authorization header returns 200 - OK", validateToken), func(t *testing.T) {
679+
if resp.StatusCode != http.StatusOK {
680+
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
681+
}
682+
})
674683
})
675-
})
684+
685+
}
676686
}

0 commit comments

Comments
 (0)