Skip to content

Commit 482c05e

Browse files
authored
Merge pull request #16 from sacha-development-stuff/codex/resolve-merge-conflicts-and-check-functionality
Fix merge conflicts and restore auth features
2 parents c1d0acc + 94cefe3 commit 482c05e

File tree

2 files changed

+90
-10
lines changed

2 files changed

+90
-10
lines changed

src/mcp/client/auth.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
486486
# Retry with new tokens
487487
self._add_auth_header(request)
488488
yield request
489+
490+
489491
class ClientCredentialsProvider(httpx.Auth):
490492
"""HTTPX auth using the OAuth2 client credentials grant."""
491493

@@ -508,19 +510,48 @@ def __init__(
508510

509511
self._token_lock = anyio.Lock()
510512

513+
def _get_authorization_base_url(self, server_url: str) -> str:
514+
"""Return base authorization server URL without path."""
515+
parsed = urlparse(server_url)
516+
return f"{parsed.scheme}://{parsed.netloc}"
517+
518+
async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None:
519+
"""Discover OAuth server metadata for client credentials."""
520+
auth_base_url = self._get_authorization_base_url(server_url)
521+
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
522+
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
523+
524+
async with httpx.AsyncClient() as client:
525+
try:
526+
response = await client.get(url, headers=headers)
527+
if response.status_code == 404:
528+
return None
529+
response.raise_for_status()
530+
return OAuthMetadata.model_validate(response.json())
531+
except Exception:
532+
try:
533+
response = await client.get(url)
534+
if response.status_code == 404:
535+
return None
536+
response.raise_for_status()
537+
return OAuthMetadata.model_validate(response.json())
538+
except Exception:
539+
logger.exception("Failed to discover OAuth metadata")
540+
return None
541+
511542
async def _register_oauth_client(
512543
self,
513544
server_url: str,
514545
client_metadata: OAuthClientMetadata,
515546
metadata: OAuthMetadata | None = None,
516547
) -> OAuthClientInformationFull:
517548
if not metadata:
518-
metadata = await _discover_oauth_metadata(server_url)
549+
metadata = await self._discover_oauth_metadata(server_url)
519550

520551
if metadata and metadata.registration_endpoint:
521552
registration_url = str(metadata.registration_endpoint)
522553
else:
523-
auth_base_url = _get_authorization_base_url(server_url)
554+
auth_base_url = self._get_authorization_base_url(server_url)
524555
registration_url = urljoin(auth_base_url, "/register")
525556

526557
if client_metadata.scope is None and metadata and metadata.scopes_supported is not None:
@@ -582,14 +613,14 @@ async def _get_or_register_client(self) -> OAuthClientInformationFull:
582613

583614
async def _request_token(self) -> None:
584615
if not self._metadata:
585-
self._metadata = await _discover_oauth_metadata(self.server_url)
616+
self._metadata = await self._discover_oauth_metadata(self.server_url)
586617

587618
client_info = await self._get_or_register_client()
588619

589620
if self._metadata and self._metadata.token_endpoint:
590621
token_url = str(self._metadata.token_endpoint)
591622
else:
592-
auth_base_url = _get_authorization_base_url(self.server_url)
623+
auth_base_url = self._get_authorization_base_url(self.server_url)
593624
token_url = urljoin(auth_base_url, "/token")
594625

595626
token_data = {
@@ -671,14 +702,14 @@ def __init__(
671702

672703
async def _request_token(self) -> None:
673704
if not self._metadata:
674-
self._metadata = await _discover_oauth_metadata(self.server_url)
705+
self._metadata = await self._discover_oauth_metadata(self.server_url)
675706

676707
client_info = await self._get_or_register_client()
677708

678709
if self._metadata and self._metadata.token_endpoint:
679710
token_url = str(self._metadata.token_endpoint)
680711
else:
681-
auth_base_url = _get_authorization_base_url(self.server_url)
712+
auth_base_url = self._get_authorization_base_url(self.server_url)
682713
token_url = urljoin(auth_base_url, "/token")
683714

684715
subject_token = await self.subject_token_supplier()

tests/client/test_auth.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,24 @@
22
Tests for refactored OAuth client authentication implementation.
33
"""
44

5-
import time
65
import asyncio
6+
import time
7+
from unittest.mock import AsyncMock, Mock, patch
78

89
import httpx
910
import pytest
1011
from pydantic import AnyHttpUrl, AnyUrl
11-
from unittest.mock import AsyncMock, Mock, patch
1212

13-
from mcp.client.auth import OAuthClientProvider, PKCEParameters
13+
from mcp.client.auth import (
14+
ClientCredentialsProvider,
15+
OAuthClientProvider,
16+
PKCEParameters,
17+
TokenExchangeProvider,
18+
)
1419
from mcp.shared.auth import (
1520
OAuthClientInformationFull,
1621
OAuthClientMetadata,
22+
OAuthMetadata,
1723
OAuthToken,
1824
)
1925

@@ -81,6 +87,8 @@ async def callback_handler() -> tuple[str, str | None]:
8187
redirect_handler=redirect_handler,
8288
callback_handler=callback_handler,
8389
)
90+
91+
8492
@pytest.fixture
8593
def client_credentials_metadata():
8694
return OAuthClientMetadata(
@@ -92,6 +100,45 @@ def client_credentials_metadata():
92100
token_endpoint_auth_method="client_secret_post",
93101
)
94102

103+
104+
@pytest.fixture
105+
def oauth_metadata():
106+
return OAuthMetadata(
107+
issuer=AnyHttpUrl("https://auth.example.com"),
108+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
109+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
110+
registration_endpoint=AnyHttpUrl("https://auth.example.com/register"),
111+
scopes_supported=["read", "write", "admin"],
112+
response_types_supported=["code"],
113+
grant_types_supported=["authorization_code", "refresh_token", "client_credentials"],
114+
code_challenge_methods_supported=["S256"],
115+
)
116+
117+
118+
@pytest.fixture
119+
def oauth_client_info():
120+
return OAuthClientInformationFull(
121+
client_id="test_client_id",
122+
client_secret="test_client_secret",
123+
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
124+
client_name="Test Client",
125+
grant_types=["authorization_code", "refresh_token"],
126+
response_types=["code"],
127+
scope="read write",
128+
)
129+
130+
131+
@pytest.fixture
132+
def oauth_token():
133+
return OAuthToken(
134+
access_token="test_access_token",
135+
token_type="bearer",
136+
expires_in=3600,
137+
refresh_token="test_refresh_token",
138+
scope="read write",
139+
)
140+
141+
95142
@pytest.fixture
96143
async def client_credentials_provider(client_credentials_metadata, mock_storage):
97144
return ClientCredentialsProvider(
@@ -100,6 +147,7 @@ async def client_credentials_provider(client_credentials_metadata, mock_storage)
100147
storage=mock_storage,
101148
)
102149

150+
103151
@pytest.fixture
104152
async def token_exchange_provider(client_credentials_metadata, mock_storage):
105153
return TokenExchangeProvider(
@@ -342,6 +390,8 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v
342390
await auth_flow.asend(response)
343391
except StopAsyncIteration:
344392
pass # Expected
393+
394+
345395
class TestClientCredentialsProvider:
346396
@pytest.mark.anyio
347397
async def test_request_token_success(
@@ -417,4 +467,3 @@ async def test_request_token_success(
417467

418468
mock_client.post.assert_called_once()
419469
assert token_exchange_provider._current_tokens.access_token == oauth_token.access_token
420-

0 commit comments

Comments
 (0)