Skip to content

Commit 16edf44

Browse files
committed
Wire server-url to audience, if it is set
1 parent a74f33c commit 16edf44

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

pkg/http/authorization.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ const (
1919
)
2020

2121
// AuthorizationMiddleware validates the OAuth flow using Kubernetes TokenReview API
22-
func AuthorizationMiddleware(requireOAuth bool, mcpServer *mcp.Server) func(http.Handler) http.Handler {
22+
func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp.Server) func(http.Handler) http.Handler {
2323
return func(next http.Handler) http.Handler {
2424
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2525
if r.URL.Path == "/healthz" || r.URL.Path == "/.well-known/oauth-protected-resource" {
@@ -42,7 +42,12 @@ func AuthorizationMiddleware(requireOAuth bool, mcpServer *mcp.Server) func(http
4242

4343
token := strings.TrimPrefix(authHeader, "Bearer ")
4444

45-
err := validateJWTToken(token)
45+
audience := Audience
46+
if serverURL != "" {
47+
audience = serverURL
48+
}
49+
50+
err := validateJWTToken(token, audience)
4651
if err != nil {
4752
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
4853

@@ -73,7 +78,7 @@ type JWTClaims struct {
7378
}
7479

7580
// validateJWTToken validates basic JWT claims without signature verification
76-
func validateJWTToken(token string) error {
81+
func validateJWTToken(token, audience string) error {
7782
parts := strings.Split(token, ".")
7883
if len(parts) != 3 {
7984
return fmt.Errorf("invalid JWT token format")
@@ -88,7 +93,7 @@ func validateJWTToken(token string) error {
8893
return fmt.Errorf("token expired")
8994
}
9095

91-
if !slices.Contains(claims.Audience, Audience) {
96+
if !slices.Contains(claims.Audience, audience) {
9297
return fmt.Errorf("token audience mismatch: %v", claims.Audience)
9398
}
9499

pkg/http/authorization_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func TestValidateJWTToken(t *testing.T) {
9898
t.Run("invalid token format - not enough parts", func(t *testing.T) {
9999
invalidToken := "header.payload"
100100

101-
err := validateJWTToken(invalidToken)
101+
err := validateJWTToken(invalidToken, "test")
102102
if err == nil {
103103
t.Error("expected error for invalid token format, got nil")
104104
}
@@ -120,7 +120,7 @@ func TestValidateJWTToken(t *testing.T) {
120120
payload := base64.URLEncoding.EncodeToString(jsonBytes)
121121
expiredToken := "header." + payload + ".signature"
122122

123-
err := validateJWTToken(expiredToken)
123+
err := validateJWTToken(expiredToken, "kubernetes-mcp-server")
124124
if err == nil {
125125
t.Error("expected error for expired token, got nil")
126126
}
@@ -142,7 +142,7 @@ func TestValidateJWTToken(t *testing.T) {
142142
payload := base64.URLEncoding.EncodeToString(jsonBytes)
143143
multiAudToken := "header." + payload + ".signature"
144144

145-
err := validateJWTToken(multiAudToken)
145+
err := validateJWTToken(multiAudToken, "kubernetes-mcp-server")
146146
if err != nil {
147147
t.Errorf("expected no error for token with multiple audiences, got %v", err)
148148
}
@@ -160,7 +160,7 @@ func TestValidateJWTToken(t *testing.T) {
160160
payload := base64.URLEncoding.EncodeToString(jsonBytes)
161161
wrongAudToken := "header." + payload + ".signature"
162162

163-
err := validateJWTToken(wrongAudToken)
163+
err := validateJWTToken(wrongAudToken, "audience")
164164
if err == nil {
165165
t.Error("expected error for token with wrong audience, got nil")
166166
}
@@ -183,7 +183,7 @@ func TestAuthorizationMiddleware(t *testing.T) {
183183
handlerCalled = false
184184

185185
// Create middleware with OAuth disabled
186-
middleware := AuthorizationMiddleware(false, nil)
186+
middleware := AuthorizationMiddleware(false, "", nil)
187187
wrappedHandler := middleware(handler)
188188

189189
// Create request without authorization header
@@ -204,7 +204,7 @@ func TestAuthorizationMiddleware(t *testing.T) {
204204
handlerCalled = false
205205

206206
// Create middleware with OAuth enabled
207-
middleware := AuthorizationMiddleware(true, nil)
207+
middleware := AuthorizationMiddleware(true, "", nil)
208208
wrappedHandler := middleware(handler)
209209

210210
// Create request to healthz endpoint
@@ -225,7 +225,7 @@ func TestAuthorizationMiddleware(t *testing.T) {
225225
handlerCalled = false
226226

227227
// Create middleware with OAuth enabled
228-
middleware := AuthorizationMiddleware(true, nil)
228+
middleware := AuthorizationMiddleware(true, "", nil)
229229
wrappedHandler := middleware(handler)
230230

231231
// Create request without authorization header
@@ -249,7 +249,7 @@ func TestAuthorizationMiddleware(t *testing.T) {
249249
handlerCalled = false
250250

251251
// Create middleware with OAuth enabled
252-
middleware := AuthorizationMiddleware(true, nil)
252+
middleware := AuthorizationMiddleware(true, "", nil)
253253
wrappedHandler := middleware(handler)
254254

255255
// Create request with invalid bearer token

pkg/http/http.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func Serve(ctx context.Context, mcpServer *mcp.Server, staticConfig *config.Stat
2020
mux := http.NewServeMux()
2121

2222
wrappedMux := RequestMiddleware(
23-
AuthorizationMiddleware(staticConfig.RequireOAuth, mcpServer)(mux),
23+
AuthorizationMiddleware(staticConfig.RequireOAuth, staticConfig.ServerURL, mcpServer)(mux),
2424
)
2525

2626
httpServer := &http.Server{

0 commit comments

Comments
 (0)