Skip to content

Commit 7848e68

Browse files
committed
Fix tests and pyright errors
1 parent d9c751f commit 7848e68

File tree

5 files changed

+58
-40
lines changed

5 files changed

+58
-40
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ async def main():
814814
The SDK includes [authorization support](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization) for connecting to protected MCP servers:
815815

816816
```python
817-
from mcp.client.auth import OAuthClientProvider, ClientCredentialsProvider, TokenStorage
817+
from mcp.client.auth import OAuthClientProvider, TokenStorage
818818
from mcp.client.session import ClientSession
819819
from mcp.client.streamable_http import streamablehttp_client
820820
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken

examples/servers/simple-auth/mcp_simple_auth/server.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,24 @@ async def exchange_refresh_token(
247247
"""Exchange refresh token"""
248248
raise NotImplementedError("Not supported")
249249

250+
async def exchange_client_credentials(
251+
self, client: OAuthClientInformationFull, scopes: list[str]
252+
) -> OAuthToken:
253+
"""Exchange client credentials for an access token."""
254+
token = f"mcp_{secrets.token_hex(32)}"
255+
self.tokens[token] = AccessToken(
256+
token=token,
257+
client_id=client.client_id,
258+
scopes=scopes,
259+
expires_at=int(time.time()) + 3600,
260+
)
261+
return OAuthToken(
262+
access_token=token,
263+
token_type="bearer",
264+
expires_in=3600,
265+
scope=" ".join(scopes),
266+
)
267+
250268
async def revoke_token(
251269
self, token: str, token_type_hint: str | None = None
252270
) -> None:

src/mcp/server/auth/handlers/register.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async def handle(self, request: Request) -> Response:
7474
),
7575
status_code=400,
7676
)
77-
grant_types_set = set(client_metadata.grant_types)
77+
grant_types_set: set[str] = set(client_metadata.grant_types)
7878
valid_sets = [
7979
{"authorization_code", "refresh_token"},
8080
{"client_credentials"},

tests/client/test_auth.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from inline_snapshot import snapshot
1414
from pydantic import AnyHttpUrl
1515

16-
from mcp.client.auth import ClientCredentialsProvider, OAuthClientProvider
16+
from mcp.client.auth import (
17+
ClientCredentialsProvider,
18+
OAuthClientProvider,
19+
_discover_oauth_metadata,
20+
_get_authorization_base_url,
21+
)
1722
from mcp.server.auth.routes import build_metadata
1823
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
1924
from mcp.shared.auth import (
@@ -190,21 +195,19 @@ def test_get_authorization_base_url(self, oauth_provider):
190195
"""Test authorization base URL extraction."""
191196
# Test with path
192197
assert (
193-
oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp")
198+
_get_authorization_base_url("https://api.example.com/v1/mcp")
194199
== "https://api.example.com"
195200
)
196201

197202
# Test with no path
198203
assert (
199-
oauth_provider._get_authorization_base_url("https://api.example.com")
204+
_get_authorization_base_url("https://api.example.com")
200205
== "https://api.example.com"
201206
)
202207

203208
# Test with port
204209
assert (
205-
oauth_provider._get_authorization_base_url(
206-
"https://api.example.com:8080/path/to/mcp"
207-
)
210+
_get_authorization_base_url("https://api.example.com:8080/path/to/mcp")
208211
== "https://api.example.com:8080"
209212
)
210213

@@ -224,7 +227,7 @@ async def test_discover_oauth_metadata_success(
224227
mock_response.json.return_value = metadata_response
225228
mock_client.get.return_value = mock_response
226229

227-
result = await oauth_provider._discover_oauth_metadata(
230+
result = await _discover_oauth_metadata(
228231
"https://api.example.com/v1/mcp"
229232
)
230233

@@ -253,7 +256,7 @@ async def test_discover_oauth_metadata_not_found(self, oauth_provider):
253256
mock_response.status_code = 404
254257
mock_client.get.return_value = mock_response
255258

256-
result = await oauth_provider._discover_oauth_metadata(
259+
result = await _discover_oauth_metadata(
257260
"https://api.example.com/v1/mcp"
258261
)
259262

@@ -280,7 +283,7 @@ async def test_discover_oauth_metadata_cors_fallback(
280283
mock_response_success, # Second call succeeds
281284
]
282285

283-
result = await oauth_provider._discover_oauth_metadata(
286+
result = await _discover_oauth_metadata(
284287
"https://api.example.com/v1/mcp"
285288
)
286289

@@ -334,9 +337,7 @@ async def test_register_oauth_client_fallback_endpoint(
334337
mock_client.post.return_value = mock_response
335338

336339
# Mock metadata discovery to return None (fallback)
337-
with patch.object(
338-
oauth_provider, "_discover_oauth_metadata", return_value=None
339-
):
340+
with patch("mcp.client.auth._discover_oauth_metadata", return_value=None):
340341
result = await oauth_provider._register_oauth_client(
341342
"https://api.example.com/v1/mcp",
342343
oauth_provider.client_metadata,
@@ -363,9 +364,7 @@ async def test_register_oauth_client_failure(self, oauth_provider):
363364
mock_client.post.return_value = mock_response
364365

365366
# Mock metadata discovery to return None (fallback)
366-
with patch.object(
367-
oauth_provider, "_discover_oauth_metadata", return_value=None
368-
):
367+
with patch("mcp.client.auth._discover_oauth_metadata", return_value=None):
369368
with pytest.raises(httpx.HTTPStatusError):
370369
await oauth_provider._register_oauth_client(
371370
"https://api.example.com/v1/mcp",
@@ -993,26 +992,26 @@ def test_build_metadata(
993992
revocation_options=RevocationOptions(enabled=True),
994993
)
995994

996-
assert metadata == snapshot(
997-
OAuthMetadata(
998-
issuer=AnyHttpUrl(issuer_url),
999-
authorization_endpoint=AnyHttpUrl(authorization_endpoint),
1000-
token_endpoint=AnyHttpUrl(token_endpoint),
1001-
registration_endpoint=AnyHttpUrl(registration_endpoint),
1002-
scopes_supported=["read", "write", "admin"],
1003-
grant_types_supported=[
1004-
"authorization_code",
1005-
"refresh_token",
1006-
"client_credentials",
1007-
],
1008-
token_endpoint_auth_methods_supported=["client_secret_post"],
1009-
service_documentation=AnyHttpUrl(service_documentation_url),
1010-
revocation_endpoint=AnyHttpUrl(revocation_endpoint),
1011-
revocation_endpoint_auth_methods_supported=["client_secret_post"],
1012-
code_challenge_methods_supported=["S256"],
1013-
)
995+
expected = OAuthMetadata(
996+
issuer=AnyHttpUrl(issuer_url),
997+
authorization_endpoint=AnyHttpUrl(authorization_endpoint),
998+
token_endpoint=AnyHttpUrl(token_endpoint),
999+
registration_endpoint=AnyHttpUrl(registration_endpoint),
1000+
scopes_supported=["read", "write", "admin"],
1001+
grant_types_supported=[
1002+
"authorization_code",
1003+
"refresh_token",
1004+
"client_credentials",
1005+
],
1006+
token_endpoint_auth_methods_supported=["client_secret_post"],
1007+
service_documentation=AnyHttpUrl(service_documentation_url),
1008+
revocation_endpoint=AnyHttpUrl(revocation_endpoint),
1009+
revocation_endpoint_auth_methods_supported=["client_secret_post"],
1010+
code_challenge_methods_supported=["S256"],
10141011
)
10151012

1013+
assert metadata == expected
1014+
10161015

10171016
class TestClientCredentialsProvider:
10181017
@pytest.mark.anyio

tests/server/fastmcp/resources/test_file_resources.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,12 @@ async def test_missing_file_error(self, temp_file: Path):
100100
with pytest.raises(ValueError, match="Error reading file"):
101101
await resource.read()
102102

103-
@pytest.mark.skipif(
104-
os.name == "nt", reason="File permissions behave differently on Windows"
105-
)
106-
@pytest.mark.anyio
107-
async def test_permission_error(self, temp_file: Path):
103+
@pytest.mark.skipif(
104+
os.name == "nt" or getattr(os, "geteuid", lambda: 0)() == 0,
105+
reason="File permissions behave differently on Windows or when running as root",
106+
)
107+
@pytest.mark.anyio
108+
async def test_permission_error(self, temp_file: Path):
108109
"""Test reading a file without permissions."""
109110
temp_file.chmod(0o000) # Remove all permissions
110111
try:

0 commit comments

Comments
 (0)