Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class OAuthContext:
# Discovery state for fallback support
discovery_base_url: str | None = None
discovery_pathname: str | None = None
# Optional expected issuer for access tokens (JWT iss claim)
expected_issuer: str | None = None

def get_authorization_base_url(self, server_url: str) -> str:
"""Extract base URL by removing path component."""
Expand All @@ -126,12 +128,64 @@ def update_token_expiry(self, token: OAuthToken) -> None:

def is_token_valid(self) -> bool:
"""Check if current token is valid."""
return bool(
# Basic existence and expiry checks
basic_valid = bool(
self.current_tokens
and self.current_tokens.access_token
and (not self.token_expiry_time or time.time() <= self.token_expiry_time)
)

if not basic_valid:
return False

# If no expected issuer is configured, behave as before
if not getattr(self, "expected_issuer", None):
return True

# If expected_issuer is set, ensure token issuer matches
try:
return self._token_issuer_matches(self.current_tokens.access_token)
except Exception:
# On any parsing issue, treat token as invalid
logger.exception("Failed to validate token issuer")
return False

def _token_issuer_matches(self, token: str) -> bool:
"""Decode a JWT access token (no signature verification) and compare its 'iss' claim.

This performs a safe, minimal check: split the token, base64-decode the payload,
parse JSON, and compare the 'iss' field to self.expected_issuer. Returns False
if the token is malformed or the claim is missing/mismatched.
"""
# JWTs are in the form header.payload.signature
parts = token.split(".")
if len(parts) < 2:
return False

payload_b64 = parts[1]

# Add padding for base64 if necessary
padding = "=" * (-len(payload_b64) % 4)
payload_b64 += padding

try:
payload_bytes = base64.urlsafe_b64decode(payload_b64.encode())
except Exception:
return False

try:
import json

payload = json.loads(payload_bytes)
except Exception:
return False

iss = payload.get("iss")
if not iss:
return False

return iss == self.expected_issuer

def can_refresh_token(self) -> bool:
"""Check if token can be refreshed."""
return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info)
Expand Down
Loading