Skip to content

Commit 7b11c16

Browse files
authored
feat(auth): configurable audience validation (#251)
Signed-off-by: Marc Nuri <[email protected]>
1 parent b0da9fb commit 7b11c16

File tree

6 files changed

+117
-54
lines changed

6 files changed

+117
-54
lines changed

pkg/config/config.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,17 @@ type StaticConfig struct {
1919
// When true, expose only tools annotated with readOnlyHint=true
2020
ReadOnly bool `toml:"read_only,omitempty"`
2121
// When true, disable tools annotated with destructiveHint=true
22-
DisableDestructive bool `toml:"disable_destructive,omitempty"`
23-
EnabledTools []string `toml:"enabled_tools,omitempty"`
24-
DisabledTools []string `toml:"disabled_tools,omitempty"`
25-
RequireOAuth bool `toml:"require_oauth,omitempty"`
26-
AuthorizationURL string `toml:"authorization_url,omitempty"`
27-
CertificateAuthority string `toml:"certificate_authority,omitempty"`
28-
ServerURL string `toml:"server_url,omitempty"`
22+
DisableDestructive bool `toml:"disable_destructive,omitempty"`
23+
EnabledTools []string `toml:"enabled_tools,omitempty"`
24+
DisabledTools []string `toml:"disabled_tools,omitempty"`
25+
RequireOAuth bool `toml:"require_oauth,omitempty"`
26+
// OAuthAudience is the valid audience for the OAuth tokens, used for offline JWT claim validation.
27+
OAuthAudience string `toml:"oauth_audience,omitempty"`
28+
// AuthorizationURL is the URL of the OIDC authorization server.
29+
// It is used for token validation and for STS token exchange.
30+
AuthorizationURL string `toml:"authorization_url,omitempty"`
31+
CertificateAuthority string `toml:"certificate_authority,omitempty"`
32+
ServerURL string `toml:"server_url,omitempty"`
2933
}
3034

3135
type GroupVersionKind struct {

pkg/http/authorization.go

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,17 @@ type KubernetesApiTokenVerifier interface {
4141
// 2. requireOAuth is set to true, server is protected:
4242
//
4343
// 2.1. Raw Token Validation (oidcProvider is nil):
44-
// - The token is validated offline for basic sanity checks (audience and expiration).
45-
// - The token is then used against the Kubernetes API Server for TokenReview.
44+
// - The token is validated offline for basic sanity checks (expiration).
45+
// - If audience is set, the token is validated against the audience.
46+
// - The token is then used against the Kubernetes API Server for TokenReview.
4647
//
4748
// 2.2. OIDC Provider Validation (oidcProvider is not nil):
4849
// - The token is validated offline for basic sanity checks (audience and expiration).
4950
// - The token is then validated against the OIDC Provider.
5051
// - The token is then used against the Kubernetes API Server for TokenReview.
5152
//
5253
// 2.3. OIDC Token Exchange (oidcProvider is not nil and xxx):
53-
func AuthorizationMiddleware(requireOAuth bool, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler {
54+
func AuthorizationMiddleware(requireOAuth bool, audience string, oidcProvider *oidc.Provider, verifier KubernetesApiTokenVerifier) func(http.Handler) http.Handler {
5455
return func(next http.Handler) http.Handler {
5556
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
5657
if r.URL.Path == healthEndpoint || slices.Contains(WellKnownEndpoints, r.URL.EscapedPath()) {
@@ -62,13 +63,16 @@ func AuthorizationMiddleware(requireOAuth bool, oidcProvider *oidc.Provider, ver
6263
return
6364
}
6465

65-
audience := Audience
66+
wwwAuthenticateHeader := "Bearer realm=\"Kubernetes MCP Server\""
67+
if audience != "" {
68+
wwwAuthenticateHeader += fmt.Sprintf(`, audience="%s"`, audience)
69+
}
6670

6771
authHeader := r.Header.Get("Authorization")
6872
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
6973
klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr)
7074

71-
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="missing_token"`, audience))
75+
w.Header().Set("WWW-Authenticate", wwwAuthenticateHeader+", error=\"missing_token\"")
7276
http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized)
7377
return
7478
}
@@ -77,12 +81,15 @@ func AuthorizationMiddleware(requireOAuth bool, oidcProvider *oidc.Provider, ver
7781

7882
claims, err := ParseJWTClaims(token)
7983
if err == nil && claims != nil {
80-
err = claims.Validate(r.Context(), audience, oidcProvider)
84+
err = claims.ValidateOffline(audience)
85+
}
86+
if err == nil && claims != nil {
87+
err = claims.ValidateWithProvider(r.Context(), audience, oidcProvider)
8188
}
8289
if err != nil {
8390
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
8491

85-
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
92+
w.Header().Set("WWW-Authenticate", wwwAuthenticateHeader+", error=\"invalid_token\"")
8693
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
8794
return
8895
}
@@ -147,16 +154,24 @@ func (c *JWTClaims) GetScopes() []string {
147154
return strings.Fields(c.Scope)
148155
}
149156

150-
// Validate Checks if the JWT claims are valid and if the audience matches the expected one.
151-
func (c *JWTClaims) Validate(ctx context.Context, audience string, provider *oidc.Provider) error {
152-
if err := c.Claims.Validate(jwt.Expected{AnyAudience: jwt.Audience{audience}}); err != nil {
157+
// ValidateOffline Checks if the JWT claims are valid and if the audience matches the expected one.
158+
func (c *JWTClaims) ValidateOffline(audience string) error {
159+
expected := jwt.Expected{}
160+
if audience != "" {
161+
expected.AnyAudience = jwt.Audience{audience}
162+
}
163+
if err := c.Validate(expected); err != nil {
153164
return fmt.Errorf("JWT token validation error: %v", err)
154165
}
166+
return nil
167+
}
168+
169+
// ValidateWithProvider validates the JWT claims against the OIDC provider.
170+
func (c *JWTClaims) ValidateWithProvider(ctx context.Context, audience string, provider *oidc.Provider) error {
155171
if provider != nil {
156172
verifier := provider.Verifier(&oidc.Config{
157173
ClientID: audience,
158174
})
159-
160175
_, err := verifier.Verify(ctx, c.Token)
161176
if err != nil {
162177
return fmt.Errorf("OIDC token validation error: %v", err)

pkg/http/authorization_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,14 @@ func TestParseJWTClaimsPayloadInvalid(t *testing.T) {
104104
})
105105
}
106106

107-
func TestJWTTokenValidate(t *testing.T) {
107+
func TestJWTTokenValidateOffline(t *testing.T) {
108108
t.Run("expired token returns error", func(t *testing.T) {
109109
claims, err := ParseJWTClaims(tokenBasicExpired)
110110
if err != nil {
111111
t.Fatalf("expected no error for expired token parsing, got %v", err)
112112
}
113113

114-
err = claims.Validate(t.Context(), "mcp-server", nil)
114+
err = claims.ValidateOffline("mcp-server")
115115
if err == nil {
116116
t.Fatalf("expected error for expired token, got nil")
117117
}
@@ -130,7 +130,7 @@ func TestJWTTokenValidate(t *testing.T) {
130130
t.Fatalf("expected claims to be returned, got nil")
131131
}
132132

133-
err = claims.Validate(t.Context(), "mcp-server", nil)
133+
err = claims.ValidateOffline("mcp-server")
134134
if err != nil {
135135
t.Fatalf("expected no error for valid audience, got %v", err)
136136
}
@@ -145,7 +145,7 @@ func TestJWTTokenValidate(t *testing.T) {
145145
t.Fatalf("expected claims to be returned, got nil")
146146
}
147147

148-
err = claims.Validate(t.Context(), "missing-audience", nil)
148+
err = claims.ValidateOffline("missing-audience")
149149
if err == nil {
150150
t.Fatalf("expected error for token with wrong audience, got nil")
151151
}

pkg/http/http.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat
2828
mux := http.NewServeMux()
2929

3030
wrappedMux := RequestMiddleware(
31-
AuthorizationMiddleware(staticConfig.RequireOAuth, oidcProvider, mcpServer)(mux),
31+
AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.OAuthAudience, oidcProvider, mcpServer)(mux),
3232
)
3333

3434
httpServer := &http.Server{

pkg/http/http_test.go

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
390390
})
391391
t.Run("Protected resource with MISSING Authorization header returns WWW-Authenticate header", func(t *testing.T) {
392392
authHeader := resp.Header.Get("WWW-Authenticate")
393-
expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="missing_token"`
393+
expected := `Bearer realm="Kubernetes MCP Server", error="missing_token"`
394394
if authHeader != expected {
395395
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
396396
}
@@ -415,7 +415,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
415415
t.Cleanup(func() { _ = resp.Body.Close })
416416
t.Run("Protected resource with INCOMPATIBLE Authorization header returns WWW-Authenticate header", func(t *testing.T) {
417417
authHeader := resp.Header.Get("WWW-Authenticate")
418-
expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="missing_token"`
418+
expected := `Bearer realm="Kubernetes MCP Server", error="missing_token"`
419419
if authHeader != expected {
420420
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
421421
}
@@ -432,7 +432,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
432432
if err != nil {
433433
t.Fatalf("Failed to create request: %v", err)
434434
}
435-
req.Header.Set("Authorization", "Bearer invalid_base64"+tokenBasicNotExpired)
435+
req.Header.Set("Authorization", "Bearer "+strings.ReplaceAll(tokenBasicNotExpired, ".", ".invalid"))
436436
resp, err := http.DefaultClient.Do(req)
437437
if err != nil {
438438
t.Fatalf("Failed to get protected endpoint: %v", err)
@@ -445,13 +445,13 @@ func TestAuthorizationUnauthorized(t *testing.T) {
445445
})
446446
t.Run("Protected resource with INVALID Authorization header returns WWW-Authenticate header", func(t *testing.T) {
447447
authHeader := resp.Header.Get("WWW-Authenticate")
448-
expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="invalid_token"`
448+
expected := `Bearer realm="Kubernetes MCP Server", error="invalid_token"`
449449
if authHeader != expected {
450450
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
451451
}
452452
})
453453
t.Run("Protected resource with INVALID Authorization header logs error", func(t *testing.T) {
454-
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") &&
454+
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") ||
455455
!strings.Contains(ctx.LogBuffer.String(), "error: failed to parse JWT token: illegal base64 data") {
456456
t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String())
457457
}
@@ -476,22 +476,53 @@ func TestAuthorizationUnauthorized(t *testing.T) {
476476
})
477477
t.Run("Protected resource with EXPIRED Authorization header returns WWW-Authenticate header", func(t *testing.T) {
478478
authHeader := resp.Header.Get("WWW-Authenticate")
479-
expected := `Bearer realm="Kubernetes MCP Server", audience="mcp-server", error="invalid_token"`
479+
expected := `Bearer realm="Kubernetes MCP Server", error="invalid_token"`
480480
if authHeader != expected {
481481
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
482482
}
483483
})
484484
t.Run("Protected resource with EXPIRED Authorization header logs error", func(t *testing.T) {
485-
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") &&
485+
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") ||
486486
!strings.Contains(ctx.LogBuffer.String(), "validation failed, token is expired (exp)") {
487487
t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String())
488488
}
489489
})
490490
})
491+
// Invalid audience claim Bearer token
492+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "expected-audience"}}, func(ctx *httpContext) {
493+
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
494+
if err != nil {
495+
t.Fatalf("Failed to create request: %v", err)
496+
}
497+
req.Header.Set("Authorization", "Bearer "+tokenBasicExpired)
498+
resp, err := http.DefaultClient.Do(req)
499+
if err != nil {
500+
t.Fatalf("Failed to get protected endpoint: %v", err)
501+
}
502+
t.Cleanup(func() { _ = resp.Body.Close })
503+
t.Run("Protected resource with INVALID AUDIENCE Authorization header returns 401 - Unauthorized", func(t *testing.T) {
504+
if resp.StatusCode != 401 {
505+
t.Errorf("Expected HTTP 401, got %d", resp.StatusCode)
506+
}
507+
})
508+
t.Run("Protected resource with INVALID AUDIENCE Authorization header returns WWW-Authenticate header", func(t *testing.T) {
509+
authHeader := resp.Header.Get("WWW-Authenticate")
510+
expected := `Bearer realm="Kubernetes MCP Server", audience="expected-audience", error="invalid_token"`
511+
if authHeader != expected {
512+
t.Errorf("Expected WWW-Authenticate header to be %q, got %q", expected, authHeader)
513+
}
514+
})
515+
t.Run("Protected resource with INVALID AUDIENCE Authorization header logs error", func(t *testing.T) {
516+
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") ||
517+
!strings.Contains(ctx.LogBuffer.String(), "invalid audience claim (aud)") {
518+
t.Errorf("Expected log entry for JWT validation error, got: %s", ctx.LogBuffer.String())
519+
}
520+
})
521+
})
491522
// Failed OIDC validation
492523
key, oidcProvider, httpServer := NewOidcTestServer(t)
493524
t.Cleanup(httpServer.Close)
494-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
525+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server"}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
495526
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
496527
if err != nil {
497528
t.Fatalf("Failed to create request: %v", err)
@@ -528,7 +559,7 @@ func TestAuthorizationUnauthorized(t *testing.T) {
528559
"aud": "mcp-server"
529560
}`
530561
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
531-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
562+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server"}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
532563
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
533564
if err != nil {
534565
t.Fatalf("Failed to create request: %v", err)
@@ -576,30 +607,37 @@ func TestAuthorizationRequireOAuthFalse(t *testing.T) {
576607
}
577608

578609
func TestAuthorizationRawToken(t *testing.T) {
579-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}}, func(ctx *httpContext) {
580-
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
581-
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
582-
w.Header().Set("Content-Type", "application/json")
583-
_, _ = w.Write([]byte(tokenReviewSuccessful))
584-
return
610+
cases := []string{
611+
"",
612+
"mcp-server",
613+
}
614+
for _, audience := range cases {
615+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: audience}}, func(ctx *httpContext) {
616+
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
617+
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
618+
w.Header().Set("Content-Type", "application/json")
619+
_, _ = w.Write([]byte(tokenReviewSuccessful))
620+
return
621+
}
622+
}))
623+
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
624+
if err != nil {
625+
t.Fatalf("Failed to create request: %v", err)
585626
}
586-
}))
587-
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/mcp", ctx.HttpAddress), nil)
588-
if err != nil {
589-
t.Fatalf("Failed to create request: %v", err)
590-
}
591-
req.Header.Set("Authorization", "Bearer "+tokenBasicNotExpired)
592-
resp, err := http.DefaultClient.Do(req)
593-
if err != nil {
594-
t.Fatalf("Failed to get protected endpoint: %v", err)
595-
}
596-
t.Cleanup(func() { _ = resp.Body.Close() })
597-
t.Run("Protected resource with VALID Authorization header returns 200 - OK", func(t *testing.T) {
598-
if resp.StatusCode != http.StatusOK {
599-
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
627+
req.Header.Set("Authorization", "Bearer "+tokenBasicNotExpired)
628+
resp, err := http.DefaultClient.Do(req)
629+
if err != nil {
630+
t.Fatalf("Failed to get protected endpoint: %v", err)
600631
}
632+
t.Cleanup(func() { _ = resp.Body.Close() })
633+
t.Run("Protected resource with audience = '"+audience+"' with VALID Authorization header returns 200 - OK", func(t *testing.T) {
634+
if resp.StatusCode != http.StatusOK {
635+
t.Errorf("Expected HTTP 200 OK, got %d", resp.StatusCode)
636+
}
637+
})
601638
})
602-
})
639+
}
640+
603641
}
604642

605643
func TestAuthorizationOidcToken(t *testing.T) {
@@ -611,7 +649,7 @@ func TestAuthorizationOidcToken(t *testing.T) {
611649
"aud": "mcp-server"
612650
}`
613651
validOidcToken := oidctest.SignIDToken(key, "test-oidc-key-id", oidc.RS256, rawClaims)
614-
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
652+
testCaseWithContext(t, &httpContext{StaticConfig: &config.StaticConfig{RequireOAuth: true, OAuthAudience: "mcp-server"}, OidcProvider: oidcProvider}, func(ctx *httpContext) {
615653
ctx.mockServer.Handle(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
616654
if req.URL.EscapedPath() == "/apis/authentication.k8s.io/v1/tokenreviews" {
617655
w.Header().Set("Content-Type", "application/json")

pkg/kubernetes-mcp-server/cmd/root.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ type MCPServerOptions struct {
6262
ReadOnly bool
6363
DisableDestructive bool
6464
RequireOAuth bool
65+
OAuthAudience string
6566
AuthorizationURL string
6667
CertificateAuthority string
6768
ServerURL string
@@ -119,6 +120,8 @@ func NewMCPServer(streams genericiooptions.IOStreams) *cobra.Command {
119120
cmd.Flags().BoolVar(&o.DisableDestructive, "disable-destructive", o.DisableDestructive, "If true, tools annotated with destructiveHint=true are disabled")
120121
cmd.Flags().BoolVar(&o.RequireOAuth, "require-oauth", o.RequireOAuth, "If true, requires OAuth authorization as defined in the Model Context Protocol (MCP) specification. This flag is ignored if transport type is stdio")
121122
_ = cmd.Flags().MarkHidden("require-oauth")
123+
cmd.Flags().StringVar(&o.OAuthAudience, "oauth-audience", o.OAuthAudience, "OAuth audience for token claims validation. Optional. If not set, the audience is not validated. Only valid if require-oauth is enabled.")
124+
_ = cmd.Flags().MarkHidden("oauth-audience")
122125
cmd.Flags().StringVar(&o.AuthorizationURL, "authorization-url", o.AuthorizationURL, "OAuth authorization server URL for protected resource endpoint. If not provided, the Kubernetes API server host will be used. Only valid if require-oauth is enabled.")
123126
_ = cmd.Flags().MarkHidden("authorization-url")
124127
cmd.Flags().StringVar(&o.ServerURL, "server-url", o.ServerURL, "Server URL of this application. Optional. If set, this url will be served in protected resource metadata endpoint and tokens will be validated with this audience. If not set, expected audience is kubernetes-mcp-server. Only valid if require-oauth is enabled.")
@@ -179,6 +182,9 @@ func (m *MCPServerOptions) loadFlags(cmd *cobra.Command) {
179182
if cmd.Flag("require-oauth").Changed {
180183
m.StaticConfig.RequireOAuth = m.RequireOAuth
181184
}
185+
if cmd.Flag("oauth-audience").Changed {
186+
m.StaticConfig.OAuthAudience = m.OAuthAudience
187+
}
182188
if cmd.Flag("authorization-url").Changed {
183189
m.StaticConfig.AuthorizationURL = m.AuthorizationURL
184190
}

0 commit comments

Comments
 (0)