Skip to content

Commit 4dcede1

Browse files
authored
refactor(auth): consolidate JWT validation into single method (containers#238)
Signed-off-by: Marc Nuri <[email protected]>
1 parent 4302a43 commit 4dcede1

File tree

3 files changed

+24
-38
lines changed

3 files changed

+24
-38
lines changed

pkg/http/authorization.go

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
5656
// rejected already.
5757
claims, err := ParseJWTClaims(token)
5858
if err == nil && claims != nil {
59-
err = claims.Validate(audience)
59+
err = claims.Validate(r.Context(), audience, oidcProvider)
6060
}
6161
if err != nil {
6262
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
@@ -70,21 +70,6 @@ func AuthorizationMiddleware(requireOAuth bool, serverURL string, oidcProvider *
7070
return
7171
}
7272

73-
if oidcProvider != nil {
74-
// If OIDC Provider is configured, this token must be validated against it.
75-
if err := validateTokenWithOIDC(r.Context(), oidcProvider, token, audience); err != nil {
76-
klog.V(1).Infof("Authentication failed - OIDC token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
77-
78-
if serverURL == "" {
79-
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s", error="invalid_token"`, audience))
80-
} else {
81-
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience="%s"", resource_metadata="%s%s", error="invalid_token"`, audience, serverURL, oauthProtectedResourceEndpoint))
82-
}
83-
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
84-
return
85-
}
86-
}
87-
8873
// Scopes are likely to be used for authorization.
8974
scopes := claims.GetScopes()
9075
klog.V(2).Infof("JWT token validated - Scopes: %v", scopes)
@@ -138,6 +123,7 @@ var allSignatureAlgorithms = []jose.SignatureAlgorithm{
138123

139124
type JWTClaims struct {
140125
jwt.Claims
126+
Token string `json:"-"`
141127
Scope string `json:"scope,omitempty"`
142128
}
143129

@@ -149,10 +135,21 @@ func (c *JWTClaims) GetScopes() []string {
149135
}
150136

151137
// Validate Checks if the JWT claims are valid and if the audience matches the expected one.
152-
func (c *JWTClaims) Validate(audience string) error {
153-
return c.Claims.Validate(jwt.Expected{
154-
AnyAudience: jwt.Audience{audience},
155-
})
138+
func (c *JWTClaims) Validate(ctx context.Context, audience string, provider *oidc.Provider) error {
139+
if err := c.Claims.Validate(jwt.Expected{AnyAudience: jwt.Audience{audience}}); err != nil {
140+
return fmt.Errorf("JWT token validation error: %v", err)
141+
}
142+
if provider != nil {
143+
verifier := provider.Verifier(&oidc.Config{
144+
ClientID: audience,
145+
})
146+
147+
_, err := verifier.Verify(ctx, c.Token)
148+
if err != nil {
149+
return fmt.Errorf("OIDC token validation error: %v", err)
150+
}
151+
}
152+
return nil
156153
}
157154

158155
func ParseJWTClaims(token string) (*JWTClaims, error) {
@@ -162,18 +159,6 @@ func ParseJWTClaims(token string) (*JWTClaims, error) {
162159
}
163160
claims := &JWTClaims{}
164161
err = tkn.UnsafeClaimsWithoutVerification(claims)
162+
claims.Token = token
165163
return claims, err
166164
}
167-
168-
func validateTokenWithOIDC(ctx context.Context, provider *oidc.Provider, token, audience string) error {
169-
verifier := provider.Verifier(&oidc.Config{
170-
ClientID: audience,
171-
})
172-
173-
_, err := verifier.Verify(ctx, token)
174-
if err != nil {
175-
return fmt.Errorf("JWT token verification failed: %v", err)
176-
}
177-
178-
return nil
179-
}

pkg/http/authorization_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ func TestJWTTokenValidate(t *testing.T) {
111111
t.Fatalf("expected no error for expired token parsing, got %v", err)
112112
}
113113

114-
err = claims.Validate("kubernetes-mcp-server")
114+
err = claims.Validate(t.Context(), "kubernetes-mcp-server", nil)
115115
if err == nil {
116116
t.Fatalf("expected error for expired token, got nil")
117117
}
@@ -130,7 +130,7 @@ func TestJWTTokenValidate(t *testing.T) {
130130
t.Fatalf("expected claims to be returned, got nil")
131131
}
132132

133-
err = claims.Validate("kubernetes-mcp-server")
133+
err = claims.Validate(t.Context(), "kubernetes-mcp-server", nil)
134134
if err != nil {
135135
t.Fatalf("expected no error for valid audience, got %v", err)
136136
}
@@ -145,7 +145,7 @@ func TestJWTTokenValidate(t *testing.T) {
145145
t.Fatalf("expected claims to be returned, got nil")
146146
}
147147

148-
err = claims.Validate("missing-audience")
148+
err = claims.Validate(t.Context(), "missing-audience", nil)
149149
if err == nil {
150150
t.Fatalf("expected error for token with wrong audience, got nil")
151151
}

pkg/http/http_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ func NewOidcTestServer(t *testing.T) (privateKey *rsa.PrivateKey, oidcProvider *
127127
t.Fatalf("failed to generate private key for oidc: %v", err)
128128
}
129129
oidcServer := &oidctest.Server{
130+
Algorithms: []string{oidc.RS256, oidc.ES256},
130131
PublicKeys: []oidctest.PublicKey{
131132
{
132133
PublicKey: privateKey.Public(),
@@ -470,8 +471,8 @@ func TestAuthorizationUnauthorized(t *testing.T) {
470471
}
471472
})
472473
t.Run("Protected resource with INVALID OIDC Authorization header logs error", func(t *testing.T) {
473-
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - OIDC token validation error") &&
474-
!strings.Contains(ctx.LogBuffer.String(), "JWT token verification failed: oidc: id token issued by a different provider") {
474+
if !strings.Contains(ctx.LogBuffer.String(), "Authentication failed - JWT validation error") ||
475+
!strings.Contains(ctx.LogBuffer.String(), "OIDC token validation error: failed to verify signature") {
475476
t.Errorf("Expected log entry for OIDC validation error, got: %s", ctx.LogBuffer.String())
476477
}
477478
})

0 commit comments

Comments
 (0)