Skip to content

Commit e10f7c9

Browse files
committed
Remove 2LO in this branch, limit to RFC7523
1 parent ed23997 commit e10f7c9

File tree

3 files changed

+158
-227
lines changed

3 files changed

+158
-227
lines changed

src/mcp/client/auth.py

Lines changed: 124 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,40 @@ class JWTParameters(BaseModel):
8080
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
8181
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
8282

83+
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
84+
if self.assertion is not None:
85+
# Prebuilt JWT (e.g. acquired out-of-band)
86+
assertion = self.assertion
87+
else:
88+
if not self.jwt_signing_key:
89+
raise OAuthFlowError("Missing signing key for JWT bearer grant")
90+
if not self.issuer:
91+
raise OAuthFlowError("Missing issuer for JWT bearer grant")
92+
if not self.subject:
93+
raise OAuthFlowError("Missing subject for JWT bearer grant")
94+
95+
audience = self.audience if self.audience else with_audience_fallback
96+
if not audience:
97+
raise OAuthFlowError("Missing audience for JWT bearer grant")
98+
99+
now = int(time.time())
100+
claims: dict[str, Any] = {
101+
"iss": self.issuer,
102+
"sub": self.subject,
103+
"aud": audience,
104+
"exp": now + self.jwt_lifetime_seconds,
105+
"iat": now,
106+
"jti": str(uuid4()),
107+
}
108+
claims.update(self.claims or {})
109+
110+
assertion = jwt.encode(
111+
claims,
112+
self.jwt_signing_key,
113+
algorithm=self.jwt_signing_algorithm or "RS256",
114+
)
115+
return assertion
116+
83117

84118
class TokenStorage(Protocol):
85119
"""Protocol for token storage implementations."""
@@ -111,7 +145,6 @@ class OAuthContext:
111145
redirect_handler: Callable[[str], Awaitable[None]] | None
112146
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
113147
timeout: float = 300.0
114-
jwt_parameters: JWTParameters | None = None
115148

116149
# Discovered metadata
117150
protected_resource_metadata: ProtectedResourceMetadata | None = None
@@ -213,7 +246,6 @@ def __init__(
213246
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
214247
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
215248
timeout: float = 300.0,
216-
jwt_parameters: JWTParameters | None = None,
217249
):
218250
"""Initialize OAuth2 authentication."""
219251
self.context = OAuthContext(
@@ -223,7 +255,6 @@ def __init__(
223255
redirect_handler=redirect_handler,
224256
callback_handler=callback_handler,
225257
timeout=timeout,
226-
jwt_parameters=jwt_parameters,
227258
)
228259
self._initialized = False
229260

@@ -334,16 +365,9 @@ async def _handle_registration_response(self, response: httpx.Response) -> None:
334365

335366
async def _perform_authorization(self) -> httpx.Request:
336367
"""Perform the authorization flow."""
337-
if "client_credentials" in self.context.client_metadata.grant_types:
338-
token_request = await self._exchange_token_client_credentials()
339-
return token_request
340-
elif "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
341-
token_request = await self._exchange_token_jwt_bearer()
342-
return token_request
343-
else:
344-
auth_code, code_verifier = await self._perform_authorization_code_grant()
345-
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
346-
return token_request
368+
auth_code, code_verifier = await self._perform_authorization_code_grant()
369+
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
370+
return token_request
347371

348372
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
349373
"""Perform the authorization redirect and get auth code."""
@@ -406,21 +430,25 @@ def _get_token_endpoint(self) -> str:
406430
token_url = urljoin(auth_base_url, "/token")
407431
return token_url
408432

409-
async def _exchange_token_authorization_code(self, auth_code: str, code_verifier: str) -> httpx.Request:
433+
async def _exchange_token_authorization_code(
434+
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] = {}
435+
) -> httpx.Request:
410436
"""Build token exchange request for authorization_code flow."""
411437
if self.context.client_metadata.redirect_uris is None:
412438
raise OAuthFlowError("No redirect URIs provided for authorization code grant")
413439
if not self.context.client_info:
414440
raise OAuthFlowError("Missing client info")
415441

416442
token_url = self._get_token_endpoint()
417-
token_data = {
418-
"grant_type": "authorization_code",
419-
"code": auth_code,
420-
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
421-
"client_id": self.context.client_info.client_id,
422-
"code_verifier": code_verifier,
423-
}
443+
token_data.update(
444+
{
445+
"grant_type": "authorization_code",
446+
"code": auth_code,
447+
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
448+
"client_id": self.context.client_info.client_id,
449+
"code_verifier": code_verifier,
450+
}
451+
)
424452

