diff --git a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py index 5228d034e4..fed8a70e77 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py +++ b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py @@ -4,7 +4,7 @@ from typing import Any from mcp.server.auth.provider import AccessToken, TokenVerifier -from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url +from mcp.shared._auth_utils import check_resource_allowed, resource_url_from_server_url logger = logging.getLogger(__name__) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index ddc61ef663..abef62bdf7 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -4,6 +4,8 @@ Implements authorization code flow with PKCE and automatic token refresh. """ +from __future__ import annotations as _annotations + import base64 import hashlib import logging @@ -13,11 +15,11 @@ from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field from typing import Any, Protocol -from urllib.parse import quote, urlencode, urljoin, urlparse +from urllib.parse import quote, urlencode, urljoin, urlparse, urlsplit, urlunsplit import anyio import httpx -from pydantic import BaseModel, Field, ValidationError +from pydantic import AnyUrl, BaseModel, Field, HttpUrl, ValidationError from mcp.client.auth.exceptions import OAuthFlowError, OAuthTokenError from mcp.client.auth.utils import ( @@ -45,11 +47,6 @@ OAuthToken, ProtectedResourceMetadata, ) -from mcp.shared.auth_utils import ( - calculate_token_expiry, - check_resource_allowed, - resource_url_from_server_url, -) logger = logging.getLogger(__name__) @@ -61,7 +58,7 @@ class PKCEParameters(BaseModel): code_challenge: str = Field(..., min_length=43, max_length=128) @classmethod - def generate(cls) -> "PKCEParameters": + def generate(cls) -> PKCEParameters: """Generate new PKCE parameters.""" code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) digest = hashlib.sha256(code_verifier.encode()).digest() @@ -74,19 +71,15 @@ class TokenStorage(Protocol): async def get_tokens(self) -> OAuthToken | None: """Get stored tokens.""" - ... async def set_tokens(self, tokens: OAuthToken) -> None: """Store tokens.""" - ... async def get_client_info(self) -> OAuthClientInformationFull | None: """Get stored client information.""" - ... async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: """Store client information.""" - ... @dataclass @@ -124,7 +117,7 @@ def get_authorization_base_url(self, server_url: str) -> str: def update_token_expiry(self, token: OAuthToken) -> None: """Update token expiry time using shared util function.""" - self.token_expiry_time = calculate_token_expiry(token.expires_in) + self.token_expiry_time = _calculate_token_expiry(token.expires_in) def is_token_valid(self) -> bool: """Check if current token is valid.""" @@ -148,12 +141,12 @@ def get_resource_url(self) -> str: Uses PRM resource if it's a valid parent, otherwise uses canonical server URL. """ - resource = resource_url_from_server_url(self.server_url) + resource = _resource_url_from_server_url(self.server_url) # If PRM provides a resource that's a valid parent, use it if self.protected_resource_metadata and self.protected_resource_metadata.resource: prm_resource = str(self.protected_resource_metadata.resource) - if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): + if _check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): resource = prm_resource return resource @@ -614,3 +607,82 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Retry with new tokens self._add_auth_header(request) yield request + + +def _resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str: + """Convert server URL to canonical resource URL per RFC 8707. + + RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component". + Returns absolute URI with lowercase scheme/host for canonical form. + + Args: + url: Server URL to convert + + Returns: + Canonical resource URL string + """ + # Convert to string if needed + url_str = str(url) + + # Parse the URL and remove fragment, create canonical form + parsed = urlsplit(url_str) + canonical = urlunsplit(parsed._replace(scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), fragment="")) + + return canonical + + +def _check_resource_allowed(requested_resource: str, configured_resource: str) -> bool: + """Check if a requested resource URL matches a configured resource URL. + + A requested resource matches if it has the same scheme, domain, port, + and its path starts with the configured resource's path. This allows + hierarchical matching where a token for a parent resource can be used + for child resources. + + Args: + requested_resource: The resource URL being requested + configured_resource: The resource URL that has been configured + + Returns: + True if the requested resource matches the configured resource + """ + # Parse both URLs + requested = urlparse(requested_resource) + configured = urlparse(configured_resource) + + # Compare scheme, host, and port (origin) + if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower(): + return False + + # Handle cases like requested=/foo and configured=/foo/ + requested_path = requested.path + configured_path = configured.path + + # If requested path is shorter, it cannot be a child + if len(requested_path) < len(configured_path): + return False + + # Check if the requested path starts with the configured path + # Ensure both paths end with / for proper comparison + # This ensures that paths like "/api123" don't incorrectly match "/api" + if not requested_path.endswith("/"): + requested_path += "/" + if not configured_path.endswith("/"): + configured_path += "/" + + return requested_path.startswith(configured_path) + + +def _calculate_token_expiry(expires_in: int | str | None) -> float | None: + """Calculate token expiry timestamp from expires_in seconds. + + Args: + expires_in: Seconds until token expiration (may be string from some servers) + + Returns: + Unix timestamp when token expires, or None if no expiry specified + """ + if expires_in is None: + return None # pragma: no cover + # Defensive: handle servers that return expires_in as string + return time.time() + int(expires_in) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index b61bf0b030..e77027e893 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,3 +1,5 @@ +from __future__ import annotations as _annotations + import logging from typing import Any, Protocol, overload @@ -22,7 +24,7 @@ class SamplingFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch @@ -30,14 +32,14 @@ async def __call__( class ElicitationFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch class ListRootsFnT(Protocol): async def __call__( - self, context: RequestContext["ClientSession", Any] + self, context: RequestContext[ClientSession, Any] ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch @@ -62,7 +64,7 @@ async def _default_message_handler( async def _default_sampling_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: return types.ErrorData( @@ -72,7 +74,7 @@ async def _default_sampling_callback( async def _default_elicitation_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession, Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: return types.ErrorData( # pragma: no cover @@ -82,7 +84,7 @@ async def _default_elicitation_callback( async def _default_list_roots_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext[ClientSession, Any], ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index c9ce81d204..fb246742f9 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -31,7 +31,7 @@ class SseServerParameters(BaseModel): - """Parameters for intializing a sse_client.""" + """Parameters for initializing a sse_client.""" # The endpoint URL. url: str diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index e8c8d9af87..94ebd4ec25 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -1,5 +1,6 @@ +from __future__ import annotations as _annotations + import json -import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager @@ -12,8 +13,6 @@ import mcp.types as types from mcp.shared.message import SessionMessage -logger = logging.getLogger(__name__) - @asynccontextmanager async def websocket_client( @@ -64,10 +63,7 @@ async def ws_reader(): await read_stream_writer.send(exc) async def ws_writer(): - """ - Reads JSON-RPC messages from write_stream_reader and - sends them to the server. - """ + """Reads JSON-RPC messages from write_stream_reader and sends them to the server.""" async with write_stream_reader: async for session_message in write_stream_reader: # Convert to a dict, then to JSON diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index 1ae6d90d19..046a9ea9a1 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -11,8 +11,8 @@ from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.utilities.context_injection import find_context_parameter from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata +from mcp.shared._tool_name_validation import validate_and_warn_tool_name from mcp.shared.exceptions import UrlElicitationRequiredError -from mcp.shared.tool_name_validation import validate_and_warn_tool_name from mcp.types import Icon, ToolAnnotations if TYPE_CHECKING: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 3385e72c44..034911f43f 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -89,11 +89,11 @@ async def main(): from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared._tool_name_validation import validate_and_warn_tool_name from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError, UrlElicitationRequiredError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder -from mcp.shared.tool_name_validation import validate_and_warn_tool_name logger = logging.getLogger(__name__) diff --git a/src/mcp/server/models.py b/src/mcp/server/models.py index eb972e33a5..311b33086a 100644 --- a/src/mcp/server/models.py +++ b/src/mcp/server/models.py @@ -5,10 +5,7 @@ from pydantic import BaseModel -from mcp.types import ( - Icon, - ServerCapabilities, -) +from mcp.types import Icon, ServerCapabilities class InitializationOptions(BaseModel): diff --git a/src/mcp/shared/tool_name_validation.py b/src/mcp/shared/_tool_name_validation.py similarity index 98% rename from src/mcp/shared/tool_name_validation.py rename to src/mcp/shared/_tool_name_validation.py index f35efa5a61..e62eb93733 100644 --- a/src/mcp/shared/tool_name_validation.py +++ b/src/mcp/shared/_tool_name_validation.py @@ -9,7 +9,7 @@ See: https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names """ -from __future__ import annotations +from __future__ import annotations as _annotations import logging import re diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py deleted file mode 100644 index 8f3c542f22..0000000000 --- a/src/mcp/shared/auth_utils.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636).""" - -import time -from urllib.parse import urlparse, urlsplit, urlunsplit - -from pydantic import AnyUrl, HttpUrl - - -def resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str: - """Convert server URL to canonical resource URL per RFC 8707. - - RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component". - Returns absolute URI with lowercase scheme/host for canonical form. - - Args: - url: Server URL to convert - - Returns: - Canonical resource URL string - """ - # Convert to string if needed - url_str = str(url) - - # Parse the URL and remove fragment, create canonical form - parsed = urlsplit(url_str) - canonical = urlunsplit(parsed._replace(scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), fragment="")) - - return canonical - - -def check_resource_allowed(requested_resource: str, configured_resource: str) -> bool: - """Check if a requested resource URL matches a configured resource URL. - - A requested resource matches if it has the same scheme, domain, port, - and its path starts with the configured resource's path. This allows - hierarchical matching where a token for a parent resource can be used - for child resources. - - Args: - requested_resource: The resource URL being requested - configured_resource: The resource URL that has been configured - - Returns: - True if the requested resource matches the configured resource - """ - # Parse both URLs - requested = urlparse(requested_resource) - configured = urlparse(configured_resource) - - # Compare scheme, host, and port (origin) - if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower(): - return False - - # Handle cases like requested=/foo and configured=/foo/ - requested_path = requested.path - configured_path = configured.path - - # If requested path is shorter, it cannot be a child - if len(requested_path) < len(configured_path): - return False - - # Check if the requested path starts with the configured path - # Ensure both paths end with / for proper comparison - # This ensures that paths like "/api123" don't incorrectly match "/api" - if not requested_path.endswith("/"): - requested_path += "/" - if not configured_path.endswith("/"): - configured_path += "/" - - return requested_path.startswith(configured_path) - - -def calculate_token_expiry(expires_in: int | str | None) -> float | None: - """Calculate token expiry timestamp from expires_in seconds. - - Args: - expires_in: Seconds until token expiration (may be string from some servers) - - Returns: - Unix timestamp when token expires, or None if no expiry specified - """ - if expires_in is None: - return None # pragma: no cover - # Defensive: handle servers that return expires_in as string - return time.time() + int(expires_in) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index c7c6dbabc2..eaaa92d3e5 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -26,8 +26,7 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS Creates a pair of bidirectional memory streams for client-server communication. Returns: - A tuple of (client_streams, server_streams) where each is a tuple of - (read_stream, write_stream) + A tuple of (client_streams, server_streams) where each is a tuple of (read_stream, write_stream). """ # Create streams for both directions server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -36,12 +35,7 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS client_streams = (server_to_client_receive, client_to_server_send) server_streams = (client_to_server_receive, server_to_client_send) - async with ( - server_to_client_receive, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - ): + async with server_to_client_receive, client_to_server_send, client_to_server_receive, server_to_client_send: yield client_streams, server_streams diff --git a/src/mcp/shared/response_router.py b/src/mcp/shared/response_router.py index 31796157fe..43042e0d9f 100644 --- a/src/mcp/shared/response_router.py +++ b/src/mcp/shared/response_router.py @@ -45,9 +45,8 @@ def route_response(self, request_id: RequestId, response: dict[str, Any]) -> boo response: The response result data Returns: - True if the response was handled, False otherwise + True if the response was handled, False otherwise. """ - ... # pragma: no cover def route_error(self, request_id: RequestId, error: ErrorData) -> bool: """ @@ -58,6 +57,5 @@ def route_error(self, request_id: RequestId, error: ErrorData) -> bool: error: The error data Returns: - True if the error was handled, False otherwise + True if the error was handled, False otherwise. """ - ... # pragma: no cover diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index f5d76b77ac..dad1216f86 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -268,12 +268,7 @@ async def send_request( self._progress_callbacks[request_id] = progress_callback try: - jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - **request_data, - ) - + jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) # request read timeout takes precedence over session read timeout @@ -292,8 +287,7 @@ async def send_request( code=httpx.codes.REQUEST_TIMEOUT, message=( f"Timed out while waiting for response to " - f"{request.__class__.__name__}. Waited " - f"{timeout} seconds." + f"{request.__class__.__name__}. Waited {timeout} seconds." ), ) ) @@ -345,10 +339,7 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er await self._write_stream.send(session_message) async def _receive_loop(self) -> None: - async with ( - self._read_stream, - self._write_stream, - ): + async with self._read_stream, self._write_stream: try: async for message in self._read_stream: if isinstance(message, Exception): # pragma: no cover @@ -543,4 +534,3 @@ async def _handle_incoming( req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, ) -> None: """A generic handler for incoming messages. Overwritten by subclasses.""" - pass # pragma: no cover diff --git a/tests/shared/test_auth_utils.py b/tests/shared/test_auth_utils.py index 5b12dc6775..d8a666e973 100644 --- a/tests/shared/test_auth_utils.py +++ b/tests/shared/test_auth_utils.py @@ -1,6 +1,9 @@ """Tests for OAuth 2.0 Resource Indicators utilities.""" -from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url +# TODO(Marcelo): We should not have tests for private functions. We can test the public functions with the behavior +# that uses these private functions instead. +from mcp.client.auth.oauth2 import _check_resource_allowed as check_resource_allowed +from mcp.client.auth.oauth2 import _resource_url_from_server_url as resource_url_from_server_url class TestResourceUrlFromServerUrl: diff --git a/tests/shared/test_tool_name_validation.py b/tests/shared/test_tool_name_validation.py index 4746f3f9f8..d0cd332eec 100644 --- a/tests/shared/test_tool_name_validation.py +++ b/tests/shared/test_tool_name_validation.py @@ -4,7 +4,7 @@ import pytest -from mcp.shared.tool_name_validation import ( +from mcp.shared._tool_name_validation import ( issue_tool_name_warning, validate_and_warn_tool_name, validate_tool_name,