From b1e07ab09d64fc9835b9f46dbf8d688923535879 Mon Sep 17 00:00:00 2001 From: Ujjwal-Bajpayee Date: Wed, 8 Oct 2025 20:07:51 +0530 Subject: [PATCH] fix(auth): add token issuer validation per MCP spec (closes #1442) --- src/mcp/client/auth.py | 56 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 376036e8c..205d3f891 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -111,6 +111,8 @@ class OAuthContext: # Discovery state for fallback support discovery_base_url: str | None = None discovery_pathname: str | None = None + # Optional expected issuer for access tokens (JWT iss claim) + expected_issuer: str | None = None def get_authorization_base_url(self, server_url: str) -> str: """Extract base URL by removing path component.""" @@ -126,12 +128,64 @@ def update_token_expiry(self, token: OAuthToken) -> None: def is_token_valid(self) -> bool: """Check if current token is valid.""" - return bool( + # Basic existence and expiry checks + basic_valid = bool( self.current_tokens and self.current_tokens.access_token and (not self.token_expiry_time or time.time() <= self.token_expiry_time) ) + if not basic_valid: + return False + + # If no expected issuer is configured, behave as before + if not getattr(self, "expected_issuer", None): + return True + + # If expected_issuer is set, ensure token issuer matches + try: + return self._token_issuer_matches(self.current_tokens.access_token) + except Exception: + # On any parsing issue, treat token as invalid + logger.exception("Failed to validate token issuer") + return False + + def _token_issuer_matches(self, token: str) -> bool: + """Decode a JWT access token (no signature verification) and compare its 'iss' claim. + + This performs a safe, minimal check: split the token, base64-decode the payload, + parse JSON, and compare the 'iss' field to self.expected_issuer. Returns False + if the token is malformed or the claim is missing/mismatched. + """ + # JWTs are in the form header.payload.signature + parts = token.split(".") + if len(parts) < 2: + return False + + payload_b64 = parts[1] + + # Add padding for base64 if necessary + padding = "=" * (-len(payload_b64) % 4) + payload_b64 += padding + + try: + payload_bytes = base64.urlsafe_b64decode(payload_b64.encode()) + except Exception: + return False + + try: + import json + + payload = json.loads(payload_bytes) + except Exception: + return False + + iss = payload.get("iss") + if not iss: + return False + + return iss == self.expected_issuer + def can_refresh_token(self) -> bool: """Check if token can be refreshed.""" return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info)