diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index 05cf60bd1..343d53b97 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -6,6 +6,9 @@ on: permissions: contents: read +env: + COLUMNS: 150 + jobs: pre-commit: runs-on: ubuntu-latest diff --git a/README.md b/README.md index 94a5eab01..419410603 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession # Mock database class for example @@ -242,7 +243,7 @@ mcp = FastMCP("My App", lifespan=app_lifespan) # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context) -> str: +def query_db(ctx: Context[ServerSession, AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() @@ -314,12 +315,13 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -445,7 +447,7 @@ def get_user(user_id: str) -> UserProfile: # Classes WITHOUT type hints cannot be used for structured output class UntypedConfig: - def __init__(self, setting1, setting2): + def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] self.setting1 = setting1 self.setting2 = setting2 @@ -571,12 +573,13 @@ The Context object provides the following capabilities: ```python from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -694,6 +697,7 @@ Request additional information from users. This example shows an Elicitation dur from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession mcp = FastMCP(name="Elicitation Example") @@ -709,12 +713,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table( - date: str, - time: str, - party_size: int, - ctx: Context, -) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: """Book a table with date availability check.""" # Check if date is available if date == "2024-12-25": @@ -750,13 +749,14 @@ Tools can interact with LLMs through sampling (generating text): ```python from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent mcp = FastMCP(name="Sampling Example") @mcp.tool() -async def generate_poem(topic: str, ctx: Context) -> str: +async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: """Generate a poem using LLM sampling.""" prompt = f"Write a short poem about {topic}" @@ -785,12 +785,13 @@ Tools can send logs and notifications through the context: ```python from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession mcp = FastMCP(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context) -> str: +async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") @@ -1244,6 +1245,7 @@ Run from the repository root: from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from typing import Any import mcp.server.stdio import mcp.types as types @@ -1272,7 +1274,7 @@ class Database: @asynccontextmanager -async def server_lifespan(_server: Server) -> AsyncIterator[dict]: +async def server_lifespan(_server: Server) -> AsyncIterator[dict[str, Any]]: """Manage server startup and shutdown lifecycle.""" # Initialize resources on startup db = await Database.connect() @@ -1304,7 +1306,7 @@ async def handle_list_tools() -> list[types.Tool]: @server.call_tool() -async def query_db(name: str, arguments: dict) -> list[types.TextContent]: +async def query_db(name: str, arguments: dict[str, Any]) -> list[types.TextContent]: """Handle database query tool call.""" if name != "query_db": raise ValueError(f"Unknown tool: {name}") @@ -1558,7 +1560,7 @@ server_params = StdioServerParameters( # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext, params: types.CreateMessageRequestParams + context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 7a9e32279..19d6dcef8 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -188,9 +188,7 @@ async def _default_redirect_handler(authorization_url: str) -> None: # Create OAuth authentication handler using the new interface oauth_auth = OAuthClientProvider( server_url=self.server_url.replace("/mcp", ""), - client_metadata=OAuthClientMetadata.model_validate( - client_metadata_dict - ), + client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict), storage=InMemoryTokenStorage(), redirect_handler=_default_redirect_handler, callback_handler=callback_handler, @@ -322,9 +320,7 @@ async def interactive_loop(self): await self.call_tool(tool_name, arguments) else: - print( - "āŒ Unknown command. Try 'list', 'call ', or 'quit'" - ) + print("āŒ Unknown command. Try 'list', 'call ', or 'quit'") except KeyboardInterrupt: print("\n\nšŸ‘‹ Goodbye!") diff --git a/examples/clients/simple-auth-client/pyproject.toml b/examples/clients/simple-auth-client/pyproject.toml index 5ae7c6b9d..2b9308432 100644 --- a/examples/clients/simple-auth-client/pyproject.toml +++ b/examples/clients/simple-auth-client/pyproject.toml @@ -39,7 +39,7 @@ select = ["E", "F", "I"] ignore = [] [tool.ruff] -line-length = 88 +line-length = 120 target-version = "py310" [tool.uv] diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py index b97b85080..65e0dde03 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py @@ -12,9 +12,7 @@ from mcp.client.stdio import stdio_client # Configure logging -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") class Configuration: @@ -75,29 +73,19 @@ def __init__(self, name: str, config: dict[str, Any]) -> None: async def initialize(self) -> None: """Initialize the server connection.""" - command = ( - shutil.which("npx") - if self.config["command"] == "npx" - else self.config["command"] - ) + command = shutil.which("npx") if self.config["command"] == "npx" else self.config["command"] if command is None: raise ValueError("The command must be a valid string and cannot be None.") server_params = StdioServerParameters( command=command, args=self.config["args"], - env={**os.environ, **self.config["env"]} - if self.config.get("env") - else None, + env={**os.environ, **self.config["env"]} if self.config.get("env") else None, ) try: - stdio_transport = await self.exit_stack.enter_async_context( - stdio_client(server_params) - ) + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) read, write = stdio_transport - session = await self.exit_stack.enter_async_context( - ClientSession(read, write) - ) + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) await session.initialize() self.session = session except Exception as e: @@ -122,10 +110,7 @@ async def list_tools(self) -> list[Any]: for item in tools_response: if isinstance(item, tuple) and item[0] == "tools": - tools.extend( - Tool(tool.name, tool.description, tool.inputSchema, tool.title) - for tool in item[1] - ) + tools.extend(Tool(tool.name, tool.description, tool.inputSchema, tool.title) for tool in item[1]) return tools @@ -164,9 +149,7 @@ async def execute_tool( except Exception as e: attempt += 1 - logging.warning( - f"Error executing tool: {e}. Attempt {attempt} of {retries}." - ) + logging.warning(f"Error executing tool: {e}. Attempt {attempt} of {retries}.") if attempt < retries: logging.info(f"Retrying in {delay} seconds...") await asyncio.sleep(delay) @@ -209,9 +192,7 @@ def format_for_llm(self) -> str: args_desc = [] if "properties" in self.input_schema: for param_name, param_info in self.input_schema["properties"].items(): - arg_desc = ( - f"- {param_name}: {param_info.get('description', 'No description')}" - ) + arg_desc = f"- {param_name}: {param_info.get('description', 'No description')}" if param_name in self.input_schema.get("required", []): arg_desc += " (required)" args_desc.append(arg_desc) @@ -281,10 +262,7 @@ def get_response(self, messages: list[dict[str, str]]) -> str: logging.error(f"Status code: {status_code}") logging.error(f"Response details: {e.response.text}") - return ( - f"I encountered an error: {error_message}. " - "Please try again or rephrase your request." - ) + return f"I encountered an error: {error_message}. Please try again or rephrase your request." class ChatSession: @@ -323,17 +301,13 @@ async def process_llm_response(self, llm_response: str) -> str: tools = await server.list_tools() if any(tool.name == tool_call["tool"] for tool in tools): try: - result = await server.execute_tool( - tool_call["tool"], tool_call["arguments"] - ) + result = await server.execute_tool(tool_call["tool"], tool_call["arguments"]) if isinstance(result, dict) and "progress" in result: progress = result["progress"] total = result["total"] percentage = (progress / total) * 100 - logging.info( - f"Progress: {progress}/{total} ({percentage:.1f}%)" - ) + logging.info(f"Progress: {progress}/{total} ({percentage:.1f}%)") return f"Tool execution result: {result}" except Exception as e: @@ -408,9 +382,7 @@ async def start(self) -> None: final_response = self.llm_client.get_response(messages) logging.info("\nFinal response: %s", final_response) - messages.append( - {"role": "assistant", "content": final_response} - ) + messages.append({"role": "assistant", "content": final_response}) else: messages.append({"role": "assistant", "content": llm_response}) @@ -426,10 +398,7 @@ async def main() -> None: """Initialize and run the chat session.""" config = Configuration() server_config = config.load_config("servers_config.json") - servers = [ - Server(name, srv_config) - for name, srv_config in server_config["mcpServers"].items() - ] + servers = [Server(name, srv_config) for name, srv_config in server_config["mcpServers"].items()] llm_client = LLMClient(config.llm_api_key) chat_session = ChatSession(servers, llm_client) await chat_session.start() diff --git a/examples/clients/simple-chatbot/pyproject.toml b/examples/clients/simple-chatbot/pyproject.toml index d88b8f6d2..b699ecc32 100644 --- a/examples/clients/simple-chatbot/pyproject.toml +++ b/examples/clients/simple-chatbot/pyproject.toml @@ -41,7 +41,7 @@ select = ["E", "F", "I"] ignore = [] [tool.ruff] -line-length = 88 +line-length = 120 target-version = "py310" [tool.uv] diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 53595778b..ac449ebff 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -45,7 +45,8 @@ class ResourceServerSettings(BaseSettings): # RFC 8707 resource validation oauth_strict: bool = False - def __init__(self, **data): + # TODO(Marcelo): Is this even needed? I didn't have time to check. + def __init__(self, **data: Any): """Initialize settings with values from environment variables.""" super().__init__(**data) diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index aa813b542..0f1092d7d 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -46,7 +46,7 @@ class SimpleAuthSettings(BaseSettings): mcp_scope: str = "user" -class SimpleOAuthProvider(OAuthAuthorizationServerProvider): +class SimpleOAuthProvider(OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]): """ Simple OAuth provider for demo purposes. @@ -116,7 +116,7 @@ async def get_login_page(self, state: str) -> HTMLResponse:

This is a simplified authentication demo. Use the demo credentials below:

Username: demo_user
Password: demo_password

