Skip to content

Commit 95267ed

Browse files
committed
test step-up auth flow
1 parent 2efec15 commit 95267ed

File tree

2 files changed

+108
-15
lines changed

2 files changed

+108
-15
lines changed

src/mcp/client/auth.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,12 @@ def _configure_scope_selection(self, init_response: httpx.Response) -> None:
284284
if www_authenticate_scope is not None:
285285
# Priority 1: WWW-Authenticate header scope
286286
self.context.client_metadata.scope = www_authenticate_scope
287-
elif self.context.protected_resource_metadata is not None and self.context.protected_resource_metadata.scopes_supported is not None:
287+
elif (
288+
self.context.protected_resource_metadata is not None
289+
and self.context.protected_resource_metadata.scopes_supported is not None
290+
):
288291
# Priority 2: PRM scopes_supported
289-
self.context.client_metadata.scope = " ".join(
290-
self.context.protected_resource_metadata.scopes_supported
291-
)
292+
self.context.client_metadata.scope = " ".join(self.context.protected_resource_metadata.scopes_supported)
292293
else:
293294
# Priority 3: Omit scope parameter
294295
self.context.client_metadata.scope = None
@@ -592,12 +593,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
592593
self._add_auth_header(request)
593594
yield request
594595
elif response.status_code == 403:
595-
try:
596-
# Step 1: Extract error field from WWW-Authenticate header
597-
error = self._extract_field_from_www_auth(response, "error")
596+
# Step 1: Extract error field from WWW-Authenticate header
597+
error = self._extract_field_from_www_auth(response, "error")
598598

599-
# Step 2: Check if we need to step-up authorization
600-
if error == "insufficient_scope":
599+
# Step 2: Check if we need to step-up authorization
600+
if error == "insufficient_scope":
601+
try:
601602
# Step 2a: Update the required scopes
602603
self._configure_scope_selection(response)
603604

@@ -608,10 +609,10 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
608609
token_request = await self._exchange_token(auth_code, code_verifier)
609610
token_response = yield token_request
610611
await self._handle_token_response(token_response)
611-
except Exception:
612-
logger.exception("OAuth flow error")
613-
raise
612+
except Exception:
613+
logger.exception("OAuth flow error")
614+
raise
614615

615-
# Retry with new tokens
616-
self._add_auth_header(request)
617-
yield request
616+
# Retry with new tokens
617+
self._add_auth_header(request)
618+
yield request

tests/client/test_auth.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,98 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth(
858858
# Verify exactly one request was yielded (no double-sending)
859859
assert request_yields == 1, f"Expected 1 request yield, got {request_yields}"
860860

861+
@pytest.mark.anyio
862+
async def test_403_insufficient_scope_updates_scope_from_header(
863+
self,
864+
oauth_provider: OAuthClientProvider,
865+
mock_storage: MockTokenStorage,
866+
valid_tokens: OAuthToken,
867+
):
868+
"""Test that 403 response correctly updates scope from WWW-Authenticate header."""
869+
# Pre-store valid tokens and client info
870+
client_info = OAuthClientInformationFull(
871+
client_id="test_client_id",
872+
client_secret="test_client_secret",
873+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
874+
)
875+
await mock_storage.set_tokens(valid_tokens)
876+
await mock_storage.set_client_info(client_info)
877+
oauth_provider.context.current_tokens = valid_tokens
878+
oauth_provider.context.token_expiry_time = time.time() + 1800
879+
oauth_provider.context.client_info = client_info
880+
oauth_provider._initialized = True
881+
882+
# Original scope
883+
assert oauth_provider.context.client_metadata.scope == "read write"
884+
885+
redirect_captured = False
886+
captured_state = None
887+
888+
async def capture_redirect(url: str) -> None:
889+
nonlocal redirect_captured, captured_state
890+
redirect_captured = True
891+
# Verify the new scope is included in authorization URL
892+
assert "scope=admin%3Awrite+admin%3Adelete" in url or "scope=admin:write+admin:delete" in url.replace(
893+
"%3A", ":"
894+
).replace("+", " ")
895+
# Extract state from redirect URL
896+
from urllib.parse import parse_qs, urlparse
897+
898+
parsed = urlparse(url)
899+
params = parse_qs(parsed.query)
900+
captured_state = params.get("state", [None])[0]
901+
902+
oauth_provider.context.redirect_handler = capture_redirect
903+
904+
# Mock callback
905+
async def mock_callback() -> tuple[str, str | None]:
906+
return "auth_code", captured_state
907+
908+
oauth_provider.context.callback_handler = mock_callback
909+
910+
test_request = httpx.Request("GET", "https://api.example.com/mcp")
911+
auth_flow = oauth_provider.async_auth_flow(test_request)
912+
913+
# First request
914+
request = await auth_flow.__anext__()
915+
916+
# Send 403 with new scope requirement
917+
response_403 = httpx.Response(
918+
403,
919+
headers={"WWW-Authenticate": 'Bearer error="insufficient_scope", scope="admin:write admin:delete"'},
920+
request=request,
921+
)
922+
923+
# Trigger step-up - should get token exchange request
924+
token_exchange_request = await auth_flow.asend(response_403)
925+
926+
# Verify scope was updated
927+
assert oauth_provider.context.client_metadata.scope == "admin:write admin:delete"
928+
assert redirect_captured
929+
930+
# Complete the flow with successful token response
931+
token_response = httpx.Response(
932+
200,
933+
json={
934+
"access_token": "new_token_with_new_scope",
935+
"token_type": "Bearer",
936+
"expires_in": 3600,
937+
"scope": "admin:write admin:delete",
938+
},
939+
request=token_exchange_request,
940+
)
941+
942+
# Should get final retry request
943+
final_request = await auth_flow.asend(token_response)
944+
945+
# Send success response - flow should complete
946+
success_response = httpx.Response(200, request=final_request)
947+
try:
948+
await auth_flow.asend(success_response)
949+
pytest.fail("Should have stopped after successful response")
950+
except StopAsyncIteration:
951+
pass # Expected
952+
861953

862954
@pytest.mark.parametrize(
863955
(

0 commit comments

Comments
 (0)