|  | 
|  | 1 | +package http | 
|  | 2 | + | 
|  | 3 | +import ( | 
|  | 4 | +	"encoding/base64" | 
|  | 5 | +	"encoding/json" | 
|  | 6 | +	"fmt" | 
|  | 7 | +	"net/http" | 
|  | 8 | +	"slices" | 
|  | 9 | +	"strings" | 
|  | 10 | +	"time" | 
|  | 11 | + | 
|  | 12 | +	"k8s.io/klog/v2" | 
|  | 13 | + | 
|  | 14 | +	"github.com/manusa/kubernetes-mcp-server/pkg/mcp" | 
|  | 15 | +) | 
|  | 16 | + | 
|  | 17 | +const ( | 
|  | 18 | +	Audience = "kubernetes-mcp-server" | 
|  | 19 | +) | 
|  | 20 | + | 
|  | 21 | +// AuthorizationMiddleware validates the OAuth flow using Kubernetes TokenReview API | 
|  | 22 | +func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp.Server) func(http.Handler) http.Handler { | 
|  | 23 | +	return func(next http.Handler) http.Handler { | 
|  | 24 | +		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 
|  | 25 | +			if r.URL.Path == "/healthz" || r.URL.Path == "/.well-known/oauth-protected-resource" { | 
|  | 26 | +				next.ServeHTTP(w, r) | 
|  | 27 | +				return | 
|  | 28 | +			} | 
|  | 29 | +			if !requireOAuth { | 
|  | 30 | +				next.ServeHTTP(w, r) | 
|  | 31 | +				return | 
|  | 32 | +			} | 
|  | 33 | + | 
|  | 34 | +			authHeader := r.Header.Get("Authorization") | 
|  | 35 | +			if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { | 
|  | 36 | +				klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr) | 
|  | 37 | + | 
|  | 38 | +				w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience)) | 
|  | 39 | +				http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized) | 
|  | 40 | +				return | 
|  | 41 | +			} | 
|  | 42 | + | 
|  | 43 | +			token := strings.TrimPrefix(authHeader, "Bearer ") | 
|  | 44 | + | 
|  | 45 | +			audience := Audience | 
|  | 46 | +			if serverURL != "" { | 
|  | 47 | +				audience = serverURL | 
|  | 48 | +			} | 
|  | 49 | + | 
|  | 50 | +			err := validateJWTToken(token, audience) | 
|  | 51 | +			if err != nil { | 
|  | 52 | +				klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err) | 
|  | 53 | + | 
|  | 54 | +				w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience)) | 
|  | 55 | +				http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized) | 
|  | 56 | +				return | 
|  | 57 | +			} | 
|  | 58 | + | 
|  | 59 | +			// Validate token using Kubernetes TokenReview API | 
|  | 60 | +			_, _, err = mcpServer.VerifyToken(r.Context(), token, Audience) | 
|  | 61 | +			if err != nil { | 
|  | 62 | +				klog.V(1).Infof("Authentication failed - token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err) | 
|  | 63 | + | 
|  | 64 | +				w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience)) | 
|  | 65 | +				http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized) | 
|  | 66 | +				return | 
|  | 67 | +			} | 
|  | 68 | + | 
|  | 69 | +			next.ServeHTTP(w, r) | 
|  | 70 | +		}) | 
|  | 71 | +	} | 
|  | 72 | +} | 
|  | 73 | + | 
|  | 74 | +type JWTClaims struct { | 
|  | 75 | +	Issuer    string   `json:"iss"` | 
|  | 76 | +	Audience  []string `json:"aud"` | 
|  | 77 | +	ExpiresAt int64    `json:"exp"` | 
|  | 78 | +} | 
|  | 79 | + | 
|  | 80 | +// validateJWTToken validates basic JWT claims without signature verification | 
|  | 81 | +func validateJWTToken(token, audience string) error { | 
|  | 82 | +	parts := strings.Split(token, ".") | 
|  | 83 | +	if len(parts) != 3 { | 
|  | 84 | +		return fmt.Errorf("invalid JWT token format") | 
|  | 85 | +	} | 
|  | 86 | + | 
|  | 87 | +	claims, err := parseJWTClaims(parts[1]) | 
|  | 88 | +	if err != nil { | 
|  | 89 | +		return fmt.Errorf("failed to parse JWT claims: %v", err) | 
|  | 90 | +	} | 
|  | 91 | + | 
|  | 92 | +	if claims.ExpiresAt > 0 && time.Now().Unix() > claims.ExpiresAt { | 
|  | 93 | +		return fmt.Errorf("token expired") | 
|  | 94 | +	} | 
|  | 95 | + | 
|  | 96 | +	if !slices.Contains(claims.Audience, audience) { | 
|  | 97 | +		return fmt.Errorf("token audience mismatch: %v", claims.Audience) | 
|  | 98 | +	} | 
|  | 99 | + | 
|  | 100 | +	return nil | 
|  | 101 | +} | 
|  | 102 | + | 
|  | 103 | +func parseJWTClaims(payload string) (*JWTClaims, error) { | 
|  | 104 | +	// Add padding if needed | 
|  | 105 | +	if len(payload)%4 != 0 { | 
|  | 106 | +		payload += strings.Repeat("=", 4-len(payload)%4) | 
|  | 107 | +	} | 
|  | 108 | + | 
|  | 109 | +	decoded, err := base64.URLEncoding.DecodeString(payload) | 
|  | 110 | +	if err != nil { | 
|  | 111 | +		return nil, fmt.Errorf("failed to decode JWT payload: %v", err) | 
|  | 112 | +	} | 
|  | 113 | + | 
|  | 114 | +	var claims JWTClaims | 
|  | 115 | +	if err := json.Unmarshal(decoded, &claims); err != nil { | 
|  | 116 | +		return nil, fmt.Errorf("failed to unmarshal JWT claims: %v", err) | 
|  | 117 | +	} | 
|  | 118 | + | 
|  | 119 | +	return &claims, nil | 
|  | 120 | +} | 
0 commit comments