425453
# Only include resource param if conditions are met
426454
if self.context.should_include_resource_param(self.context.protocol_version):
@@ -433,131 +461,6 @@ async def _exchange_token_authorization_code(self, auth_code: str, code_verifier
433461
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
434462
)
435463

436-
async def _exchange_token_client_credentials(self) -> httpx.Request:
437-
"""Build token exchange request for client_credentials flow."""
438-
if not self.context.client_info:
439-
raise OAuthFlowError("Missing client info")
440-
441-
token_url = self._get_token_endpoint()
442-
token_data = {
443-
"grant_type": "client_credentials",
444-
}
445-
446-
headers = {"Content-Type": "application/x-www-form-urlencoded"}
447-
448-
# Only include resource param if conditions are met
449-
if self.context.should_include_resource_param(self.context.protocol_version):
450-
token_data["resource"] = self.context.get_resource_url() # RFC 8707
451-
452-
if self.context.client_metadata.scope:
453-
token_data["scope"] = self.context.client_metadata.scope
454-
455-
if self.context.client_metadata.token_endpoint_auth_method == "client_secret_post":
456-
# Include in request body
457-
if self.context.client_info.client_id:
458-
token_data["client_id"] = self.context.client_info.client_id
459-
if self.context.client_info.client_secret:
460-
token_data["client_secret"] = self.context.client_info.client_secret
461-
elif self.context.client_metadata.token_endpoint_auth_method == "client_secret_basic":
462-
# Include as Basic auth header
463-
if not self.context.client_info.client_id:
464-
raise OAuthTokenError("Missing client_id in Basic auth flow")
465-
if not self.context.client_info.client_secret:
466-
raise OAuthTokenError("Missing client_secret in Basic auth flow")
467-
raw_auth = f"{self.context.client_info.client_id}:{self.context.client_info.client_secret}"
468-
headers["Authorization"] = f"Basic {base64.b64encode(raw_auth.encode()).decode()}"
469-
elif self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
470-
# Use JWT assertion for client authentication
471-
if not self.context.jwt_parameters:
472-
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
473-
474-
if self.context.jwt_parameters.assertion is not None:
475-
# Prebuilt JWT (e.g. acquired out-of-band)
476-
assertion = self.context.jwt_parameters.assertion
477-
else:
478-
if not self.context.jwt_parameters.jwt_signing_key:
479-
raise OAuthTokenError("Missing JWT signing key for private_key_jwt flow")
480-
if not self.context.jwt_parameters.jwt_signing_algorithm:
481-
raise OAuthTokenError("Missing JWT signing algorithm for private_key_jwt flow")
482-
483-
now = int(time.time())
484-
claims = {
485-
"iss": self.context.jwt_parameters.issuer,
486-
"sub": self.context.jwt_parameters.subject,
487-
"aud": self.context.jwt_parameters.audience if self.context.jwt_parameters.audience else token_url,
488-
"exp": now + self.context.jwt_parameters.jwt_lifetime_seconds,
489-
"iat": now,
490-
"jti": str(uuid4()),
491-
}
492-
claims.update(self.context.jwt_parameters.claims or {})
493-
494-
assertion = jwt.encode(
495-
claims,
496-
self.context.jwt_parameters.jwt_signing_key,
497-
algorithm=self.context.jwt_parameters.jwt_signing_algorithm or "RS256",
498-
)
499-
500-
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
501-
token_data["client_assertion"] = assertion
502-
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
503-
# We need to set the audience to the token endpoint, the audience is difference from the one in claims
504-
# it represents the resource server that will validate the token
505-
token_data["audience"] = self.context.get_resource_url()
506-
507-
return httpx.Request("POST", token_url, data=token_data, headers=headers)
508-
509-
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
510-
"""Build token exchange request for JWT bearer grant."""
511-
if not self.context.client_info:
512-
raise OAuthFlowError("Missing client info")
513-
if not self.context.jwt_parameters:
514-
raise OAuthFlowError("Missing JWT parameters")
515-
516-
token_url = self._get_token_endpoint()
517-
518-
if self.context.jwt_parameters.assertion is not None:
519-
# Prebuilt JWT (e.g. acquired out-of-band)
520-
assertion = self.context.jwt_parameters.assertion
521-
else:
522-
if not self.context.jwt_parameters.jwt_signing_key:
523-
raise OAuthFlowError("Missing signing key for JWT bearer grant")
524-
if not self.context.jwt_parameters.issuer:
525-
raise OAuthFlowError("Missing issuer for JWT bearer grant")
526-
if not self.context.jwt_parameters.subject:
527-
raise OAuthFlowError("Missing subject for JWT bearer grant")
528-
529-
now = int(time.time())
530-
claims = {
531-
"iss": self.context.jwt_parameters.issuer,
532-
"sub": self.context.jwt_parameters.subject,
533-
"aud": token_url,
534-
"exp": now + self.context.jwt_parameters.jwt_lifetime_seconds,
535-
"iat": now,
536-
"jti": str(uuid4()),
537-
}
538-
claims.update(self.context.jwt_parameters.claims or {})
539-
540-
assertion = jwt.encode(
541-
claims,
542-
self.context.jwt_parameters.jwt_signing_key,
543-
algorithm=self.context.jwt_parameters.jwt_signing_algorithm or "RS256",
544-
)
545-
546-
token_data = {
547-
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
548-
"assertion": assertion,
549-
}
550-
551-
if self.context.should_include_resource_param(self.context.protocol_version):
552-
token_data["resource"] = self.context.get_resource_url()
553-
554-
if self.context.client_metadata.scope:
555-
token_data["scope"] = self.context.client_metadata.scope
556-
557-
return httpx.Request(
558-
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
559-
)
560-
561464
async def _handle_token_response(self, response: httpx.Response) -> None:
562465
"""Handle token exchange response."""
563466
if response.status_code != 200:
@@ -720,3 +623,78 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
720623
# Retry with new tokens
721624
self._add_auth_header(request)
722625
yield request
626+
627+
628+
class RFC7523OAuthClientProvider(OAuthClientProvider):
629+
"""OAuth client provider for RFC7532 clients."""
630+
631+
jwt_parameters: JWTParameters | None = None
632+
633+
def __init__(
634+
self,
635+
server_url: str,
636+
client_metadata: OAuthClientMetadata,
637+
storage: TokenStorage,
638+
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
639+
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
640+
timeout: float = 300.0,
641+
jwt_parameters: JWTParameters | None = None,
642+
) -> None:
643+
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
644+
self.jwt_parameters = jwt_parameters
645+
646+
async def _exchange_token_authorization_code(
647+
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] = {}
648+
) -> httpx.Request:
649+
"""Build token exchange request for authorization_code flow."""
650+
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
651+
self._add_client_authentication_jwt(token_data=token_data)
652+
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
653+
654+
async def _perform_authorization(self) -> httpx.Request:
655+
"""Perform the authorization flow."""
656+
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
657+
token_request = await self._exchange_token_jwt_bearer()
658+
return token_request
659+
else:
660+
return await super()._perform_authorization()
661+
662+
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]):
663+
"""Add JWT assertion for client authentication to token endpoint parameters."""
664+
if not self.jwt_parameters:
665+
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
666+
667+
token_url = self._get_token_endpoint()
668+
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=token_url)
669+
670+
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
671+
token_data["client_assertion"] = assertion
672+
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
673+
# We need to set the audience to the token endpoint, the audience is difference from the one in claims
674+
# it represents the resource server that will validate the token
675+
token_data["audience"] = self.context.get_resource_url()
676+
677+
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
678+
"""Build token exchange request for JWT bearer grant."""
679+
if not self.context.client_info:
680+
raise OAuthFlowError("Missing client info")
681+
if not self.jwt_parameters:
682+
raise OAuthFlowError("Missing JWT parameters")
683+
684+
token_url = self._get_token_endpoint()
685+
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=token_url)
686+
687+
token_data = {
688+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
689+
"assertion": assertion,
690+
}
691+
692+
if self.context.should_include_resource_param(self.context.protocol_version):
693+
token_data["resource"] = self.context.get_resource_url()
694+
695+
if self.context.client_metadata.scope:
696+
token_data["scope"] = self.context.client_metadata.scope
697+
698+
return httpx.Request(
699+
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
700+
)

src/mcp/shared/auth.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,9 @@ class OAuthClientMetadata(BaseModel):
4343

4444
redirect_uris: list[AnyUrl] | None = Field(..., min_length=1)
4545
# supported auth methods for the token endpoint
46-
token_endpoint_auth_method: Literal["none", "client_secret_basic", "client_secret_post", "private_key_jwt"] = (
47-
"client_secret_post"
48-
)
46+
token_endpoint_auth_method: Literal["none", "client_secret_post", "private_key_jwt"] = "client_secret_post"
4947
# supported grant_types of this implementation
50-
grant_types: list[
51-
Literal[
52-
"authorization_code", "client_credentials", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"
53-
]
54-
] = [
48+
grant_types: list[Literal["authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"]] = [
5549
"authorization_code",
5650
"refresh_token",
5751
]

0 commit comments

Comments
 (0)