- +
@@ -264,7 +264,8 @@ async def exchange_refresh_token( """Exchange refresh token - not supported in this example.""" raise NotImplementedError("Refresh tokens not supported") - async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: + # TODO(Marcelo): The type hint is wrong. We need to fix, and test to check if it works. + async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: # type: ignore """Revoke a token.""" if token in self.tokens: del self.tokens[token] diff --git a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py index de3140238..5228d034e 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py +++ b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py @@ -1,6 +1,7 @@ """Example token verifier implementation using OAuth 2.0 Token Introspection (RFC 7662).""" import logging +from typing import Any from mcp.server.auth.provider import AccessToken, TokenVerifier from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url @@ -79,13 +80,13 @@ async def verify_token(self, token: str) -> AccessToken | None: logger.warning(f"Token introspection failed: {e}") return None - def _validate_resource(self, token_data: dict) -> bool: + def _validate_resource(self, token_data: dict[str, Any]) -> bool: """Validate token was issued for this resource server.""" if not self.server_url or not self.resource_url: return False # Fail if strict validation requested but URLs missing # Check 'aud' claim first (standard JWT audience) - aud = token_data.get("aud") + aud: list[str] | str | None = token_data.get("aud") if isinstance(aud, list): for audience in aud: if self._is_valid_resource(audience): diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index b562cc932..76b598f93 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -2,22 +2,19 @@ import click import mcp.types as types from mcp.server.lowlevel import Server +from starlette.requests import Request -def create_messages( - context: str | None = None, topic: str | None = None -) -> list[types.PromptMessage]: +def create_messages(context: str | None = None, topic: str | None = None) -> list[types.PromptMessage]: """Create the messages for the prompt.""" - messages = [] + messages: list[types.PromptMessage] = [] # Add context if provided if context: messages.append( types.PromptMessage( role="user", - content=types.TextContent( - type="text", text=f"Here is some relevant context: {context}" - ), + content=types.TextContent(type="text", text=f"Here is some relevant context: {context}"), ) ) @@ -28,11 +25,7 @@ def create_messages( else: prompt += "whatever questions I may have." - messages.append( - types.PromptMessage( - role="user", content=types.TextContent(type="text", text=prompt) - ) - ) + messages.append(types.PromptMessage(role="user", content=types.TextContent(type="text", text=prompt))) return messages @@ -54,8 +47,7 @@ async def list_prompts() -> list[types.Prompt]: types.Prompt( name="simple", title="Simple Assistant Prompt", - description="A simple prompt that can take optional context and topic " - "arguments", + description="A simple prompt that can take optional context and topic arguments", arguments=[ types.PromptArgument( name="context", @@ -72,9 +64,7 @@ async def list_prompts() -> list[types.Prompt]: ] @app.get_prompt() - async def get_prompt( - name: str, arguments: dict[str, str] | None = None - ) -> types.GetPromptResult: + async def get_prompt(name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: if name != "simple": raise ValueError(f"Unknown prompt: {name}") @@ -82,9 +72,7 @@ async def get_prompt( arguments = {} return types.GetPromptResult( - messages=create_messages( - context=arguments.get("context"), topic=arguments.get("topic") - ), + messages=create_messages(context=arguments.get("context"), topic=arguments.get("topic")), description="A simple prompt with optional context and topic arguments", ) @@ -96,13 +84,9 @@ async def get_prompt( sse = SseServerTransport("/messages/") - async def handle_sse(request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: - await app.run( - streams[0], streams[1], app.create_initialization_options() - ) + async def handle_sse(request: Request): + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] + await app.run(streams[0], streams[1], app.create_initialization_options()) return Response() starlette_app = Starlette( @@ -121,9 +105,7 @@ async def handle_sse(request): async def arun(): async with stdio_server() as streams: - await app.run( - streams[0], streams[1], app.create_initialization_options() - ) + await app.run(streams[0], streams[1], app.create_initialization_options()) anyio.run(arun) diff --git a/examples/servers/simple-prompt/pyproject.toml b/examples/servers/simple-prompt/pyproject.toml index 1ef968d40..035b4134a 100644 --- a/examples/servers/simple-prompt/pyproject.toml +++ b/examples/servers/simple-prompt/pyproject.toml @@ -40,7 +40,7 @@ select = ["E", "F", "I"] ignore = [] [tool.ruff] -line-length = 88 +line-length = 120 target-version = "py310" [tool.uv] diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index cef29b851..002d7ad10 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -3,6 +3,7 @@ import mcp.types as types from mcp.server.lowlevel import Server from pydantic import AnyUrl, FileUrl +from starlette.requests import Request SAMPLE_RESOURCES = { "greeting": { @@ -63,13 +64,9 @@ async def read_resource(uri: AnyUrl) -> str | bytes: sse = SseServerTransport("/messages/") - async def handle_sse(request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: - await app.run( - streams[0], streams[1], app.create_initialization_options() - ) + async def handle_sse(request: Request): + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] + await app.run(streams[0], streams[1], app.create_initialization_options()) return Response() starlette_app = Starlette( @@ -88,9 +85,7 @@ async def handle_sse(request): async def arun(): async with stdio_server() as streams: - await app.run( - streams[0], streams[1], app.create_initialization_options() - ) + await app.run(streams[0], streams[1], app.create_initialization_options()) anyio.run(arun) diff --git a/examples/servers/simple-resource/pyproject.toml b/examples/servers/simple-resource/pyproject.toml index cbab1ca47..29906ecc5 100644 --- a/examples/servers/simple-resource/pyproject.toml +++ b/examples/servers/simple-resource/pyproject.toml @@ -40,7 +40,7 @@ select = ["E", "F", "I"] ignore = [] [tool.ruff] -line-length = 88 +line-length = 120 target-version = "py310" [tool.uv] diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index 68f3ac6a6..3071e8231 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -1,6 +1,7 @@ import contextlib import logging from collections.abc import AsyncIterator +from typing import Any import anyio import click @@ -41,7 +42,7 @@ def main( app = Server("mcp-streamable-http-stateless-demo") @app.call_tool() - async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]: + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: ctx = app.request_context interval = arguments.get("interval", 1.0) count = arguments.get("count", 5) @@ -61,10 +62,7 @@ async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]: return [ types.TextContent( type="text", - text=( - f"Sent {count} notifications with {interval}s interval" - f" for caller: {caller}" - ), + text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), ) ] @@ -73,10 +71,7 @@ async def list_tools() -> list[types.Tool]: return [ types.Tool( name="start-notification-stream", - description=( - "Sends a stream of notifications with configurable count" - " and interval" - ), + description=("Sends a stream of notifications with configurable count and interval"), inputSchema={ "type": "object", "required": ["interval", "count", "caller"], @@ -91,9 +86,7 @@ async def list_tools() -> list[types.Tool]: }, "caller": { "type": "string", - "description": ( - "Identifier of the caller to include in notifications" - ), + "description": ("Identifier of the caller to include in notifications"), }, }, }, @@ -108,9 +101,7 @@ async def list_tools() -> list[types.Tool]: stateless=True, ) - async def handle_streamable_http( - scope: Scope, receive: Receive, send: Send - ) -> None: + async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: await session_manager.handle_request(scope, receive, send) @contextlib.asynccontextmanager diff --git a/examples/servers/simple-streamablehttp-stateless/pyproject.toml b/examples/servers/simple-streamablehttp-stateless/pyproject.toml index d2b089451..66d7de833 100644 --- a/examples/servers/simple-streamablehttp-stateless/pyproject.toml +++ b/examples/servers/simple-streamablehttp-stateless/pyproject.toml @@ -29,8 +29,8 @@ select = ["E", "F", "I"] ignore = [] [tool.ruff] -line-length = 88 +line-length = 120 target-version = "py310" [tool.uv] -dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file +dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py index 28c58149f..ee52cdbe7 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py @@ -10,13 +10,7 @@ from dataclasses import dataclass from uuid import uuid4 -from mcp.server.streamable_http import ( - EventCallback, - EventId, - EventMessage, - EventStore, - StreamId, -) +from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId from mcp.types import JSONRPCMessage logger = logging.getLogger(__name__) @@ -54,14 +48,10 @@ def __init__(self, max_events_per_stream: int = 100): # event_id -> EventEntry for quick lookup self.event_index: dict[EventId, EventEntry] = {} - async def store_event( - self, stream_id: StreamId, message: JSONRPCMessage - ) -> EventId: + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: """Stores an event with a generated event ID.""" event_id = str(uuid4()) - event_entry = EventEntry( - event_id=event_id, stream_id=stream_id, message=message - ) + event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message) # Get or create deque for this stream if stream_id not in self.streams: diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 9c25cc569..cf9200ce7 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -1,6 +1,7 @@ import contextlib import logging from collections.abc import AsyncIterator +from typing import Any import anyio import click @@ -45,7 +46,7 @@ def main( app = Server("mcp-streamable-http-demo") @app.call_tool() - async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]: + async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: ctx = app.request_context interval = arguments.get("interval", 1.0) count = arguments.get("count", 5) @@ -54,10 +55,7 @@ async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]: # Send the specified number of notifications with the given interval for i in range(count): # Include more detailed message for resumability demonstration - notification_msg = ( - f"[{i + 1}/{count}] Event from '{caller}' - " - f"Use Last-Event-ID to resume if disconnected" - ) + notification_msg = f"[{i + 1}/{count}] Event from '{caller}' - Use Last-Event-ID to resume if disconnected" await ctx.session.send_log_message( level="info", data=notification_msg, @@ -79,10 +77,7 @@ async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]: return [ types.TextContent( type="text", - text=( - f"Sent {count} notifications with {interval}s interval" - f" for caller: {caller}" - ), + text=(f"Sent {count} notifications with {interval}s interval for caller: {caller}"), ) ] @@ -91,10 +86,7 @@ async def list_tools() -> list[types.Tool]: return [ types.Tool( name="start-notification-stream", - description=( - "Sends a stream of notifications with configurable count" - " and interval" - ), + description=("Sends a stream of notifications with configurable count and interval"), inputSchema={ "type": "object", "required": ["interval", "count", "caller"], @@ -109,9 +101,7 @@ async def list_tools() -> list[types.Tool]: }, "caller": { "type": "string", - "description": ( - "Identifier of the caller to include in notifications" - ), + "description": ("Identifier of the caller to include in notifications"), }, }, }, @@ -136,9 +126,7 @@ async def list_tools() -> list[types.Tool]: ) # ASGI handler for streamable HTTP connections - async def handle_streamable_http( - scope: Scope, receive: Receive, send: Send - ) -> None: + async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: await session_manager.handle_request(scope, receive, send) @contextlib.asynccontextmanager diff --git a/examples/servers/simple-streamablehttp/pyproject.toml b/examples/servers/simple-streamablehttp/pyproject.toml index c35887d1f..e5ec6e08a 100644 --- a/examples/servers/simple-streamablehttp/pyproject.toml +++ b/examples/servers/simple-streamablehttp/pyproject.toml @@ -29,8 +29,8 @@ select = ["E", "F", "I"] ignore = [] [tool.ruff] -line-length = 88 +line-length = 120 target-version = "py310" [tool.uv] -dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file +dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index bf3683c9e..5b2b7d068 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -1,16 +1,17 @@ +from typing import Any + import anyio import click import mcp.types as types from mcp.server.lowlevel import Server from mcp.shared._httpx_utils import create_mcp_http_client +from starlette.requests import Request async def fetch_website( url: str, ) -> list[types.ContentBlock]: - headers = { - "User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)" - } + headers = {"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"} async with create_mcp_http_client(headers=headers) as client: response = await client.get(url) response.raise_for_status() @@ -29,7 +30,7 @@ def main(port: int, transport: str) -> int: app = Server("mcp-website-fetcher") @app.call_tool() - async def fetch_tool(name: str, arguments: dict) -> list[types.ContentBlock]: + async def fetch_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentBlock]: if name != "fetch": raise ValueError(f"Unknown tool: {name}") if "url" not in arguments: @@ -64,13 +65,9 @@ async def list_tools() -> list[types.Tool]: sse = SseServerTransport("/messages/") - async def handle_sse(request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: - await app.run( - streams[0], streams[1], app.create_initialization_options() - ) + async def handle_sse(request: Request): + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: # type: ignore[reportPrivateUsage] + await app.run(streams[0], streams[1], app.create_initialization_options()) return Response() starlette_app = Starlette( @@ -89,9 +86,7 @@ async def handle_sse(request): async def arun(): async with stdio_server() as streams: - await app.run( - streams[0], streams[1], app.create_initialization_options() - ) + await app.run(streams[0], streams[1], app.create_initialization_options()) anyio.run(arun) diff --git a/examples/servers/simple-tool/pyproject.toml b/examples/servers/simple-tool/pyproject.toml index c690aad97..ba7521691 100644 --- a/examples/servers/simple-tool/pyproject.toml +++ b/examples/servers/simple-tool/pyproject.toml @@ -40,7 +40,7 @@ select = ["E", "F", "I"] ignore = [] [tool.ruff] -line-length = 88 +line-length = 120 target-version = "py310" [tool.uv] diff --git a/examples/snippets/clients/parsing_tool_results.py b/examples/snippets/clients/parsing_tool_results.py index 0a3b3997c..515873546 100644 --- a/examples/snippets/clients/parsing_tool_results.py +++ b/examples/snippets/clients/parsing_tool_results.py @@ -34,7 +34,7 @@ async def parse_tool_results(): resource = content.resource if isinstance(resource, types.TextResourceContents): print(f"Config from {resource.uri}: {resource.text}") - elif isinstance(resource, types.BlobResourceContents): + else: print(f"Binary data from {resource.uri}") # Example 4: Parsing image content diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index 74a6f09df..ac978035d 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -22,7 +22,7 @@ # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext, params: types.CreateMessageRequestParams + context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 6d150cd6c..2c8a3b35a 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession mcp = FastMCP(name="Elicitation Example") @@ -16,12 +17,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table( - date: str, - time: str, - party_size: int, - ctx: Context, -) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: """Book a table with date availability check.""" # Check if date is available if date == "2024-12-25": diff --git a/examples/snippets/servers/lifespan_example.py b/examples/snippets/servers/lifespan_example.py index 37d04b597..62278b6aa 100644 --- a/examples/snippets/servers/lifespan_example.py +++ b/examples/snippets/servers/lifespan_example.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession # Mock database class for example @@ -50,7 +51,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context) -> str: +def query_db(ctx: Context[ServerSession, AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() diff --git a/examples/snippets/servers/lowlevel/lifespan.py b/examples/snippets/servers/lowlevel/lifespan.py index 61a9fe78e..ada373122 100644 --- a/examples/snippets/servers/lowlevel/lifespan.py +++ b/examples/snippets/servers/lowlevel/lifespan.py @@ -5,6 +5,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from typing import Any import mcp.server.stdio import mcp.types as types @@ -33,7 +34,7 @@ async def query(self, query_str: str) -> list[dict[str, str]]: @asynccontextmanager -async def server_lifespan(_server: Server) -> AsyncIterator[dict]: +async def server_lifespan(_server: Server) -> AsyncIterator[dict[str, Any]]: """Manage server startup and shutdown lifecycle.""" # Initialize resources on startup db = await Database.connect() @@ -65,7 +66,7 @@ async def handle_list_tools() -> list[types.Tool]: @server.call_tool() -async def query_db(name: str, arguments: dict) -> list[types.TextContent]: +async def query_db(name: str, arguments: dict[str, Any]) -> list[types.TextContent]: """Handle database query tool call.""" if name != "query_db": raise ValueError(f"Unknown tool: {name}") diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 96f0bc141..833bc8905 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,10 +1,11 @@ from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession mcp = FastMCP(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context) -> str: +async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") diff --git a/examples/snippets/servers/sampling.py b/examples/snippets/servers/sampling.py index 230b15fcf..0099836c2 100644 --- a/examples/snippets/servers/sampling.py +++ b/examples/snippets/servers/sampling.py @@ -1,11 +1,12 @@ from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession from mcp.types import SamplingMessage, TextContent mcp = FastMCP(name="Sampling Example") @mcp.tool() -async def generate_poem(topic: str, ctx: Context) -> str: +async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: """Generate a poem using LLM sampling.""" prompt = f"Write a short poem about {topic}" diff --git a/examples/snippets/servers/structured_output.py b/examples/snippets/servers/structured_output.py index 263f6be51..021ffb169 100644 --- a/examples/snippets/servers/structured_output.py +++ b/examples/snippets/servers/structured_output.py @@ -71,7 +71,7 @@ def get_user(user_id: str) -> UserProfile: # Classes WITHOUT type hints cannot be used for structured output class UntypedConfig: - def __init__(self, setting1, setting2): + def __init__(self, setting1, setting2): # type: ignore[reportMissingParameterType] self.setting1 = setting1 self.setting2 = setting2 diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index d62e62dd1..2ac458f6a 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,10 +1,11 @@ from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context, steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") diff --git a/pyproject.toml b/pyproject.toml index 474c58f6e..5da7a7cb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "anyio>=4.5", "httpx>=0.27", "httpx-sse>=0.4", - "pydantic>=2.8.0,<3.0.0", + "pydantic>=2.11.0,<3.0.0", "starlette>=0.27", "python-multipart>=0.0.9", "sse-starlette>=1.6.1", @@ -88,10 +88,19 @@ Issues = "https://github.com/modelcontextprotocol/python-sdk/issues" packages = ["src/mcp"] [tool.pyright] +typeCheckingMode = "strict" include = ["src/mcp", "tests", "examples/servers", "examples/snippets"] venvPath = "." venv = ".venv" -strict = ["src/mcp/**/*.py"] +# The FastAPI style of using decorators in tests gives a `reportUnusedFunction` error. +# See https://github.com/microsoft/pyright/issues/7771 for more details. +# TODO(Marcelo): We should remove `reportPrivateUsage = false`. The idea is that we should test the workflow that uses +# those private functions instead of testing the private functions directly. It makes it easier to maintain the code source +# and refactor code that is not public. +executionEnvironments = [ + { root = "tests", reportUnusedFunction = false, reportPrivateUsage = false }, + { root = "examples/servers", reportUnusedFunction = false }, +] [tool.ruff.lint] select = ["C4", "E", "F", "I", "PERF", "UP"] diff --git a/scripts/update_readme_snippets.py b/scripts/update_readme_snippets.py index 76d40277c..d325333ff 100755 --- a/scripts/update_readme_snippets.py +++ b/scripts/update_readme_snippets.py @@ -29,7 +29,7 @@ def get_github_url(file_path: str) -> str: return f"{base_url}/{file_path}" -def process_snippet_block(match: re.Match, check_mode: bool = False) -> str: +def process_snippet_block(match: re.Match[str], check_mode: bool = False) -> str: """Process a single snippet-source block. Args: diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index c3ddd0de4..fb354ba7f 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -1,10 +1,11 @@ import subprocess import sys from pathlib import Path +from typing import Any import pytest -from mcp.cli.cli import _build_uv_command, _get_npx_command, _parse_file_path +from mcp.cli.cli import _build_uv_command, _get_npx_command, _parse_file_path # type: ignore[reportPrivateUsage] @pytest.mark.parametrize( @@ -14,7 +15,7 @@ ("foo.py:srv_obj", "srv_obj"), ], ) -def test_parse_file_path_accepts_valid_specs(tmp_path, spec, expected_obj): +def test_parse_file_path_accepts_valid_specs(tmp_path: Path, spec: str, expected_obj: str | None): """Should accept valid file specs.""" file = tmp_path / spec.split(":")[0] file.write_text("x = 1") @@ -23,13 +24,13 @@ def test_parse_file_path_accepts_valid_specs(tmp_path, spec, expected_obj): assert obj == expected_obj -def test_parse_file_path_missing(tmp_path): +def test_parse_file_path_missing(tmp_path: Path): """Should system exit if a file is missing.""" with pytest.raises(SystemExit): _parse_file_path(str(tmp_path / "missing.py")) -def test_parse_file_exit_on_dir(tmp_path): +def test_parse_file_exit_on_dir(tmp_path: Path): """Should system exit if a directory is passed""" dir_path = tmp_path / "dir" dir_path.mkdir() @@ -68,17 +69,17 @@ def test_build_uv_command_adds_editable_and_packages(): ] -def test_get_npx_unix_like(monkeypatch): +def test_get_npx_unix_like(monkeypatch: pytest.MonkeyPatch): """Should return "npx" on unix-like systems.""" monkeypatch.setattr(sys, "platform", "linux") assert _get_npx_command() == "npx" -def test_get_npx_windows(monkeypatch): +def test_get_npx_windows(monkeypatch: pytest.MonkeyPatch): """Should return one of the npx candidates on Windows.""" candidates = ["npx.cmd", "npx.exe", "npx"] - def fake_run(cmd, **kw): + def fake_run(cmd: list[str], **kw: Any) -> subprocess.CompletedProcess[bytes]: if cmd[0] in candidates: return subprocess.CompletedProcess(cmd, 0) else: @@ -89,11 +90,11 @@ def fake_run(cmd, **kw): assert _get_npx_command() in candidates -def test_get_npx_returns_none_when_npx_missing(monkeypatch): +def test_get_npx_returns_none_when_npx_missing(monkeypatch: pytest.MonkeyPatch): """Should give None if every candidate fails.""" monkeypatch.setattr(sys, "platform", "win32", raising=False) - def always_fail(*args, **kwargs): + def always_fail(*args: Any, **kwargs: Any) -> subprocess.CompletedProcess[bytes]: raise subprocess.CalledProcessError(1, args[0]) monkeypatch.setattr(subprocess, "run", always_fail) diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 0c8283903..97014af9f 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -1,22 +1,22 @@ +from collections.abc import Callable, Generator from contextlib import asynccontextmanager +from typing import Any from unittest.mock import patch import pytest +from anyio.streams.memory import MemoryObjectSendStream import mcp.shared.memory from mcp.shared.message import SessionMessage -from mcp.types import ( - JSONRPCNotification, - JSONRPCRequest, -) +from mcp.types import JSONRPCNotification, JSONRPCRequest class SpyMemoryObjectSendStream: - def __init__(self, original_stream): + def __init__(self, original_stream: MemoryObjectSendStream[SessionMessage]): self.original_stream = original_stream self.sent_messages: list[SessionMessage] = [] - async def send(self, message): + async def send(self, message: SessionMessage): self.sent_messages.append(message) await self.original_stream.send(message) @@ -26,16 +26,12 @@ async def aclose(self): async def __aenter__(self): return self - async def __aexit__(self, *args): + async def __aexit__(self, *args: Any): await self.aclose() class StreamSpyCollection: - def __init__( - self, - client_spy: SpyMemoryObjectSendStream, - server_spy: SpyMemoryObjectSendStream, - ): + def __init__(self, client_spy: SpyMemoryObjectSendStream, server_spy: SpyMemoryObjectSendStream): self.client = client_spy self.server = server_spy @@ -80,7 +76,7 @@ def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNot @pytest.fixture -def stream_spy(): +def stream_spy() -> Generator[Callable[[], StreamSpyCollection], None, None]: """Fixture that provides spies for both client and server write streams. Example usage: @@ -103,7 +99,7 @@ async def test_something(stream_spy): server_spy = None # Store references to our spy objects - def capture_spies(c_spy, s_spy): + def capture_spies(c_spy: SpyMemoryObjectSendStream, s_spy: SpyMemoryObjectSendStream): nonlocal client_spy, server_spy client_spy = c_spy server_spy = s_spy diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb962bfc1..61d74df1e 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -11,12 +11,7 @@ from pydantic import AnyHttpUrl, AnyUrl from mcp.client.auth import OAuthClientProvider, PKCEParameters -from mcp.shared.auth import ( - OAuthClientInformationFull, - OAuthClientMetadata, - OAuthToken, - ProtectedResourceMetadata, -) +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, ProtectedResourceMetadata class MockTokenStorage: @@ -66,7 +61,7 @@ def valid_tokens(): @pytest.fixture -def oauth_provider(client_metadata, mock_storage): +def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): async def redirect_handler(url: str) -> None: """Mock redirect handler.""" pass @@ -115,7 +110,9 @@ class TestOAuthContext: """Test OAuth context functionality.""" @pytest.mark.anyio - async def test_oauth_provider_initialization(self, oauth_provider, client_metadata, mock_storage): + async def test_oauth_provider_initialization( + self, oauth_provider: OAuthClientProvider, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): """Test OAuthClientProvider basic setup.""" assert oauth_provider.context.server_url == "https://api.example.com/v1/mcp" assert oauth_provider.context.client_metadata == client_metadata @@ -123,7 +120,7 @@ async def test_oauth_provider_initialization(self, oauth_provider, client_metada assert oauth_provider.context.timeout == 300.0 assert oauth_provider.context is not None - def test_context_url_parsing(self, oauth_provider): + def test_context_url_parsing(self, oauth_provider: OAuthClientProvider): """Test get_authorization_base_url() extracts base URLs correctly.""" context = oauth_provider.context @@ -145,7 +142,7 @@ def test_context_url_parsing(self, oauth_provider): ) @pytest.mark.anyio - async def test_token_validity_checking(self, oauth_provider, mock_storage, valid_tokens): + async def test_token_validity_checking(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): """Test is_token_valid() and can_refresh_token() logic.""" context = oauth_provider.context @@ -180,7 +177,7 @@ async def test_token_validity_checking(self, oauth_provider, mock_storage, valid context.client_info = None assert not context.can_refresh_token() - def test_clear_tokens(self, oauth_provider, valid_tokens): + def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): """Test clear_tokens() removes token data.""" context = oauth_provider.context context.current_tokens = valid_tokens @@ -198,7 +195,9 @@ class TestOAuthFlow: """Test OAuth flow methods.""" @pytest.mark.anyio - async def test_discover_protected_resource_request(self, client_metadata, mock_storage): + async def test_discover_protected_resource_request( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): """Test protected resource discovery request building maintains backward compatibility.""" async def redirect_handler(url: str) -> None: @@ -236,7 +235,7 @@ async def callback_handler() -> tuple[str, str | None]: assert "mcp-protocol-version" in request.headers @pytest.mark.anyio - def test_create_oauth_metadata_request(self, oauth_provider): + def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider): """Test OAuth metadata discovery request building.""" request = oauth_provider._create_oauth_metadata_request("https://example.com") @@ -250,7 +249,7 @@ class TestOAuthFallback: """Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers.""" @pytest.mark.anyio - async def test_oauth_discovery_fallback_order(self, oauth_provider): + async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientProvider): """Test fallback URL construction order.""" discovery_urls = oauth_provider._get_discovery_urls() @@ -262,7 +261,7 @@ async def test_oauth_discovery_fallback_order(self, oauth_provider): ] @pytest.mark.anyio - async def test_oauth_discovery_fallback_conditions(self, oauth_provider): + async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthClientProvider): """Test the conditions during which an AS metadata discovery fallback will be attempted.""" # Ensure no tokens are stored oauth_provider.context.current_tokens = None @@ -365,7 +364,7 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider): token_request = await auth_flow.asend(token_response) @pytest.mark.anyio - async def test_handle_metadata_response_success(self, oauth_provider): + async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider): """Test successful metadata response handling.""" # Create minimal valid OAuth metadata content = b"""{ @@ -381,7 +380,7 @@ async def test_handle_metadata_response_success(self, oauth_provider): assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com/" @pytest.mark.anyio - async def test_register_client_request(self, oauth_provider): + async def test_register_client_request(self, oauth_provider: OAuthClientProvider): """Test client registration request building.""" request = await oauth_provider._register_client() @@ -391,7 +390,7 @@ async def test_register_client_request(self, oauth_provider): assert request.headers["Content-Type"] == "application/json" @pytest.mark.anyio - async def test_register_client_skip_if_registered(self, oauth_provider, mock_storage): + async def test_register_client_skip_if_registered(self, oauth_provider: OAuthClientProvider): """Test client registration is skipped if already registered.""" # Set existing client info client_info = OAuthClientInformationFull( @@ -405,7 +404,7 @@ async def test_register_client_skip_if_registered(self, oauth_provider, mock_sto assert request is None @pytest.mark.anyio - async def test_token_exchange_request(self, oauth_provider): + async def test_token_exchange_request(self, oauth_provider: OAuthClientProvider): """Test token exchange request building.""" # Set up required context oauth_provider.context.client_info = OAuthClientInformationFull( @@ -429,7 +428,7 @@ async def test_token_exchange_request(self, oauth_provider): assert "client_secret=test_secret" in content @pytest.mark.anyio - async def test_refresh_token_request(self, oauth_provider, valid_tokens): + async def test_refresh_token_request(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): """Test refresh token request building.""" # Set up required context oauth_provider.context.current_tokens = valid_tokens @@ -538,11 +537,11 @@ class TestRegistrationResponse: """Test client registration response handling.""" @pytest.mark.anyio - async def test_handle_registration_response_reads_before_accessing_text(self, oauth_provider): + async def test_handle_registration_response_reads_before_accessing_text(self, oauth_provider: OAuthClientProvider): """Test that response.aread() is called before accessing response.text.""" # Track if aread() was called - class MockResponse: + class MockResponse(httpx.Response): def __init__(self): self.status_code = 400 self._aread_called = False @@ -574,7 +573,9 @@ class TestAuthFlow: """Test the auth flow in httpx.""" @pytest.mark.anyio - async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, valid_tokens): + async def test_auth_flow_with_valid_tokens( + self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken + ): """Test auth flow when tokens are already valid.""" # Pre-store valid tokens await mock_storage.set_tokens(valid_tokens) @@ -600,7 +601,7 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v pass # Expected @pytest.mark.anyio - async def test_auth_flow_with_no_tokens(self, oauth_provider, mock_storage): + async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvider): """Test auth flow when no tokens are available, triggering the full OAuth flow.""" # Ensure no tokens are stored oauth_provider.context.current_tokens = None @@ -810,7 +811,11 @@ class TestProtectedResourceWWWAuthenticate: ], ) def test_extract_resource_metadata_from_www_auth_valid_cases( - self, client_metadata, mock_storage, www_auth_header, expected_url + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + www_auth_header: str, + expected_url: str, ): """Test extraction of resource_metadata URL from various valid WWW-Authenticate headers.""" @@ -862,7 +867,12 @@ async def callback_handler() -> tuple[str, str | None]: ], ) def test_extract_resource_metadata_from_www_auth_invalid_cases( - self, client_metadata, mock_storage, status_code, www_auth_header, description + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + status_code: int, + www_auth_header: str | None, + description: str, ): """Test extraction returns None for invalid cases.""" diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index f7b031737..b31b704a4 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -1,15 +1,16 @@ +from collections.abc import Callable + import pytest from mcp.server.fastmcp import FastMCP -from mcp.shared.memory import ( - create_connected_server_and_client_session as create_session, -) +from mcp.shared.memory import create_connected_server_and_client_session as create_session + +from .conftest import StreamSpyCollection -# Mark the whole module for async tests pytestmark = pytest.mark.anyio -async def test_list_tools_cursor_parameter(stream_spy): +async def test_list_tools_cursor_parameter(stream_spy: Callable[[], StreamSpyCollection]): """Test that the cursor parameter is accepted for list_tools and that it is correctly passed to the server. @@ -64,7 +65,7 @@ async def test_tool_2() -> str: assert list_tools_requests[0].params["cursor"] == "" -async def test_list_resources_cursor_parameter(stream_spy): +async def test_list_resources_cursor_parameter(stream_spy: Callable[[], StreamSpyCollection]): """Test that the cursor parameter is accepted for list_resources and that it is correctly passed to the server. @@ -114,7 +115,7 @@ async def test_resource() -> str: assert list_resources_requests[0].params["cursor"] == "" -async def test_list_prompts_cursor_parameter(stream_spy): +async def test_list_prompts_cursor_parameter(stream_spy: Callable[[], StreamSpyCollection]): """Test that the cursor parameter is accepted for list_prompts and that it is correctly passed to the server. See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format @@ -163,7 +164,7 @@ async def test_prompt(name: str) -> str: assert list_prompts_requests[0].params["cursor"] == "" -async def test_list_resource_templates_cursor_parameter(stream_spy): +async def test_list_resource_templates_cursor_parameter(stream_spy: Callable[[], StreamSpyCollection]): """Test that the cursor parameter is accepted for list_resource_templates and that it is correctly passed to the server. diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index 242515b96..4e649b0eb 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -1,5 +1,6 @@ import logging from contextlib import contextmanager +from typing import Any from unittest.mock import patch import pytest @@ -53,7 +54,7 @@ async def list_tools(): ] @server.call_tool() - async def call_tool(name: str, arguments: dict): + async def call_tool(name: str, arguments: dict[str, Any]): # Return invalid structured content - age is string instead of integer # The low-level server will wrap this in CallToolResult return {"name": "John", "age": "invalid"} # Invalid: age should be int @@ -92,7 +93,7 @@ async def list_tools(): ] @server.call_tool() - async def call_tool(name: str, arguments: dict): + async def call_tool(name: str, arguments: dict[str, Any]): # Return invalid structured content - result is string instead of integer return {"result": "not_a_number"} # Invalid: should be int @@ -123,7 +124,7 @@ async def list_tools(): ] @server.call_tool() - async def call_tool(name: str, arguments: dict): + async def call_tool(name: str, arguments: dict[str, Any]): # Return invalid structured content - values should be integers return {"alice": "100", "bob": "85"} # Invalid: values should be int @@ -158,7 +159,7 @@ async def list_tools(): ] @server.call_tool() - async def call_tool(name: str, arguments: dict): + async def call_tool(name: str, arguments: dict[str, Any]): # Return structured content missing required field 'email' return {"name": "John", "age": 30} # Missing required 'email' @@ -170,17 +171,17 @@ async def call_tool(name: str, arguments: dict): assert "Invalid structured content returned by tool get_person" in str(exc_info.value) @pytest.mark.anyio - async def test_tool_not_listed_warning(self, caplog): + async def test_tool_not_listed_warning(self, caplog: pytest.LogCaptureFixture): """Test that client logs warning when tool is not in list_tools but has outputSchema""" server = Server("test-server") @server.list_tools() - async def list_tools(): + async def list_tools() -> list[Tool]: # Return empty list - tool is not listed return [] @server.call_tool() - async def call_tool(name: str, arguments: dict): + async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: # Server still responds to the tool call with structured content return {"result": 42} diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index 527884219..0752d649f 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -1,14 +1,12 @@ +from typing import Any from unittest.mock import patch import anyio import pytest -from mcp.shared.session import BaseSession -from mcp.types import ( - ClientRequest, - EmptyResult, - PingRequest, -) +from mcp.shared.message import SessionMessage +from mcp.shared.session import BaseSession, RequestId, SendResultT +from mcp.types import ClientNotification, ClientRequest, ClientResult, EmptyResult, ErrorData, PingRequest @pytest.mark.anyio @@ -20,13 +18,13 @@ async def test_send_request_stream_cleanup(): """ # Create a mock session with the minimal required functionality - class TestSession(BaseSession): - async def _send_response(self, request_id, response): + class TestSession(BaseSession[ClientRequest, ClientNotification, ClientResult, Any, Any]): + async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: pass # Create streams - write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1) - read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1) + write_stream_send, write_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) + read_stream_send, read_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) # Create the session session = TestSession( @@ -37,14 +35,10 @@ async def _send_response(self, request_id, response): ) # Create a test request - request = ClientRequest( - PingRequest( - method="ping", - ) - ) + request = ClientRequest(PingRequest(method="ping")) # Patch the _write_stream.send method to raise an exception - async def mock_send(*args, **kwargs): + async def mock_send(*args: Any, **kwargs: Any): raise RuntimeError("Simulated network error") # Record the response streams before the test diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 16a887e00..c38cfeabc 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -5,11 +5,7 @@ import mcp from mcp import types -from mcp.client.session_group import ( - ClientSessionGroup, - SseServerParameters, - StreamableHttpParameters, -) +from mcp.client.session_group import ClientSessionGroup, SseServerParameters, StreamableHttpParameters from mcp.client.stdio import StdioServerParameters from mcp.shared.exceptions import McpError @@ -54,7 +50,7 @@ async def test_call_tool(self): mock_session = mock.AsyncMock() # --- Prepare Session Group --- - def hook(name, server_info): + def hook(name: str, server_info: types.Implementation) -> str: return f"{(server_info.name)}-{name}" mcp_session_group = ClientSessionGroup(component_name_hook=hook) @@ -79,7 +75,7 @@ def hook(name, server_info): {"name": "value1", "args": {}}, ) - async def test_connect_to_server(self, mock_exit_stack): + async def test_connect_to_server(self, mock_exit_stack: contextlib.AsyncExitStack): """Test connecting to a server and aggregating components.""" # --- Mock Dependencies --- mock_server_info = mock.Mock(spec=types.Implementation) @@ -116,7 +112,7 @@ async def test_connect_to_server(self, mock_exit_stack): mock_session.list_resources.assert_awaited_once() mock_session.list_prompts.assert_awaited_once() - async def test_connect_to_server_with_name_hook(self, mock_exit_stack): + async def test_connect_to_server_with_name_hook(self, mock_exit_stack: contextlib.AsyncExitStack): """Test connecting with a component name hook.""" # --- Mock Dependencies --- mock_server_info = mock.Mock(spec=types.Implementation) @@ -208,7 +204,7 @@ async def test_disconnect_from_server(self): # No mock arguments needed assert "res1" not in group._resources assert "prm1" not in group._prompts - async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_stack): + async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_stack: contextlib.AsyncExitStack): """Test McpError raised when connecting a server with a dup name.""" # --- Setup Pre-existing State --- group = ClientSessionGroup(exit_stack=mock_exit_stack) @@ -282,9 +278,9 @@ async def test_disconnect_non_existent_server(self): ) async def test_establish_session_parameterized( self, - server_params_instance, - client_type_name, # Just for clarity or conditional logic if needed - patch_target_for_client_func, + server_params_instance: StdioServerParameters | SseServerParameters | StreamableHttpParameters, + client_type_name: str, # Just for clarity or conditional logic if needed + patch_target_for_client_func: str, ): with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class: with mock.patch(patch_target_for_client_func) as mock_specific_client_func: @@ -338,8 +334,10 @@ async def test_establish_session_parameterized( # --- Assertions --- # 1. Assert the correct specific client function was called if client_type_name == "stdio": + assert isinstance(server_params_instance, StdioServerParameters) mock_specific_client_func.assert_called_once_with(server_params_instance) elif client_type_name == "sse": + assert isinstance(server_params_instance, SseServerParameters) mock_specific_client_func.assert_called_once_with( url=server_params_instance.url, headers=server_params_instance.headers, @@ -347,6 +345,7 @@ async def test_establish_session_parameterized( sse_read_timeout=server_params_instance.sse_read_timeout, ) elif client_type_name == "streamablehttp": + assert isinstance(server_params_instance, StreamableHttpParameters) mock_specific_client_func.assert_called_once_with( url=server_params_instance.url, headers=server_params_instance.headers, diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 2abb42e5c..69dad4846 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -9,28 +9,25 @@ import pytest from mcp.client.session import ClientSession -from mcp.client.stdio import ( - StdioServerParameters, - _create_platform_compatible_process, - stdio_client, -) +from mcp.client.stdio import StdioServerParameters, _create_platform_compatible_process, stdio_client from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse -from tests.shared.test_win32_utils import escape_path_for_python + +from ..shared.test_win32_utils import escape_path_for_python # Timeout for cleanup of processes that ignore SIGTERM # This timeout ensures the test fails quickly if the cleanup logic doesn't have # proper fallback mechanisms (SIGINT/SIGKILL) for processes that ignore SIGTERM SIGTERM_IGNORING_PROCESS_TIMEOUT = 5.0 -tee: str = shutil.which("tee") # type: ignore -python: str = shutil.which("python") # type: ignore +tee = shutil.which("tee") @pytest.mark.anyio @pytest.mark.skipif(tee is None, reason="could not find tee command") async def test_stdio_context_manager_exiting(): + assert tee is not None async with stdio_client(StdioServerParameters(command=tee)) as (_, _): pass @@ -38,6 +35,7 @@ async def test_stdio_context_manager_exiting(): @pytest.mark.anyio @pytest.mark.skipif(tee is None, reason="could not find tee command") async def test_stdio_client(): + assert tee is not None server_parameters = StdioServerParameters(command=tee) async with stdio_client(server_parameters) as (read_stream, write_stream): @@ -52,7 +50,7 @@ async def test_stdio_client(): session_message = SessionMessage(message) await write_stream.send(session_message) - read_messages = [] + read_messages: list[JSONRPCMessage] = [] async with read_stream: async for message in read_stream: if isinstance(message, Exception): @@ -118,7 +116,7 @@ async def test_stdio_client_universal_cleanup(): """ import time import sys - + # Simulate a long-running process for i in range(100): time.sleep(0.1) @@ -136,7 +134,7 @@ async def test_stdio_client_universal_cleanup(): start_time = time.time() with anyio.move_on_after(8.0) as cancel_scope: - async with stdio_client(server_params) as (read_stream, write_stream): + async with stdio_client(server_params) as (_, _): # Immediately exit - this triggers cleanup while process is still running pass @@ -195,7 +193,7 @@ def sigint_handler(signum, frame): try: # Use anyio timeout to prevent test from hanging forever with anyio.move_on_after(5.0) as cancel_scope: - async with stdio_client(server_params) as (read_stream, write_stream): + async with stdio_client(server_params) as (_, _): # Let the process start and begin ignoring SIGTERM await anyio.sleep(0.5) # Exit context triggers cleanup - this should not hang @@ -532,7 +530,7 @@ async def test_stdio_client_graceful_stdin_exit(): script_content = textwrap.dedent( """ import sys - + # Read from stdin until it's closed try: while True: @@ -541,7 +539,7 @@ async def test_stdio_client_graceful_stdin_exit(): break except: pass - + # Exit gracefully sys.exit(0) """ @@ -556,7 +554,7 @@ async def test_stdio_client_graceful_stdin_exit(): # Use anyio timeout to prevent test from hanging forever with anyio.move_on_after(5.0) as cancel_scope: - async with stdio_client(server_params) as (read_stream, write_stream): + async with stdio_client(server_params) as (_, _): # Let the process start and begin reading stdin await anyio.sleep(0.2) # Exit context triggers cleanup - process should exit from stdin closure @@ -590,16 +588,16 @@ async def test_stdio_client_stdin_close_ignored(): import signal import sys import time - + # Set up SIGTERM handler to exit cleanly def sigterm_handler(signum, frame): sys.exit(0) - + signal.signal(signal.SIGTERM, sigterm_handler) - + # Close stdin immediately to simulate ignoring it sys.stdin.close() - + # Keep running until SIGTERM while True: time.sleep(0.1) @@ -615,7 +613,7 @@ def sigterm_handler(signum, frame): # Use anyio timeout to prevent test from hanging forever with anyio.move_on_after(7.0) as cancel_scope: - async with stdio_client(server_params) as (read_stream, write_stream): + async with stdio_client(server_params) as (_, _): # Let the process start await anyio.sleep(0.2) # Exit context triggers cleanup diff --git a/tests/issues/test_1027_win_unreachable_cleanup.py b/tests/issues/test_1027_win_unreachable_cleanup.py index cb2e05a68..637f7963b 100644 --- a/tests/issues/test_1027_win_unreachable_cleanup.py +++ b/tests/issues/test_1027_win_unreachable_cleanup.py @@ -12,13 +12,19 @@ import tempfile import textwrap from pathlib import Path +from typing import TYPE_CHECKING import anyio import pytest from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import _create_platform_compatible_process, stdio_client -from tests.shared.test_win32_utils import escape_path_for_python + +# TODO(Marcelo): This doesn't seem to be the right path. We should fix this. +if TYPE_CHECKING: + from ..shared.test_win32_utils import escape_path_for_python +else: + from tests.shared.test_win32_utils import escape_path_for_python @pytest.mark.anyio @@ -52,10 +58,10 @@ async def test_lifespan_cleanup_executed(): from pathlib import Path from contextlib import asynccontextmanager from mcp.server.fastmcp import FastMCP - + STARTUP_MARKER = {escape_path_for_python(startup_marker)} CLEANUP_MARKER = {escape_path_for_python(cleanup_marker)} - + @asynccontextmanager async def lifespan(server): # Write startup marker @@ -65,13 +71,13 @@ async def lifespan(server): finally: # This cleanup code now runs properly during shutdown Path(CLEANUP_MARKER).write_text("cleaned up") - + mcp = FastMCP("test-server", lifespan=lifespan) - + @mcp.tool() def echo(text: str) -> str: return text - + if __name__ == "__main__": mcp.run() """) @@ -160,10 +166,10 @@ async def test_stdin_close_triggers_cleanup(): from pathlib import Path from contextlib import asynccontextmanager from mcp.server.fastmcp import FastMCP - + STARTUP_MARKER = {escape_path_for_python(startup_marker)} CLEANUP_MARKER = {escape_path_for_python(cleanup_marker)} - + @asynccontextmanager async def lifespan(server): # Write startup marker @@ -173,13 +179,13 @@ async def lifespan(server): finally: # This cleanup code runs when stdin closes, enabling graceful shutdown Path(CLEANUP_MARKER).write_text("cleaned up") - + mcp = FastMCP("test-server", lifespan=lifespan) - + @mcp.tool() def echo(text: str) -> str: return text - + if __name__ == "__main__": # The server should exit gracefully when stdin closes try: diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py index f87110a28..831736510 100644 --- a/tests/issues/test_188_concurrency.py +++ b/tests/issues/test_188_concurrency.py @@ -11,7 +11,7 @@ async def test_messages_are_executed_concurrently_tools(): server = FastMCP("test") event = anyio.Event() tool_started = anyio.Event() - call_order = [] + call_order: list[str] = [] @server.tool("sleep") async def sleep_tool(): @@ -52,7 +52,7 @@ async def test_messages_are_executed_concurrently_tools_and_resources(): server = FastMCP("test") event = anyio.Event() tool_started = anyio.Event() - call_order = [] + call_order: list[str] = [] @server.tool("sleep") async def sleep_tool(): diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index 3c63f00b7..3762b092b 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -12,6 +12,7 @@ JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, + JSONRPCResponse, NotificationParams, ) @@ -23,8 +24,8 @@ async def test_request_id_match() -> None: custom_request_id = "test-123" # Create memory streams for communication - client_writer, client_reader = anyio.create_memory_object_stream(1) - server_writer, server_reader = anyio.create_memory_object_stream(1) + client_writer, client_reader = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server_writer, server_reader = anyio.create_memory_object_stream[SessionMessage | Exception](1) # Server task to process the request async def run_server(): @@ -85,6 +86,9 @@ async def run_server(): response = await server_reader.receive() # Verify response ID matches request ID + assert isinstance(response, SessionMessage) + assert isinstance(response.message, JSONRPCMessage) + assert isinstance(response.message.root, JSONRPCResponse) assert response.message.root.id == custom_request_id, "Response ID should match request ID" # Cancel server task diff --git a/tests/issues/test_355_type_error.py b/tests/issues/test_355_type_error.py index 91416e5ca..7159308b2 100644 --- a/tests/issues/test_355_type_error.py +++ b/tests/issues/test_355_type_error.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession class Database: # Replace with your actual DB type @@ -44,7 +45,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context) -> str: +def query_db(ctx: Context[ServerSession, AppContext]) -> str: """Tool that uses initialized resources""" db = ctx.request_context.lifespan_context.db return db.query() diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index c3570a39c..5584abcae 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -3,15 +3,18 @@ from collections.abc import Sequence from datetime import timedelta from pathlib import Path +from typing import Any import anyio import pytest from anyio.abc import TaskStatus +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import types from mcp.client.session import ClientSession from mcp.server.lowlevel import Server from mcp.shared.exceptions import McpError +from mcp.shared.message import SessionMessage from mcp.types import ContentBlock, TextContent @@ -46,7 +49,7 @@ async def list_tools() -> list[types.Tool]: ] @server.call_tool() - async def slow_tool(name: str, arg) -> Sequence[ContentBlock]: + async def slow_tool(name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock]: nonlocal request_count request_count += 1 @@ -58,8 +61,8 @@ async def slow_tool(name: str, arg) -> Sequence[ContentBlock]: return [TextContent(type="text", text=f"unknown {request_count}")] async def server_handler( - read_stream, - write_stream, + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, ): with anyio.CancelScope() as scope: @@ -71,7 +74,11 @@ async def server_handler( raise_exceptions=True, ) - async def client(read_stream, write_stream, scope): + async def client( + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], + scope: anyio.CancelScope, + ): # Use a timeout that's: # - Long enough for fast operations (>10ms) # - Short enough for slow operations (<200ms) @@ -99,8 +106,8 @@ async def client(read_stream, write_stream, scope): scope.cancel() # Run server and client in separate task groups to avoid cancellation - server_writer, server_reader = anyio.create_memory_object_stream(1) - client_writer, client_reader = anyio.create_memory_object_stream(1) + server_writer, server_reader = anyio.create_memory_object_stream[SessionMessage](1) + client_writer, client_reader = anyio.create_memory_object_stream[SessionMessage](1) async with anyio.create_task_group() as tg: scope = await tg.start(server_handler, server_reader, client_writer) diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py index 97edb651e..065bc7841 100644 --- a/tests/issues/test_malformed_input.py +++ b/tests/issues/test_malformed_input.py @@ -1,6 +1,8 @@ # Claude Debug """Test for HackerOne vulnerability report #3156202 - malformed input DOS.""" +from typing import Any + import anyio import pytest @@ -118,7 +120,7 @@ async def test_multiple_concurrent_malformed_requests(): ), ): # Send multiple malformed requests concurrently - malformed_requests = [] + malformed_requests: list[SessionMessage] = [] for i in range(10): malformed_request = JSONRPCRequest( jsonrpc="2.0", @@ -137,7 +139,7 @@ async def test_multiple_concurrent_malformed_requests(): await anyio.sleep(0.2) # Verify we get error responses for all requests - error_responses = [] + error_responses: list[Any] = [] try: while True: response_message = write_receive_stream.receive_nowait() diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index 5bb0f969e..80c8bae21 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -11,16 +11,8 @@ from starlette.requests import Request from starlette.types import Message, Receive, Scope, Send -from mcp.server.auth.middleware.bearer_auth import ( - AuthenticatedUser, - BearerAuthBackend, - RequireAuthMiddleware, -) -from mcp.server.auth.provider import ( - AccessToken, - OAuthAuthorizationServerProvider, - ProviderTokenVerifier, -) +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, BearerAuthBackend, RequireAuthMiddleware +from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, ProviderTokenVerifier class MockOAuthProvider: @@ -31,7 +23,7 @@ class MockOAuthProvider: """ def __init__(self): - self.tokens = {} # token -> AccessToken + self.tokens: dict[str, AccessToken] = {} # token -> AccessToken def add_token(self, token: str, access_token: AccessToken) -> None: """Add a token to the provider.""" @@ -287,7 +279,7 @@ async def test_no_user(self): async def receive() -> Message: return {"type": "http.request"} - sent_messages = [] + sent_messages: list[Message] = [] async def send(message: Message) -> None: sent_messages.append(message) @@ -311,7 +303,7 @@ async def test_non_authenticated_user(self): async def receive() -> Message: return {"type": "http.request"} - sent_messages = [] + sent_messages: list[Message] = [] async def send(message: Message) -> None: sent_messages.append(message) @@ -340,7 +332,7 @@ async def test_missing_required_scope(self, valid_access_token: AccessToken): async def receive() -> Message: return {"type": "http.request"} - sent_messages = [] + sent_messages: list[Message] = [] async def send(message: Message) -> None: sent_messages.append(message) @@ -368,7 +360,7 @@ async def test_no_auth_credentials(self, valid_access_token: AccessToken): async def receive() -> Message: return {"type": "http.request"} - sent_messages = [] + sent_messages: list[Message] = [] async def send(message: Message) -> None: sent_messages.append(message) diff --git a/tests/server/auth/test_error_handling.py b/tests/server/auth/test_error_handling.py index 7846c8adb..f331b2cb2 100644 --- a/tests/server/auth/test_error_handling.py +++ b/tests/server/auth/test_error_handling.py @@ -3,6 +3,7 @@ """ import unittest.mock +from typing import TYPE_CHECKING, Any from urllib.parse import parse_qs, urlparse import httpx @@ -11,15 +12,14 @@ from pydantic import AnyHttpUrl from starlette.applications import Starlette -from mcp.server.auth.provider import ( - AuthorizeError, - RegistrationError, - TokenError, -) +from mcp.server.auth.provider import AuthorizeError, RegistrationError, TokenError from mcp.server.auth.routes import create_auth_routes -from tests.server.fastmcp.auth.test_auth_integration import ( - MockOAuthProvider, -) + +# TODO(Marcelo): This TYPE_CHECKING shouldn't be here, but pytest doesn't seem to get the module correctly. +if TYPE_CHECKING: + from ...server.fastmcp.auth.test_auth_integration import MockOAuthProvider +else: + from tests.server.fastmcp.auth.test_auth_integration import MockOAuthProvider @pytest.fixture @@ -29,7 +29,7 @@ def oauth_provider(): @pytest.fixture -def app(oauth_provider): +def app(oauth_provider: MockOAuthProvider): from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions # Enable client registration @@ -49,7 +49,7 @@ def app(oauth_provider): @pytest.fixture -def client(app): +def client(app: Starlette): transport = ASGITransport(app=app) # Use base_url without a path since routes are directly on the app return httpx.AsyncClient(transport=transport, base_url="http://localhost") @@ -74,7 +74,7 @@ def pkce_challenge(): @pytest.fixture -async def registered_client(client): +async def registered_client(client: httpx.AsyncClient) -> dict[str, Any]: """Create and register a test client.""" # Default client metadata client_metadata = { @@ -94,7 +94,7 @@ async def registered_client(client): class TestRegistrationErrorHandling: @pytest.mark.anyio - async def test_registration_error_handling(self, client, oauth_provider): + async def test_registration_error_handling(self, client: httpx.AsyncClient, oauth_provider: MockOAuthProvider): # Mock the register_client method to raise a registration error with unittest.mock.patch.object( oauth_provider, @@ -128,7 +128,13 @@ async def test_registration_error_handling(self, client, oauth_provider): class TestAuthorizeErrorHandling: @pytest.mark.anyio - async def test_authorize_error_handling(self, client, oauth_provider, registered_client, pkce_challenge): + async def test_authorize_error_handling( + self, + client: httpx.AsyncClient, + oauth_provider: MockOAuthProvider, + registered_client: dict[str, Any], + pkce_challenge: dict[str, str], + ): # Mock the authorize method to raise an authorize error with unittest.mock.patch.object( oauth_provider, @@ -165,7 +171,13 @@ async def test_authorize_error_handling(self, client, oauth_provider, registered class TestTokenErrorHandling: @pytest.mark.anyio - async def test_token_error_handling_auth_code(self, client, oauth_provider, registered_client, pkce_challenge): + async def test_token_error_handling_auth_code( + self, + client: httpx.AsyncClient, + oauth_provider: MockOAuthProvider, + registered_client: dict[str, Any], + pkce_challenge: dict[str, str], + ): # Register the client and get an auth code client_id = registered_client["client_id"] client_secret = registered_client["client_secret"] @@ -218,7 +230,13 @@ async def test_token_error_handling_auth_code(self, client, oauth_provider, regi assert data["error_description"] == "The authorization code is invalid" @pytest.mark.anyio - async def test_token_error_handling_refresh_token(self, client, oauth_provider, registered_client, pkce_challenge): + async def test_token_error_handling_refresh_token( + self, + client: httpx.AsyncClient, + oauth_provider: MockOAuthProvider, + registered_client: dict[str, Any], + pkce_challenge: dict[str, str], + ): # Register the client and get tokens client_id = registered_client["client_id"] client_secret = registered_client["client_secret"] diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index e4a8f3f4c..e4bb17397 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -7,6 +7,7 @@ import secrets import time import unittest.mock +from typing import Any from urllib.parse import parse_qs, urlparse import httpx @@ -22,24 +23,17 @@ RefreshToken, construct_redirect_uri, ) -from mcp.server.auth.routes import ( - ClientRegistrationOptions, - RevocationOptions, - create_auth_routes, -) -from mcp.shared.auth import ( - OAuthClientInformationFull, - OAuthToken, -) +from mcp.server.auth.routes import ClientRegistrationOptions, RevocationOptions, create_auth_routes +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken # Mock OAuth provider for testing -class MockOAuthProvider(OAuthAuthorizationServerProvider): +class MockOAuthProvider(OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]): def __init__(self): - self.clients = {} - self.auth_codes = {} # code -> {client_id, code_challenge, redirect_uri} - self.tokens = {} # token -> {client_id, scopes, expires_at} - self.refresh_tokens = {} # refresh_token -> access_token + self.clients: dict[str, OAuthClientInformationFull] = {} + self.auth_codes: dict[str, AuthorizationCode] = {} # code -> {client_id, code_challenge, redirect_uri} + self.tokens: dict[str, AccessToken] = {} # token -> {client_id, scopes, expires_at} + self.refresh_tokens: dict[str, str] = {} # refresh_token -> access_token async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: return self.clients.get(client_id) @@ -196,7 +190,7 @@ def mock_oauth_provider(): @pytest.fixture -def auth_app(mock_oauth_provider): +def auth_app(mock_oauth_provider: MockOAuthProvider): # Create auth router auth_routes = create_auth_routes( mock_oauth_provider, @@ -217,13 +211,15 @@ def auth_app(mock_oauth_provider): @pytest.fixture -async def test_client(auth_app): +async def test_client(auth_app: Starlette): async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client: yield client @pytest.fixture -async def registered_client(test_client: httpx.AsyncClient, request): +async def registered_client( + test_client: httpx.AsyncClient, request: pytest.FixtureRequest +) -> OAuthClientInformationFull: """Create and register a test client. Parameters can be customized via indirect parameterization: @@ -259,7 +255,12 @@ def pkce_challenge(): @pytest.fixture -async def auth_code(test_client, registered_client, pkce_challenge, request): +async def auth_code( + test_client: httpx.AsyncClient, + registered_client: dict[str, Any], + pkce_challenge: dict[str, str], + request: pytest.FixtureRequest, +): """Get an authorization code. Parameters can be customized via indirect parameterization: @@ -300,7 +301,13 @@ async def auth_code(test_client, registered_client, pkce_challenge, request): @pytest.fixture -async def tokens(test_client, registered_client, auth_code, pkce_challenge, request): +async def tokens( + test_client: httpx.AsyncClient, + registered_client: dict[str, Any], + auth_code: dict[str, str], + pkce_challenge: dict[str, str], + request: pytest.FixtureRequest, +): """Exchange authorization code for tokens. Parameters can be customized via indirect parameterization: @@ -373,7 +380,12 @@ async def test_token_validation_error(self, test_client: httpx.AsyncClient): assert "error_description" in error_response # Contains validation error messages @pytest.mark.anyio - async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge): + async def test_token_invalid_auth_code( + self, + test_client: httpx.AsyncClient, + registered_client: dict[str, Any], + pkce_challenge: dict[str, str], + ): """Test token endpoint error - authorization code does not exist.""" # Try to use a non-existent authorization code response = await test_client.post( @@ -398,11 +410,11 @@ async def test_token_invalid_auth_code(self, test_client, registered_client, pkc @pytest.mark.anyio async def test_token_expired_auth_code( self, - test_client, - registered_client, - auth_code, - pkce_challenge, - mock_oauth_provider, + test_client: httpx.AsyncClient, + registered_client: dict[str, Any], + auth_code: dict[str, str], + pkce_challenge: dict[str, str], + mock_oauth_provider: MockOAuthProvider, ): """Test token endpoint error - authorization code has expired.""" # Get the current time for our time mocking @@ -451,7 +463,13 @@ async def test_token_expired_auth_code( ], indirect=True, ) - async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge): + async def test_token_redirect_uri_mismatch( + self, + test_client: httpx.AsyncClient, + registered_client: dict[str, Any], + auth_code: dict[str, str], + pkce_challenge: dict[str, str], + ): """Test token endpoint error - redirect URI mismatch.""" # Try to use the code with a different redirect URI response = await test_client.post( @@ -472,7 +490,9 @@ async def test_token_redirect_uri_mismatch(self, test_client, registered_client, assert "redirect_uri did not match" in error_response["error_description"] @pytest.mark.anyio - async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code): + async def test_token_code_verifier_mismatch( + self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], auth_code: dict[str, str] + ): """Test token endpoint error - PKCE code verifier mismatch.""" # Try to use the code with an incorrect code verifier response = await test_client.post( @@ -493,7 +513,7 @@ async def test_token_code_verifier_mismatch(self, test_client, registered_client assert "incorrect code_verifier" in error_response["error_description"] @pytest.mark.anyio - async def test_token_invalid_refresh_token(self, test_client, registered_client): + async def test_token_invalid_refresh_token(self, test_client: httpx.AsyncClient, registered_client: dict[str, Any]): """Test token endpoint error - refresh token does not exist.""" # Try to use a non-existent refresh token response = await test_client.post( @@ -513,11 +533,10 @@ async def test_token_invalid_refresh_token(self, test_client, registered_client) @pytest.mark.anyio async def test_token_expired_refresh_token( self, - test_client, - registered_client, - auth_code, - pkce_challenge, - mock_oauth_provider, + test_client: httpx.AsyncClient, + registered_client: dict[str, Any], + auth_code: dict[str, str], + pkce_challenge: dict[str, str], ): """Test token endpoint error - refresh token has expired.""" # Step 1: First, let's create a token and refresh token at the current time @@ -560,7 +579,13 @@ async def test_token_expired_refresh_token( assert "refresh token has expired" in error_response["error_description"] @pytest.mark.anyio - async def test_token_invalid_scope(self, test_client, registered_client, auth_code, pkce_challenge): + async def test_token_invalid_scope( + self, + test_client: httpx.AsyncClient, + registered_client: dict[str, Any], + auth_code: dict[str, str], + pkce_challenge: dict[str, str], + ): """Test token endpoint error - invalid scope in refresh token request.""" # Exchange authorization code for tokens token_response = await test_client.post( @@ -664,8 +689,9 @@ async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncCli @pytest.mark.anyio async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient): """Test client registration with empty redirect_uris array.""" + redirect_uris: list[str] = [] client_metadata = { - "redirect_uris": [], # Empty array + "redirect_uris": redirect_uris, # Empty array "client_name": "Test Client", } @@ -682,12 +708,7 @@ async def test_client_registration_empty_redirect_uris(self, test_client: httpx. ) @pytest.mark.anyio - async def test_authorize_form_post( - self, - test_client: httpx.AsyncClient, - mock_oauth_provider: MockOAuthProvider, - pkce_challenge, - ): + async def test_authorize_form_post(self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]): """Test the authorization endpoint using POST with form-encoded data.""" # Register a client client_metadata = { @@ -730,7 +751,7 @@ async def test_authorization_get( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, - pkce_challenge, + pkce_challenge: dict[str, str], ): """Test the full authorization flow.""" # 1. Register a client @@ -836,7 +857,7 @@ async def test_authorization_get( assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None @pytest.mark.anyio - async def test_revoke_invalid_token(self, test_client, registered_client): + async def test_revoke_invalid_token(self, test_client: httpx.AsyncClient, registered_client: dict[str, Any]): """Test revoking an invalid token.""" response = await test_client.post( "/revoke", @@ -850,7 +871,7 @@ async def test_revoke_invalid_token(self, test_client, registered_client): assert response.status_code == 200 @pytest.mark.anyio - async def test_revoke_with_malformed_token(self, test_client, registered_client): + async def test_revoke_with_malformed_token(self, test_client: httpx.AsyncClient, registered_client: dict[str, Any]): response = await test_client.post( "/revoke", data={ @@ -874,10 +895,7 @@ async def test_client_registration_disallowed_scopes(self, test_client: httpx.As "scope": "read write profile admin", # 'admin' is not in valid_scopes } - response = await test_client.post( - "/register", - json=client_metadata, - ) + response = await test_client.post("/register", json=client_metadata) assert response.status_code == 400 error_data = response.json() assert "error" in error_data @@ -895,10 +913,7 @@ async def test_client_registration_default_scopes( # No scope specified } - response = await test_client.post( - "/register", - json=client_metadata, - ) + response = await test_client.post("/register", json=client_metadata) assert response.status_code == 201 client_info = response.json() @@ -920,10 +935,7 @@ async def test_client_registration_invalid_grant_type(self, test_client: httpx.A "grant_types": ["authorization_code"], } - response = await test_client.post( - "/register", - json=client_metadata, - ) + response = await test_client.post("/register", json=client_metadata) assert response.status_code == 400 error_data = response.json() assert "error" in error_data @@ -935,7 +947,7 @@ class TestAuthorizeEndpointErrors: """Test error handling in the OAuth authorization endpoint.""" @pytest.mark.anyio - async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): + async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]): """Test authorization endpoint with missing client_id. According to the OAuth2.0 spec, if client_id is missing, the server should @@ -959,7 +971,7 @@ async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, assert "client_id" in response.text.lower() @pytest.mark.anyio - async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): + async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge: dict[str, str]): """Test authorization endpoint with invalid client_id. According to the OAuth2.0 spec, if client_id is invalid, the server should @@ -984,7 +996,7 @@ async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, @pytest.mark.anyio async def test_authorize_missing_redirect_uri( - self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] ): """Test authorization endpoint with missing redirect_uri. @@ -1010,7 +1022,7 @@ async def test_authorize_missing_redirect_uri( @pytest.mark.anyio async def test_authorize_invalid_redirect_uri( - self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] ): """Test authorization endpoint with invalid redirect_uri. @@ -1050,7 +1062,7 @@ async def test_authorize_invalid_redirect_uri( indirect=True, ) async def test_authorize_missing_redirect_uri_multiple_registered( - self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] ): """Test endpoint with missing redirect_uri with multiple registered URIs. @@ -1076,7 +1088,7 @@ async def test_authorize_missing_redirect_uri_multiple_registered( @pytest.mark.anyio async def test_authorize_unsupported_response_type( - self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] ): """Test authorization endpoint with unsupported response_type. @@ -1110,7 +1122,7 @@ async def test_authorize_unsupported_response_type( @pytest.mark.anyio async def test_authorize_missing_response_type( - self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] ): """Test authorization endpoint with missing response_type. @@ -1142,7 +1154,9 @@ async def test_authorize_missing_response_type( assert query_params["state"][0] == "test_state" @pytest.mark.anyio - async def test_authorize_missing_pkce_challenge(self, test_client: httpx.AsyncClient, registered_client): + async def test_authorize_missing_pkce_challenge( + self, test_client: httpx.AsyncClient, registered_client: dict[str, Any] + ): """Test authorization endpoint with missing PKCE code_challenge. Missing PKCE parameters should result in invalid_request error. @@ -1171,7 +1185,9 @@ async def test_authorize_missing_pkce_challenge(self, test_client: httpx.AsyncCl assert query_params["state"][0] == "test_state" @pytest.mark.anyio - async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, registered_client, pkce_challenge): + async def test_authorize_invalid_scope( + self, test_client: httpx.AsyncClient, registered_client: dict[str, Any], pkce_challenge: dict[str, str] + ): """Test authorization endpoint with invalid scope. Invalid scope should redirect with invalid_scope error. diff --git a/tests/server/fastmcp/prompts/test_base.py b/tests/server/fastmcp/prompts/test_base.py index 5b7b50e63..4e3a98aa8 100644 --- a/tests/server/fastmcp/prompts/test_base.py +++ b/tests/server/fastmcp/prompts/test_base.py @@ -1,13 +1,9 @@ +from typing import Any + import pytest from pydantic import FileUrl -from mcp.server.fastmcp.prompts.base import ( - AssistantMessage, - Message, - Prompt, - TextContent, - UserMessage, -) +from mcp.server.fastmcp.prompts.base import AssistantMessage, Message, Prompt, TextContent, UserMessage from mcp.types import EmbeddedResource, TextResourceContents @@ -65,7 +61,7 @@ async def fn() -> AssistantMessage: @pytest.mark.anyio async def test_fn_returns_multiple_messages(self): - expected = [ + expected: list[Message] = [ UserMessage("Hello, world!"), AssistantMessage("How can I help you today?"), UserMessage("I'm looking for a restaurant in the center of town."), @@ -160,7 +156,7 @@ async def fn() -> list[Message]: async def test_fn_returns_dict_with_resource(self): """Test returning a dict with resource content.""" - async def fn() -> dict: + async def fn() -> dict[str, Any]: return { "role": "user", "content": { diff --git a/tests/server/fastmcp/prompts/test_manager.py b/tests/server/fastmcp/prompts/test_manager.py index 82b234638..3239426f9 100644 --- a/tests/server/fastmcp/prompts/test_manager.py +++ b/tests/server/fastmcp/prompts/test_manager.py @@ -17,7 +17,7 @@ def fn() -> str: assert added == prompt assert manager.get_prompt("fn") == prompt - def test_add_duplicate_prompt(self, caplog): + def test_add_duplicate_prompt(self, caplog: pytest.LogCaptureFixture): """Test adding the same prompt twice.""" def fn() -> str: @@ -30,7 +30,7 @@ def fn() -> str: assert first == second assert "Prompt already exists" in caplog.text - def test_disable_warn_on_duplicate_prompts(self, caplog): + def test_disable_warn_on_duplicate_prompts(self, caplog: pytest.LogCaptureFixture): """Test disabling warning on duplicate prompts.""" def fn() -> str: diff --git a/tests/server/fastmcp/resources/test_function_resources.py b/tests/server/fastmcp/resources/test_function_resources.py index f59436ae3..f30c6e713 100644 --- a/tests/server/fastmcp/resources/test_function_resources.py +++ b/tests/server/fastmcp/resources/test_function_resources.py @@ -60,7 +60,7 @@ def get_data() -> bytes: async def test_json_conversion(self): """Test automatic JSON conversion of non-string results.""" - def get_data() -> dict: + def get_data() -> dict[str, str]: return {"key": "value"} resource = FunctionResource( diff --git a/tests/server/fastmcp/resources/test_resource_manager.py b/tests/server/fastmcp/resources/test_resource_manager.py index 4423e5315..bab0e9ad8 100644 --- a/tests/server/fastmcp/resources/test_resource_manager.py +++ b/tests/server/fastmcp/resources/test_resource_manager.py @@ -4,12 +4,7 @@ import pytest from pydantic import AnyUrl, FileUrl -from mcp.server.fastmcp.resources import ( - FileResource, - FunctionResource, - ResourceManager, - ResourceTemplate, -) +from mcp.server.fastmcp.resources import FileResource, FunctionResource, ResourceManager, ResourceTemplate @pytest.fixture @@ -57,7 +52,7 @@ def test_add_duplicate_resource(self, temp_file: Path): assert first == second assert manager.list_resources() == [resource] - def test_warn_on_duplicate_resources(self, temp_file: Path, caplog): + def test_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): """Test warning on duplicate resources.""" manager = ResourceManager() resource = FileResource( @@ -69,7 +64,7 @@ def test_warn_on_duplicate_resources(self, temp_file: Path, caplog): manager.add_resource(resource) assert "Resource already exists" in caplog.text - def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog): + def test_disable_warn_on_duplicate_resources(self, temp_file: Path, caplog: pytest.LogCaptureFixture): """Test disabling warning on duplicate resources.""" manager = ResourceManager(warn_on_duplicate_resources=False) resource = FileResource( diff --git a/tests/server/fastmcp/resources/test_resource_template.py b/tests/server/fastmcp/resources/test_resource_template.py index f47244361..f9b91a0a1 100644 --- a/tests/server/fastmcp/resources/test_resource_template.py +++ b/tests/server/fastmcp/resources/test_resource_template.py @@ -1,4 +1,5 @@ import json +from typing import Any import pytest from pydantic import BaseModel @@ -12,7 +13,7 @@ class TestResourceTemplate: def test_template_creation(self): """Test creating a template from a function.""" - def my_func(key: str, value: int) -> dict: + def my_func(key: str, value: int) -> dict[str, Any]: return {"key": key, "value": value} template = ResourceTemplate.from_function( @@ -23,13 +24,12 @@ def my_func(key: str, value: int) -> dict: assert template.uri_template == "test://{key}/{value}" assert template.name == "test" assert template.mime_type == "text/plain" # default - test_input = {"key": "test", "value": 42} - assert template.fn(**test_input) == my_func(**test_input) + assert template.fn(key="test", value=42) == my_func(key="test", value=42) def test_template_matches(self): """Test matching URIs against a template.""" - def my_func(key: str, value: int) -> dict: + def my_func(key: str, value: int) -> dict[str, Any]: return {"key": key, "value": value} template = ResourceTemplate.from_function( @@ -50,7 +50,7 @@ def my_func(key: str, value: int) -> dict: async def test_create_resource(self): """Test creating a resource from a template.""" - def my_func(key: str, value: int) -> dict: + def my_func(key: str, value: int) -> dict[str, Any]: return {"key": key, "value": value} template = ResourceTemplate.from_function( diff --git a/tests/server/fastmcp/servers/test_file_server.py b/tests/server/fastmcp/servers/test_file_server.py index b40778ea8..df7024552 100644 --- a/tests/server/fastmcp/servers/test_file_server.py +++ b/tests/server/fastmcp/servers/test_file_server.py @@ -7,7 +7,7 @@ @pytest.fixture() -def test_dir(tmp_path_factory) -> Path: +def test_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: """Create a temporary directory with test files.""" tmp = tmp_path_factory.mktemp("test_files") diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 20937d91d..f77e80e45 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -2,12 +2,17 @@ Test the elicitation feature using stdio transport. """ +from typing import Any + import pytest from pydantic import BaseModel, Field +from mcp.client.session import ClientSession, ElicitationFnT from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession +from mcp.shared.context import RequestContext from mcp.shared.memory import create_connected_server_and_client_session -from mcp.types import ElicitResult, TextContent +from mcp.types import ElicitRequestParams, ElicitResult, TextContent # Shared schema for basic tests @@ -19,11 +24,8 @@ def create_ask_user_tool(mcp: FastMCP): """Create a standard ask_user tool that handles all elicitation responses.""" @mcp.tool(description="A tool that uses elicitation") - async def ask_user(prompt: str, ctx: Context) -> str: - result = await ctx.elicit( - message=f"Tool wants to ask: {prompt}", - schema=AnswerSchema, - ) + async def ask_user(prompt: str, ctx: Context[ServerSession, None]) -> str: + result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) if result.action == "accept" and result.data: return f"User answered: {result.data.answer}" @@ -37,9 +39,9 @@ async def ask_user(prompt: str, ctx: Context) -> str: async def call_tool_and_assert( mcp: FastMCP, - elicitation_callback, + elicitation_callback: ElicitationFnT, tool_name: str, - args: dict, + args: dict[str, Any], expected_text: str | None = None, text_contains: list[str] | None = None, ): @@ -69,7 +71,7 @@ async def test_stdio_elicitation(): create_ask_user_tool(mcp) # Create a custom handler for elicitation requests - async def elicitation_callback(context, params): + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="accept", content={"answer": "Test User"}) else: @@ -86,7 +88,7 @@ async def test_stdio_elicitation_decline(): mcp = FastMCP(name="StdioElicitationDeclineServer") create_ask_user_tool(mcp) - async def elicitation_callback(context, params): + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): return ElicitResult(action="decline") await call_tool_and_assert( @@ -101,7 +103,7 @@ async def test_elicitation_schema_validation(): def create_validation_tool(name: str, schema_class: type[BaseModel]): @mcp.tool(name=name, description=f"Tool testing {name}") - async def tool(ctx: Context) -> str: + async def tool(ctx: Context[ServerSession, None]) -> str: try: await ctx.elicit(message="This should fail validation", schema=schema_class) return "Should not reach here" @@ -124,7 +126,7 @@ class InvalidNestedSchema(BaseModel): create_validation_tool("nested_model", InvalidNestedSchema) # Dummy callback (won't be called due to validation failure) - async def elicitation_callback(context, params): + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): return ElicitResult(action="accept", content={}) async with create_connected_server_and_client_session( @@ -153,7 +155,7 @@ class OptionalSchema(BaseModel): subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?") @mcp.tool(description="Tool with optional fields") - async def optional_tool(ctx: Context) -> str: + async def optional_tool(ctx: Context[ServerSession, None]) -> str: result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema) if result.action == "accept" and result.data: @@ -168,7 +170,7 @@ async def optional_tool(ctx: Context) -> str: return f"User {result.action}" # Test cases with different field combinations - test_cases = [ + test_cases: list[tuple[dict[str, Any], str]] = [ ( # All fields provided {"required_name": "John Doe", "optional_age": 30, "optional_email": "john@example.com", "subscribe": True}, @@ -183,7 +185,7 @@ async def optional_tool(ctx: Context) -> str: for content, expected in test_cases: - async def callback(context, params): + async def callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): return ElicitResult(action="accept", content=content) await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) @@ -194,16 +196,19 @@ class InvalidOptionalSchema(BaseModel): optional_list: list[str] | None = Field(default=None, description="Invalid optional list") @mcp.tool(description="Tool with invalid optional field") - async def invalid_optional_tool(ctx: Context) -> str: + async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: try: await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) return "Should not reach here" except TypeError as e: return f"Validation failed: {str(e)}" + async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + return ElicitResult(action="accept", content={}) + await call_tool_and_assert( mcp, - lambda c, p: ElicitResult(action="accept", content={}), + elicitation_callback, "invalid_optional_tool", {}, text_contains=["Validation failed:", "optional_list"], diff --git a/tests/server/fastmcp/test_func_metadata.py b/tests/server/fastmcp/test_func_metadata.py index 7027443da..830cf816b 100644 --- a/tests/server/fastmcp/test_func_metadata.py +++ b/tests/server/fastmcp/test_func_metadata.py @@ -1,3 +1,9 @@ +# NOTE: Those were added because we actually want to test wrong type annotations. +# pyright: reportUnknownParameterType=false +# pyright: reportMissingParameterType=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownLambdaType=false +from collections.abc import Callable from dataclasses import dataclass from typing import Annotated, Any, TypedDict @@ -58,7 +64,7 @@ def complex_arguments_fn( an_int_with_equals_field: int = Field(1, ge=0), int_annotated_with_default: Annotated[int, Field(description="hey")] = 5, ) -> str: - _ = ( + _: Any = ( an_int, must_be_none, must_be_none_dumb_annotation, @@ -240,7 +246,7 @@ def func_dict_int_key() -> dict[int, str]: @pytest.mark.anyio async def test_lambda_function(): """Test lambda function schema and validation""" - fn = lambda x, y=5: x # noqa: E731 + fn: Callable[[str, int], str] = lambda x, y=5: x # noqa: E731 meta = func_metadata(lambda x, y=5: x) # Test schema @@ -899,7 +905,7 @@ def test_structured_output_unserializable_type_error(): class ConfigWithCallable: name: str # Callable defaults are not JSON serializable and will trigger Pydantic warnings - callback: Any = lambda x: x * 2 + callback: Callable[[Any], Any] = lambda x: x * 2 def func_returning_config_with_callable() -> ConfigWithCallable: return ConfigWithCallable() @@ -955,7 +961,7 @@ def func_with_aliases() -> ModelWithAliases: # Check that the actual output uses aliases too result = ModelWithAliases(**{"first": "hello", "second": "world"}) - unstructured_content, structured_content = meta.convert_result(result) + _, structured_content = meta.convert_result(result) # The structured content should use aliases to match the schema assert "first" in structured_content @@ -967,7 +973,7 @@ def func_with_aliases() -> ModelWithAliases: # Also test the case where we have a model with defaults to ensure aliases work in all cases result_with_defaults = ModelWithAliases() # Uses default None values - unstructured_content_defaults, structured_content_defaults = meta.convert_result(result_with_defaults) + _, structured_content_defaults = meta.convert_result(result_with_defaults) # Even with defaults, should use aliases in output assert "first" in structured_content_defaults diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 377e4923b..83fa10806 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -4,6 +4,11 @@ These tests validate the proper functioning of FastMCP features using focused, single-feature servers across different transports (SSE and StreamableHTTP). """ +# TODO(Marcelo): The `examples` package is not being imported as package. We need to solve this. +# pyright: reportUnknownMemberType=false +# pyright: reportMissingImports=false +# pyright: reportUnknownVariableType=false +# pyright: reportUnknownArgumentType=false import json import multiprocessing @@ -13,6 +18,7 @@ import pytest import uvicorn +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl from examples.snippets.servers import ( @@ -29,17 +35,27 @@ ) from mcp.client.session import ClientSession from mcp.client.sse import sse_client -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.shared.context import RequestContext +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder from mcp.types import ( + ClientResult, + CreateMessageRequestParams, CreateMessageResult, + ElicitRequestParams, ElicitResult, GetPromptResult, InitializeResult, LoggingMessageNotification, + LoggingMessageNotificationParams, + NotificationParams, ProgressNotification, + ProgressNotificationParams, ReadResourceResult, ResourceListChangedNotification, ServerNotification, + ServerRequest, TextContent, TextResourceContents, ToolListChangedNotification, @@ -50,12 +66,14 @@ class NotificationCollector: """Collects notifications from the server for testing.""" def __init__(self): - self.progress_notifications: list = [] - self.log_messages: list = [] - self.resource_notifications: list = [] - self.tool_notifications: list = [] - - async def handle_generic_notification(self, message) -> None: + self.progress_notifications: list[ProgressNotificationParams] = [] + self.log_messages: list[LoggingMessageNotificationParams] = [] + self.resource_notifications: list[NotificationParams | None] = [] + self.tool_notifications: list[NotificationParams | None] = [] + + async def handle_generic_notification( + self, message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception + ) -> None: """Handle any server notification and route to appropriate handler.""" if isinstance(message, ServerNotification): if isinstance(message.root, ProgressNotification): @@ -123,7 +141,7 @@ def run_server_with_transport(module_name: str, port: int, transport: str) -> No @pytest.fixture -def server_transport(request, server_port: int) -> Generator[str, None, None]: +def server_transport(request: pytest.FixtureRequest, server_port: int) -> Generator[str, None, None]: """Start server in a separate process with specified MCP instance and transport. Args: @@ -177,7 +195,14 @@ def create_client_for_transport(transport: str, server_url: str): raise ValueError(f"Invalid transport: {transport}") -def unpack_streams(client_streams): +def unpack_streams( + client_streams: tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] + | tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback, + ], +): """Unpack client streams handling different return values from SSE vs StreamableHTTP. SSE client returns (read_stream, write_stream) @@ -197,7 +222,9 @@ def unpack_streams(client_streams): # Callback functions for testing -async def sampling_callback(context, params) -> CreateMessageResult: +async def sampling_callback( + context: RequestContext[ClientSession, None], params: CreateMessageRequestParams +) -> CreateMessageResult: """Sampling callback for tests.""" return CreateMessageResult( role="assistant", @@ -209,7 +236,7 @@ async def sampling_callback(context, params) -> CreateMessageResult: ) -async def elicitation_callback(context, params): +async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): """Elicitation callback for tests.""" # For restaurant booking test if "No tables available" in params.message: @@ -367,7 +394,7 @@ async def test_tool_progress(server_transport: str, server_url: str) -> None: transport = server_transport collector = NotificationCollector() - async def message_handler(message): + async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): await collector.handle_generic_notification(message) if isinstance(message, Exception): raise message @@ -508,7 +535,7 @@ async def test_notifications(server_transport: str, server_url: str) -> None: transport = server_transport collector = NotificationCollector() - async def message_handler(message): + async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): await collector.handle_generic_notification(message) if isinstance(message, Exception): raise message diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index a9e0d182a..a4e72d1e9 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -11,6 +11,7 @@ from mcp.server.fastmcp.prompts.base import Message, UserMessage from mcp.server.fastmcp.resources import FileResource, FunctionResource from mcp.server.fastmcp.utilities.types import Image +from mcp.server.session import ServerSession from mcp.shared.exceptions import McpError from mcp.shared.memory import ( create_connected_server_and_client_session as client_session, @@ -338,8 +339,10 @@ async def test_tool_mixed_list_with_image(self, tmp_path: Path): image_path = tmp_path / "test.png" image_path.write_bytes(b"test image data") - def mixed_list_fn() -> list: - return [ + # TODO(Marcelo): It seems if we add the proper type hint, it generates an invalid JSON schema. + # We need to fix this. + def mixed_list_fn() -> list: # type: ignore + return [ # type: ignore "text message", Image(image_path), {"key": "value"}, @@ -347,7 +350,7 @@ def mixed_list_fn() -> list: ] mcp = FastMCP() - mcp.add_tool(mixed_list_fn) + mcp.add_tool(mixed_list_fn) # type: ignore async with client_session(mcp._mcp_server) as client: result = await client.call_tool("mixed_list_fn", {}) assert len(result.content) == 4 @@ -655,7 +658,7 @@ async def test_resource_with_untyped_params(self): mcp = FastMCP() @mcp.resource("resource://{param}") - def get_data(param) -> str: + def get_data(param) -> str: # type: ignore return "Data" @pytest.mark.anyio @@ -748,7 +751,7 @@ async def test_context_detection(self): """Test that context parameters are properly detected.""" mcp = FastMCP() - def tool_with_context(x: int, ctx: Context) -> str: + def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: return f"Request {ctx.request_id}: {x}" tool = mcp._tool_manager.add_tool(tool_with_context) @@ -759,7 +762,7 @@ async def test_context_injection(self): """Test that context is properly injected into tool calls.""" mcp = FastMCP() - def tool_with_context(x: int, ctx: Context) -> str: + def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: assert ctx.request_id is not None return f"Request {ctx.request_id}: {x}" @@ -777,7 +780,7 @@ async def test_async_context(self): """Test that context works in async functions.""" mcp = FastMCP() - async def async_tool(x: int, ctx: Context) -> str: + async def async_tool(x: int, ctx: Context[ServerSession, None]) -> str: assert ctx.request_id is not None return f"Async request {ctx.request_id}: {x}" @@ -792,12 +795,10 @@ async def async_tool(x: int, ctx: Context) -> str: @pytest.mark.anyio async def test_context_logging(self): - import mcp.server.session - """Test that context logging methods work.""" mcp = FastMCP() - async def logging_tool(msg: str, ctx: Context) -> str: + async def logging_tool(msg: str, ctx: Context[ServerSession, None]) -> str: await ctx.debug("Debug message") await ctx.info("Info message") await ctx.warning("Warning message") @@ -866,7 +867,7 @@ def test_resource() -> str: return "resource data" @mcp.tool() - async def tool_with_resource(ctx: Context) -> str: + async def tool_with_resource(ctx: Context[ServerSession, None]) -> str: r_iter = await ctx.read_resource("test://data") r_list = list(r_iter) assert len(r_list) == 1 diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 27e16cc8e..8b6168275 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -34,7 +34,7 @@ def sum(a: int, b: int) -> int: assert tool.parameters["properties"]["a"]["type"] == "integer" assert tool.parameters["properties"]["b"]["type"] == "integer" - def test_init_with_tools(self, caplog): + def test_init_with_tools(self, caplog: pytest.LogCaptureFixture): def sum(a: int, b: int) -> int: return a + b @@ -89,7 +89,7 @@ class UserInput(BaseModel): name: str age: int - def create_user(user: UserInput, flag: bool) -> dict: + def create_user(user: UserInput, flag: bool) -> dict[str, Any]: """Create a new user.""" return {"id": 1, **user.model_dump()} @@ -145,15 +145,15 @@ def test_add_invalid_tool(self): def test_add_lambda(self): manager = ToolManager() - tool = manager.add_tool(lambda x: x, name="my_tool") + tool = manager.add_tool(lambda x: x, name="my_tool") # type: ignore[reportUnknownLambdaType] assert tool.name == "my_tool" def test_add_lambda_with_no_name(self): manager = ToolManager() with pytest.raises(ValueError, match="You must provide a name for lambda functions"): - manager.add_tool(lambda x: x) + manager.add_tool(lambda x: x) # type: ignore[reportUnknownLambdaType] - def test_warn_on_duplicate_tools(self, caplog): + def test_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture): """Test warning on duplicate tools.""" def f(x: int) -> int: @@ -165,7 +165,7 @@ def f(x: int) -> int: manager.add_tool(f) assert "Tool already exists: f" in caplog.text - def test_disable_warn_on_duplicate_tools(self, caplog): + def test_disable_warn_on_duplicate_tools(self, caplog: pytest.LogCaptureFixture): """Test disabling warning on duplicate tools.""" def f(x: int) -> int: @@ -297,7 +297,7 @@ class Shrimp(BaseModel): shrimp: list[Shrimp] x: None - def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: + def name_shrimp(tank: MyShrimpTank, ctx: Context[ServerSessionT, None]) -> list[str]: return [x.name for x in tank.shrimp] manager = ToolManager() @@ -317,7 +317,7 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: class TestToolSchema: @pytest.mark.anyio async def test_context_arg_excluded_from_schema(self): - def something(a: int, ctx: Context) -> int: + def something(a: int, ctx: Context[ServerSessionT, None]) -> int: return a manager = ToolManager() @@ -334,7 +334,7 @@ def test_context_parameter_detection(self): """Test that context parameters are properly detected in Tool.from_function().""" - def tool_with_context(x: int, ctx: Context) -> str: + def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: return str(x) manager = ToolManager() @@ -357,7 +357,7 @@ def tool_with_parametrized_context(x: int, ctx: Context[ServerSessionT, Lifespan async def test_context_injection(self): """Test that context is properly injected during tool execution.""" - def tool_with_context(x: int, ctx: Context) -> str: + def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: assert isinstance(ctx, Context) return str(x) @@ -373,7 +373,7 @@ def tool_with_context(x: int, ctx: Context) -> str: async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" - async def async_tool(x: int, ctx: Context) -> str: + async def async_tool(x: int, ctx: Context[ServerSessionT, None]) -> str: assert isinstance(ctx, Context) return str(x) @@ -389,7 +389,7 @@ async def async_tool(x: int, ctx: Context) -> str: async def test_context_optional(self): """Test that context is optional when calling tools.""" - def tool_with_context(x: int, ctx: Context | None = None) -> str: + def tool_with_context(x: int, ctx: Context[ServerSessionT, None] | None = None) -> str: return str(x) manager = ToolManager() @@ -402,7 +402,7 @@ def tool_with_context(x: int, ctx: Context | None = None) -> str: async def test_context_error_handling(self): """Test error handling when context injection fails.""" - def tool_with_context(x: int, ctx: Context) -> str: + def tool_with_context(x: int, ctx: Context[ServerSessionT, None]) -> str: raise ValueError("Test error") manager = ToolManager() @@ -552,7 +552,7 @@ def get_numbers() -> list[int]: async def test_tool_without_structured_output(self): """Test that tools work normally when structured_output=False.""" - def get_dict() -> dict: + def get_dict() -> dict[str, Any]: """Get a dict.""" return {"key": "value"} diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 44b9a924d..e7149826b 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -1,5 +1,7 @@ """Test that cancelled requests don't cause double responses.""" +from typing import Any + import anyio import pytest @@ -41,7 +43,7 @@ async def handle_list_tools() -> list[Tool]: ] @server.call_tool() - async def handle_call_tool(name: str, arguments: dict | None) -> list: + async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: nonlocal call_count, first_request_id if name == "test_tool": call_count += 1 diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index f0d154587..f0864667d 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -2,6 +2,8 @@ Tests for completion handler with context functionality. """ +from typing import Any + import pytest from mcp.server.lowlevel import Server @@ -21,7 +23,7 @@ async def test_completion_handler_receives_context(): server = Server("test-server") # Track what the handler receives - received_args = {} + received_args: dict[str, Any] = {} @server.completion() async def handle_completion( diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index a3ff59bc1..9d73fd47a 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -2,6 +2,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from typing import Any import anyio import pytest @@ -10,6 +11,7 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.types import ( ClientCapabilities, @@ -18,6 +20,8 @@ JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, + JSONRPCResponse, + TextContent, ) @@ -35,29 +39,23 @@ async def test_lifespan(server: Server) -> AsyncIterator[dict[str, bool]]: finally: context["shutdown"] = True - server = Server("test", lifespan=test_lifespan) + server = Server[dict[str, bool]]("test", lifespan=test_lifespan) # Create memory streams for testing - send_stream1, receive_stream1 = anyio.create_memory_object_stream(100) - send_stream2, receive_stream2 = anyio.create_memory_object_stream(100) + send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) # Create a tool that accesses lifespan context @server.call_tool() - async def check_lifespan(name: str, arguments: dict) -> list: + async def check_lifespan(name: str, arguments: dict[str, Any]) -> list[TextContent]: ctx = server.request_context assert isinstance(ctx.lifespan_context, dict) assert ctx.lifespan_context["started"] assert not ctx.lifespan_context["shutdown"] - return [{"type": "text", "text": "true"}] + return [TextContent(type="text", text="true")] # Run server in background task - async with ( - anyio.create_task_group() as tg, - send_stream1, - receive_stream1, - send_stream2, - receive_stream2, - ): + async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: async def run_server(): await server.run( @@ -126,6 +124,8 @@ async def run_server(): # Get response and verify response = await receive_stream2.receive() response = response.message + assert isinstance(response, JSONRPCMessage) + assert isinstance(response.root, JSONRPCResponse) assert response.root.result["content"][0]["text"] == "true" # Cancel server task @@ -137,7 +137,7 @@ async def test_fastmcp_server_lifespan(): """Test that lifespan works in FastMCP server.""" @asynccontextmanager - async def test_lifespan(server: FastMCP) -> AsyncIterator[dict]: + async def test_lifespan(server: FastMCP) -> AsyncIterator[dict[str, bool]]: """Test lifespan context that tracks startup/shutdown.""" context = {"started": False, "shutdown": False} try: @@ -149,12 +149,12 @@ async def test_lifespan(server: FastMCP) -> AsyncIterator[dict]: server = FastMCP("test", lifespan=test_lifespan) # Create memory streams for testing - send_stream1, receive_stream1 = anyio.create_memory_object_stream(100) - send_stream2, receive_stream2 = anyio.create_memory_object_stream(100) + send_stream1, receive_stream1 = anyio.create_memory_object_stream[SessionMessage](100) + send_stream2, receive_stream2 = anyio.create_memory_object_stream[SessionMessage](100) # Add a tool that checks lifespan context @server.tool() - def check_lifespan(ctx: Context) -> bool: + def check_lifespan(ctx: Context[ServerSession, None]) -> bool: """Tool that checks lifespan context.""" assert isinstance(ctx.request_context.lifespan_context, dict) assert ctx.request_context.lifespan_context["started"] @@ -230,6 +230,8 @@ async def run_server(): # Get response and verify response = await receive_stream2.receive() response = response.message + assert isinstance(response, JSONRPCMessage) + assert isinstance(response.root, JSONRPCResponse) assert response.root.result["content"][0]["text"] == "true" # Cancel server task diff --git a/tests/server/test_lowlevel_input_validation.py b/tests/server/test_lowlevel_input_validation.py index 250159733..3f5fba3da 100644 --- a/tests/server/test_lowlevel_input_validation.py +++ b/tests/server/test_lowlevel_input_validation.py @@ -263,7 +263,7 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio -async def test_tool_not_in_list_logs_warning(caplog): +async def test_tool_not_in_list_logs_warning(caplog: pytest.LogCaptureFixture): """Test that calling a tool not in list_tools logs a warning and skips validation.""" tools = [ Tool( diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 154c3a368..89e807b29 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -1,3 +1,5 @@ +from typing import Any + import anyio import pytest @@ -16,8 +18,10 @@ CompletionContext, CompletionsCapability, InitializedNotification, + Prompt, PromptReference, PromptsCapability, + Resource, ResourcesCapability, ResourceTemplateReference, ServerCapabilities, @@ -80,7 +84,7 @@ async def run_server(): async def test_server_capabilities(): server = Server("test") notification_options = NotificationOptions() - experimental_capabilities = {} + experimental_capabilities: dict[str, Any] = {} # Initially no capabilities caps = server.get_capabilities(notification_options, experimental_capabilities) @@ -90,7 +94,7 @@ async def test_server_capabilities(): # Add a prompts handler @server.list_prompts() - async def list_prompts(): + async def list_prompts() -> list[Prompt]: return [] caps = server.get_capabilities(notification_options, experimental_capabilities) @@ -100,7 +104,7 @@ async def list_prompts(): # Add a resources handler @server.list_resources() - async def list_resources(): + async def list_resources() -> list[Resource]: return [] caps = server.get_capabilities(notification_options, experimental_capabilities) diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 2d1850b73..a1d1792f8 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -26,7 +26,7 @@ async def test_stdio_server(): read_stream, write_stream, ): - received_messages = [] + received_messages: list[JSONRPCMessage] = [] async with read_stream: async for message in read_stream: if isinstance(message, Exception): diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 9a4c695b8..7a8551e5c 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -1,13 +1,15 @@ """Tests for StreamableHTTPSessionManager.""" +from typing import Any from unittest.mock import AsyncMock, patch import anyio import pytest +from starlette.types import Message from mcp.server import streamable_http_manager from mcp.server.lowlevel import Server -from mcp.server.streamable_http import MCP_SESSION_ID_HEADER +from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager @@ -35,7 +37,7 @@ async def test_run_prevents_concurrent_calls(): app = Server("test-server") manager = StreamableHTTPSessionManager(app=app) - errors = [] + errors: list[Exception] = [] async def try_run(): try: @@ -67,7 +69,7 @@ async def test_handle_request_without_run_raises_error(): async def receive(): return {"type": "http.request", "body": b""} - async def send(message): + async def send(message: Message): pass # Should raise error because run() hasn't been called @@ -93,16 +95,16 @@ async def running_manager(): @pytest.mark.anyio -async def test_stateful_session_cleanup_on_graceful_exit(running_manager): +async def test_stateful_session_cleanup_on_graceful_exit(running_manager: tuple[StreamableHTTPSessionManager, Server]): manager, app = running_manager mock_mcp_run = AsyncMock(return_value=None) # This will be called by StreamableHTTPSessionManager's run_server -> self.app.run app.run = mock_mcp_run - sent_messages = [] + sent_messages: list[Message] = [] - async def mock_send(message): + async def mock_send(message: Message): sent_messages.append(message) scope = { @@ -148,15 +150,15 @@ async def mock_receive(): @pytest.mark.anyio -async def test_stateful_session_cleanup_on_exception(running_manager): +async def test_stateful_session_cleanup_on_exception(running_manager: tuple[StreamableHTTPSessionManager, Server]): manager, app = running_manager mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash")) app.run = mock_mcp_run - sent_messages = [] + sent_messages: list[Message] = [] - async def mock_send(message): + async def mock_send(message: Message): sent_messages.append(message) # If an exception occurs, the transport might try to send an error response # For this test, we mostly care that the session is established enough @@ -207,13 +209,13 @@ async def test_stateless_requests_memory_cleanup(): manager = StreamableHTTPSessionManager(app=app, stateless=True) # Track created transport instances - created_transports = [] + created_transports: list[StreamableHTTPServerTransport] = [] # Patch StreamableHTTPServerTransport constructor to track instances original_constructor = streamable_http_manager.StreamableHTTPServerTransport - def track_transport(*args, **kwargs): + def track_transport(*args: Any, **kwargs: Any) -> StreamableHTTPServerTransport: transport = original_constructor(*args, **kwargs) created_transports.append(transport) return transport @@ -224,9 +226,9 @@ def track_transport(*args, **kwargs): app.run = AsyncMock(return_value=None) # Send a simple request - sent_messages = [] + sent_messages: list[Message] = [] - async def mock_send(message): + async def mock_send(message: Message): sent_messages.append(message) scope = { diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index a0c32f556..16bd6cb93 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -4,13 +4,8 @@ from mcp.client.session import ClientSession from mcp.server import Server -from mcp.shared.memory import ( - create_connected_server_and_client_session, -) -from mcp.types import ( - EmptyResult, - Resource, -) +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import EmptyResult, Resource @pytest.fixture diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 93cc712b4..1e13031e6 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -11,11 +11,7 @@ from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.shared.progress import progress -from mcp.shared.session import ( - BaseSession, - RequestResponder, - SessionMessage, -) +from mcp.shared.session import BaseSession, RequestResponder, SessionMessage @pytest.mark.anyio @@ -47,8 +43,8 @@ async def run_server(): raise e # Track progress updates - server_progress_updates = [] - client_progress_updates = [] + server_progress_updates: list[dict[str, Any]] = [] + client_progress_updates: list[dict[str, Any]] = [] # Progress tokens server_progress_token = "server_token_123" @@ -87,7 +83,7 @@ async def handle_list_tools() -> list[types.Tool]: # Register tool handler @server.call_tool() - async def handle_call_tool(name: str, arguments: dict | None) -> list: + async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: # Make sure we received a progress token if name == "test_tool": if arguments and "_meta" in arguments: @@ -124,7 +120,7 @@ async def handle_call_tool(name: str, arguments: dict | None) -> list: else: raise ValueError("Progress token not sent.") - return ["Tool executed successfully"] + return [types.TextContent(type="text", text="Tool executed successfully")] raise ValueError(f"Unknown tool: {name}") @@ -217,7 +213,7 @@ async def test_progress_context_manager(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) # Track progress updates - server_progress_updates = [] + server_progress_updates: list[dict[str, Any]] = [] server = Server(name="ProgressContextTestServer") @@ -230,12 +226,7 @@ async def handle_progress( message: str | None, ): server_progress_updates.append( - { - "token": progress_token, - "progress": progress, - "total": total, - "message": message, - } + {"token": progress_token, "progress": progress, "total": total, "message": message} ) # Run server session to receive progress updates @@ -288,13 +279,7 @@ async def handle_client_message( ) # cast for type checker - typed_context = cast( - RequestContext[ - BaseSession[Any, Any, Any, Any, Any], - Any, - ], - request_context, - ) + typed_context = cast(RequestContext[BaseSession[Any, Any, Any, Any, Any], Any], request_context) # Utilize progress context manager with progress(typed_context, total=100) as p: diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 864e0d1b4..c2c023c71 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator +from typing import Any import anyio import pytest @@ -7,16 +8,14 @@ from mcp.client.session import ClientSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError -from mcp.shared.memory import ( - create_client_server_memory_streams, - create_connected_server_and_client_session, -) +from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session from mcp.types import ( CancelledNotification, CancelledNotificationParams, ClientNotification, ClientRequest, EmptyResult, + TextContent, ) @@ -61,7 +60,7 @@ def make_server() -> Server: # Register the tool handler @server.call_tool() - async def handle_call_tool(name: str, arguments: dict | None) -> list: + async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[TextContent]: nonlocal request_id, ev_tool_called if name == "slow_tool": request_id = server.request_context.request_id @@ -83,7 +82,7 @@ async def handle_list_tools() -> list[types.Tool]: return server - async def make_request(client_session): + async def make_request(client_session: ClientSession): nonlocal ev_cancelled try: await client_session.send_request( @@ -134,14 +133,11 @@ async def test_connection_closed(): ev_closed = anyio.Event() ev_response = anyio.Event() - async with create_client_server_memory_streams() as ( - client_streams, - server_streams, - ): + async with create_client_server_memory_streams() as (client_streams, server_streams): client_read, client_write = client_streams server_read, server_write = server_streams - async def make_request(client_session): + async def make_request(client_session: ClientSession): """Send a request in a separate task""" nonlocal ev_response try: @@ -165,10 +161,7 @@ async def mock_server(): async with ( anyio.create_task_group() as tg, - ClientSession( - read_stream=client_read, - write_stream=client_write, - ) as client_session, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, ): tg.start_soon(make_request, client_session) tg.start_soon(mock_server) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 39ae13524..7b0d89cb4 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -3,6 +3,7 @@ import socket import time from collections.abc import AsyncGenerator, Generator +from typing import Any import anyio import httpx @@ -74,7 +75,7 @@ async def handle_list_tools() -> list[Tool]: ] @self.call_tool() - async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: return [TextContent(type="text", text=f"Called {name}")] @@ -147,7 +148,7 @@ def server(server_port: int) -> Generator[None, None, None]: @pytest.fixture() -async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, None]: +async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: """Create test client""" async with httpx.AsyncClient(base_url=server_url) as client: yield client @@ -194,7 +195,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.fixture -async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() @@ -305,7 +306,7 @@ def __init__(self): super().__init__("request_context_server") @self.call_tool() - async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: headers_info = {} context = self.request_context if context.request: @@ -435,7 +436,7 @@ async def test_request_context_propagation(context_server: None, server_url: str @pytest.mark.anyio async def test_request_context_isolation(context_server: None, server_url: str) -> None: """Test that request contexts are isolated between different SSE clients.""" - contexts = [] + contexts: list[dict[str, Any]] = [] # Create multiple clients with different headers for i in range(3): @@ -501,7 +502,7 @@ def test_sse_message_id_coercion(): ) def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result: str | type[Exception]): """Test that SseServerTransport properly validates and normalizes endpoints.""" - if isinstance(expected_result, type) and issubclass(expected_result, Exception): + if isinstance(expected_result, type): # Test invalid endpoints that should raise an exception with pytest.raises(expected_result, match="is not a relative path.*expecting a relative path"): SseServerTransport(endpoint) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3fea54f0b..ecbe6eb08 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -40,16 +40,9 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError -from mcp.shared.message import ( - ClientMessageMetadata, -) +from mcp.shared.message import ClientMessageMetadata from mcp.shared.session import RequestResponder -from mcp.types import ( - InitializeResult, - TextContent, - TextResourceContents, - Tool, -) +from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool # Test constants SERVER_NAME = "test_streamable_http_server" @@ -173,7 +166,7 @@ async def handle_list_tools() -> list[Tool]: ] @self.call_tool() - async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: ctx = self.request_context # When the tool is called, send a notification to test GET stream @@ -261,7 +254,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Called {name}")] -def create_app(is_json_response_enabled=False, event_store: EventStore | None = None) -> Starlette: +def create_app(is_json_response_enabled: bool = False, event_store: EventStore | None = None) -> Starlette: """Create a Starlette application for testing using the session manager. Args: @@ -294,7 +287,7 @@ def create_app(is_json_response_enabled=False, event_store: EventStore | None = return app -def run_server(port: int, is_json_response_enabled=False, event_store: EventStore | None = None) -> None: +def run_server(port: int, is_json_response_enabled: bool = False, event_store: EventStore | None = None) -> None: """Run the test server. Args: @@ -462,7 +455,7 @@ def json_server_url(json_server_port: int) -> str: # Basic request validation tests -def test_accept_header_validation(basic_server, basic_server_url): +def test_accept_header_validation(basic_server: None, basic_server_url: str): """Test that Accept header is properly validated.""" # Test without Accept header response = requests.post( @@ -474,7 +467,7 @@ def test_accept_header_validation(basic_server, basic_server_url): assert "Not Acceptable" in response.text -def test_content_type_validation(basic_server, basic_server_url): +def test_content_type_validation(basic_server: None, basic_server_url: str): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type response = requests.post( @@ -490,7 +483,7 @@ def test_content_type_validation(basic_server, basic_server_url): assert "Invalid Content-Type" in response.text -def test_json_validation(basic_server, basic_server_url): +def test_json_validation(basic_server: None, basic_server_url: str): """Test that JSON content is properly validated.""" # Test with invalid JSON response = requests.post( @@ -505,7 +498,7 @@ def test_json_validation(basic_server, basic_server_url): assert "Parse error" in response.text -def test_json_parsing(basic_server, basic_server_url): +def test_json_parsing(basic_server: None, basic_server_url: str): """Test that JSON content is properly parse.""" # Test with valid JSON but invalid JSON-RPC response = requests.post( @@ -520,7 +513,7 @@ def test_json_parsing(basic_server, basic_server_url): assert "Validation error" in response.text -def test_method_not_allowed(basic_server, basic_server_url): +def test_method_not_allowed(basic_server: None, basic_server_url: str): """Test that unsupported HTTP methods are rejected.""" # Test with unsupported method (PUT) response = requests.put( @@ -535,7 +528,7 @@ def test_method_not_allowed(basic_server, basic_server_url): assert "Method Not Allowed" in response.text -def test_session_validation(basic_server, basic_server_url): +def test_session_validation(basic_server: None, basic_server_url: str): """Test session ID validation.""" # session_id not used directly in this test @@ -610,7 +603,7 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server, basic_server_url): +def test_session_termination(basic_server: None, basic_server_url: str): """Test session termination via DELETE and subsequent request handling.""" response = requests.post( f"{basic_server_url}/mcp", @@ -650,7 +643,7 @@ def test_session_termination(basic_server, basic_server_url): assert "Session has been terminated" in response.text -def test_response(basic_server, basic_server_url): +def test_response(basic_server: None, basic_server_url: str): """Test response handling for a valid request.""" mcp_url = f"{basic_server_url}/mcp" response = requests.post( @@ -685,7 +678,7 @@ def test_response(basic_server, basic_server_url): assert tools_response.headers.get("Content-Type") == "text/event-stream" -def test_json_response(json_response_server, json_server_url): +def test_json_response(json_response_server: None, json_server_url: str): """Test response handling when is_json_response_enabled is True.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -700,7 +693,7 @@ def test_json_response(json_response_server, json_server_url): assert response.headers.get("Content-Type") == "application/json" -def test_get_sse_stream(basic_server, basic_server_url): +def test_get_sse_stream(basic_server: None, basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -760,7 +753,7 @@ def test_get_sse_stream(basic_server, basic_server_url): assert second_get.status_code == 409 -def test_get_validation(basic_server, basic_server_url): +def test_get_validation(basic_server: None, basic_server_url: str): """Test validation for GET requests.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -815,14 +808,14 @@ def test_get_validation(basic_server, basic_server_url): # Client-specific fixtures @pytest.fixture -async def http_client(basic_server, basic_server_url): +async def http_client(basic_server: None, basic_server_url: str): """Create test client matching the SSE test pattern.""" async with httpx.AsyncClient(base_url=basic_server_url) as client: yield client @pytest.fixture -async def initialized_client_session(basic_server, basic_server_url): +async def initialized_client_session(basic_server: None, basic_server_url: str): """Create initialized StreamableHTTP client session.""" async with streamablehttp_client(f"{basic_server_url}/mcp") as ( read_stream, @@ -838,7 +831,7 @@ async def initialized_client_session(basic_server, basic_server_url): @pytest.mark.anyio -async def test_streamablehttp_client_basic_connection(basic_server, basic_server_url): +async def test_streamablehttp_client_basic_connection(basic_server: None, basic_server_url: str): """Test basic client connection with initialization.""" async with streamablehttp_client(f"{basic_server_url}/mcp") as ( read_stream, @@ -856,16 +849,17 @@ async def test_streamablehttp_client_basic_connection(basic_server, basic_server @pytest.mark.anyio -async def test_streamablehttp_client_resource_read(initialized_client_session): +async def test_streamablehttp_client_resource_read(initialized_client_session: ClientSession): """Test client resource read functionality.""" response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource")) assert len(response.contents) == 1 assert response.contents[0].uri == AnyUrl("foobar://test-resource") + assert isinstance(response.contents[0], TextResourceContents) assert response.contents[0].text == "Read test-resource" @pytest.mark.anyio -async def test_streamablehttp_client_tool_invocation(initialized_client_session): +async def test_streamablehttp_client_tool_invocation(initialized_client_session: ClientSession): """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() @@ -880,7 +874,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session) @pytest.mark.anyio -async def test_streamablehttp_client_error_handling(initialized_client_session): +async def test_streamablehttp_client_error_handling(initialized_client_session: ClientSession): """Test error handling in client.""" with pytest.raises(McpError) as exc_info: await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error")) @@ -889,7 +883,7 @@ async def test_streamablehttp_client_error_handling(initialized_client_session): @pytest.mark.anyio -async def test_streamablehttp_client_session_persistence(basic_server, basic_server_url): +async def test_streamablehttp_client_session_persistence(basic_server: None, basic_server_url: str): """Test that session ID persists across requests.""" async with streamablehttp_client(f"{basic_server_url}/mcp") as ( read_stream, @@ -917,7 +911,7 @@ async def test_streamablehttp_client_session_persistence(basic_server, basic_ser @pytest.mark.anyio -async def test_streamablehttp_client_json_response(json_response_server, json_server_url): +async def test_streamablehttp_client_json_response(json_response_server: None, json_server_url: str): """Test client with JSON response mode.""" async with streamablehttp_client(f"{json_server_url}/mcp") as ( read_stream, @@ -945,12 +939,12 @@ async def test_streamablehttp_client_json_response(json_response_server, json_se @pytest.mark.anyio -async def test_streamablehttp_client_get_stream(basic_server, basic_server_url): +async def test_streamablehttp_client_get_stream(basic_server: None, basic_server_url: str): """Test GET stream functionality for server-initiated messages.""" import mcp.types as types from mcp.shared.session import RequestResponder - notifications_received = [] + notifications_received: list[types.ServerNotification] = [] # Define message handler to capture notifications async def message_handler( @@ -986,7 +980,7 @@ async def message_handler( @pytest.mark.anyio -async def test_streamablehttp_client_session_termination(basic_server, basic_server_url): +async def test_streamablehttp_client_session_termination(basic_server: None, basic_server_url: str): """Test client session termination functionality.""" captured_session_id = None @@ -1008,7 +1002,7 @@ async def test_streamablehttp_client_session_termination(basic_server, basic_ser tools = await session.list_tools() assert len(tools.tools) == 6 - headers = {} + headers: dict[str, str] = {} if captured_session_id: headers[MCP_SESSION_ID_HEADER] = captured_session_id @@ -1027,7 +1021,9 @@ async def test_streamablehttp_client_session_termination(basic_server, basic_ser @pytest.mark.anyio -async def test_streamablehttp_client_session_termination_204(basic_server, basic_server_url, monkeypatch): +async def test_streamablehttp_client_session_termination_204( + basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch +): """Test client session termination functionality with a 204 response. This test patches the httpx client to return a 204 response for DELETEs. @@ -1037,7 +1033,7 @@ async def test_streamablehttp_client_session_termination_204(basic_server, basic original_delete = httpx.AsyncClient.delete # Mock the client's delete method to return a 204 - async def mock_delete(self, *args, **kwargs): + async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> httpx.Response: # Call the original method to get the real response response = await original_delete(self, *args, **kwargs) @@ -1072,7 +1068,7 @@ async def mock_delete(self, *args, **kwargs): tools = await session.list_tools() assert len(tools.tools) == 6 - headers = {} + headers: dict[str, str] = {} if captured_session_id: headers[MCP_SESSION_ID_HEADER] = captured_session_id @@ -1091,14 +1087,14 @@ async def mock_delete(self, *args, **kwargs): @pytest.mark.anyio -async def test_streamablehttp_client_resumption(event_server): +async def test_streamablehttp_client_resumption(event_server: tuple[SimpleEventStore, str]): """Test client session resumption using sync primitives for reliable coordination.""" _, server_url = event_server # Variables to track the state captured_session_id = None captured_resumption_token = None - captured_notifications = [] + captured_notifications: list[types.ServerNotification] = [] captured_protocol_version = None first_notification_received = False @@ -1170,7 +1166,7 @@ async def run_tool(): captured_notifications = [] # Now resume the session with the same mcp-session-id and protocol version - headers = {} + headers: dict[str, Any] = {} if captured_session_id: headers[MCP_SESSION_ID_HEADER] = captured_session_id if captured_protocol_version: @@ -1211,11 +1207,12 @@ async def run_tool(): # We should have received the remaining notifications assert len(captured_notifications) == 1 + assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) assert captured_notifications[0].root.params.data == "Second notification after lock" @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server, basic_server_url): +async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): """Test server-initiated sampling request through streamable HTTP transport.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False @@ -1298,7 +1295,7 @@ async def handle_list_tools() -> list[Tool]: ] @self.call_tool() - async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: ctx = self.request_context if name == "echo_headers": @@ -1306,16 +1303,11 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: headers_info = {} if ctx.request and isinstance(ctx.request, Request): headers_info = dict(ctx.request.headers) - return [ - TextContent( - type="text", - text=json.dumps(headers_info), - ) - ] + return [TextContent(type="text", text=json.dumps(headers_info))] elif name == "echo_context": # Return full context information - context_data = { + context_data: dict[str, Any] = { "request_id": args.get("request_id"), "headers": {}, "method": None, @@ -1430,7 +1422,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: """Test that request contexts are isolated between StreamableHTTP clients.""" - contexts = [] + contexts: list[dict[str, Any]] = [] # Create multiple clients with different headers for i in range(3): @@ -1462,7 +1454,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server, basic_server_url): +async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): """Test that client includes mcp-protocol-version header after initialization.""" async with streamablehttp_client(f"{basic_server_url}/mcp") as ( read_stream, @@ -1486,7 +1478,7 @@ async def test_client_includes_protocol_version_header_after_init(context_aware_ assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version -def test_server_validates_protocol_version_header(basic_server, basic_server_url): +def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): """Test that server returns 400 Bad Request version if header unsupported or invalid.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1544,7 +1536,7 @@ def test_server_validates_protocol_version_header(basic_server, basic_server_url assert response.status_code == 200 -def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_server_url): +def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): """Test server accepts requests without protocol version header.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1574,7 +1566,7 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_ @pytest.mark.anyio -async def test_client_crash_handled(basic_server, basic_server_url): +async def test_client_crash_handled(basic_server: None, basic_server_url: str): """Test that cases where the client crashes are handled gracefully.""" # Simulate bad client that crashes after init diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 5081f1d53..2d67eccdd 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -2,6 +2,7 @@ import socket import time from collections.abc import AsyncGenerator, Generator +from typing import Any import anyio import pytest @@ -9,6 +10,7 @@ from pydantic import AnyUrl from starlette.applications import Starlette from starlette.routing import WebSocketRoute +from starlette.websockets import WebSocket from mcp.client.session import ClientSession from mcp.client.websocket import websocket_client @@ -67,7 +69,7 @@ async def handle_list_tools() -> list[Tool]: ] @self.call_tool() - async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]: return [TextContent(type="text", text=f"Called {name}")] @@ -76,7 +78,7 @@ def make_server_app() -> Starlette: """Create test Starlette app with WebSocket transport""" server = ServerTest() - async def handle_ws(websocket): + async def handle_ws(websocket: WebSocket): async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: await server.run(streams[0], streams[1], server.create_initialization_options()) @@ -133,7 +135,7 @@ def server(server_port: int) -> Generator[None, None, None]: @pytest.fixture() -async def initialized_ws_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: """Create and initialize a WebSocket client session""" async with websocket_client(server_url + "/ws") as streams: async with ClientSession(*streams) as session: diff --git a/tests/test_examples.py b/tests/test_examples.py index decffd810..59063f122 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,13 +1,16 @@ """Tests for example servers""" +# TODO(Marcelo): The `examples` directory needs to be importable as a package. +# pyright: reportMissingImports=false +# pyright: reportUnknownVariableType=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false import sys import pytest from pytest_examples import CodeExample, EvalExample, find_examples -from mcp.shared.memory import ( - create_connected_server_and_client_session as client_session, -) +from mcp.shared.memory import create_connected_server_and_client_session as client_session from mcp.types import TextContent, TextResourceContents @@ -42,7 +45,7 @@ async def test_complex_inputs(): @pytest.mark.anyio -async def test_desktop(monkeypatch): +async def test_desktop(monkeypatch: pytest.MonkeyPatch): """Test the desktop server""" from pathlib import Path @@ -52,7 +55,7 @@ async def test_desktop(monkeypatch): # Mock desktop directory listing mock_files = [Path("/fake/path/file1.txt"), Path("/fake/path/file2.txt")] - monkeypatch.setattr(Path, "iterdir", lambda self: mock_files) + monkeypatch.setattr(Path, "iterdir", lambda self: mock_files) # type: ignore[reportUnknownArgumentType] monkeypatch.setattr(Path, "home", lambda: Path("/fake/home")) async with client_session(mcp._mcp_server) as client: diff --git a/tests/test_types.py b/tests/test_types.py index a39d33412..d7f2ac831 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,11 +1,6 @@ import pytest -from mcp.types import ( - LATEST_PROTOCOL_VERSION, - ClientRequest, - JSONRPCMessage, - JSONRPCRequest, -) +from mcp.types import LATEST_PROTOCOL_VERSION, ClientRequest, JSONRPCMessage, JSONRPCRequest @pytest.mark.anyio diff --git a/uv.lock b/uv.lock index 7a34275ce..9cb602465 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" [manifest] @@ -627,7 +627,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.27" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" }, - { name = "pydantic", specifier = ">=2.8.0,<3.0.0" }, + { name = "pydantic", specifier = ">=2.11.0,<3.0.0" }, { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, { name = "python-multipart", specifier = ">=0.0.9" },