diff --git a/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py b/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py index ba8679e3ac..9066d49af2 100644 --- a/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py +++ b/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py @@ -29,7 +29,8 @@ import logging import os import sys -from urllib.parse import ParseResult, parse_qs, urlparse +from typing import Any, cast +from urllib.parse import parse_qs, urlparse import httpx from mcp import ClientSession @@ -39,12 +40,12 @@ PrivateKeyJWTOAuthProvider, SignedJWTParameters, ) -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken from pydantic import AnyUrl -def get_conformance_context() -> dict: +def get_conformance_context() -> dict[str, Any]: """Load conformance test context from MCP_CONFORMANCE_CONTEXT environment variable.""" context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT") if not context_json: @@ -116,9 +117,9 @@ async def handle_redirect(self, authorization_url: str) -> None: # Check for redirect response if response.status_code in (301, 302, 303, 307, 308): - location = response.headers.get("location") + location = cast(str, response.headers.get("location")) if location: - redirect_url: ParseResult = urlparse(location) + redirect_url = urlparse(location) query_params: dict[str, list[str]] = parse_qs(redirect_url.query) if "code" in query_params: @@ -259,12 +260,8 @@ async def run_client_credentials_basic_client(server_url: str) -> None: async def _run_session(server_url: str, oauth_auth: OAuthClientProvider) -> None: """Common session logic for all OAuth flows.""" # Connect using streamable HTTP transport with OAuth - async with streamablehttp_client( - url=server_url, - auth=oauth_auth, - timeout=30.0, - sse_read_timeout=60.0, - ) as (read_stream, write_stream, _): + client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0) + async with streamable_http_client(url=server_url, http_client=client) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # Initialize the session await session.initialize() diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 0223b72394..e95d36aa61 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -6,21 +6,26 @@ """ +from __future__ import annotations as _annotations + import asyncio import os +import socketserver import threading import time import webbrowser from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Any +from typing import Any, Callable from urllib.parse import parse_qs, urlparse import httpx +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from mcp.shared.message import SessionMessage class InMemoryTokenStorage(TokenStorage): @@ -46,7 +51,13 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None class CallbackHandler(BaseHTTPRequestHandler): """Simple HTTP handler to capture OAuth callback.""" - def __init__(self, request, client_address, server, callback_data): + def __init__( + self, + request: Any, + client_address: tuple[str, int], + server: socketserver.BaseServer, + callback_data: dict[str, Any], + ): """Initialize with callback data storage.""" self.callback_data = callback_data super().__init__(request, client_address, server) @@ -91,15 +102,14 @@ def do_GET(self): self.send_response(404) self.end_headers() - def log_message(self, format, *args): + def log_message(self, format: str, *args: Any): """Suppress default logging.""" - pass class CallbackServer: """Simple server to handle OAuth callbacks.""" - def __init__(self, port=3000): + def __init__(self, port: int = 3000): self.port = port self.server = None self.thread = None @@ -110,7 +120,12 @@ def _create_handler_with_data(self): callback_data = self.callback_data class DataCallbackHandler(CallbackHandler): - def __init__(self, request, client_address, server): + def __init__( + self, + request: BaseHTTPRequestHandler, + client_address: tuple[str, int], + server: socketserver.BaseServer, + ): super().__init__(request, client_address, server, callback_data) return DataCallbackHandler @@ -131,7 +146,7 @@ def stop(self): if self.thread: self.thread.join(timeout=1) - def wait_for_callback(self, timeout=300): + def wait_for_callback(self, timeout: int = 300): """Wait for OAuth callback with timeout.""" start_time = time.time() while time.time() - start_time < timeout: @@ -225,7 +240,12 @@ async def _default_redirect_handler(authorization_url: str) -> None: traceback.print_exc() - async def _run_session(self, read_stream, write_stream, get_session_id): + async def _run_session( + self, + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], + get_session_id: Callable[[], str | None] | None = None, + ): """Run the MCP session with the given streams.""" print("🤝 Initializing MCP session...") async with ClientSession(read_stream, write_stream) as session: @@ -314,7 +334,7 @@ async def interactive_loop(self): continue # Parse arguments (simple JSON-like format) - arguments = {} + arguments: dict[str, Any] = {} if len(parts) > 2: import json diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py index 78a81a4d9f..cec26aae0e 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import json import logging @@ -93,7 +95,7 @@ async def initialize(self) -> None: await self.cleanup() raise - async def list_tools(self) -> list[Any]: + async def list_tools(self) -> list[Tool]: """List available tools from the server. Returns: @@ -106,10 +108,10 @@ async def list_tools(self) -> list[Any]: raise RuntimeError(f"Server {self.name} not initialized") tools_response = await self.session.list_tools() - tools = [] + tools: list[Tool] = [] for item in tools_response: - if isinstance(item, tuple) and item[0] == "tools": + if item[0] == "tools": tools.extend(Tool(tool.name, tool.description, tool.inputSchema, tool.title) for tool in item[1]) return tools @@ -189,7 +191,7 @@ def format_for_llm(self) -> str: Returns: A formatted string describing the tool. """ - args_desc = [] + args_desc: list[str] = [] if "properties" in self.input_schema: for param_name, param_info in self.input_schema["properties"].items(): arg_desc = f"- {param_name}: {param_info.get('description', 'No description')}" @@ -311,9 +313,9 @@ def _clean_json_string(json_string: str) -> str: result = await server.execute_tool(tool_call["tool"], tool_call["arguments"]) if isinstance(result, dict) and "progress" in result: - progress = result["progress"] - total = result["total"] - percentage = (progress / total) * 100 + progress = result["progress"] # type: ignore + total = result["total"] # type: ignore + percentage = (progress / total) * 100 # type: ignore logging.info(f"Progress: {progress}/{total} ({percentage:.1f}%)") return f"Tool execution result: {result}" @@ -338,7 +340,7 @@ async def start(self) -> None: await self.cleanup_servers() return - all_tools = [] + all_tools: list[Tool] = [] for server in self.servers: tools = await server.list_tools() all_tools.extend(tools) diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/main.py b/examples/clients/simple-task-client/mcp_simple_task_client/main.py index 12691162ab..5050900208 100644 --- a/examples/clients/simple-task-client/mcp_simple_task_client/main.py +++ b/examples/clients/simple-task-client/mcp_simple_task_client/main.py @@ -4,12 +4,12 @@ import click from mcp import ClientSession -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client from mcp.types import CallToolResult, TextContent async def run(url: str) -> None: - async with streamablehttp_client(url) as (read, write, _): + async with streamable_http_client(url) as (read, write, _): async with ClientSession(read, write) as session: await session.initialize() @@ -28,12 +28,13 @@ async def run(url: str) -> None: task_id = result.task.taskId print(f"Task created: {task_id}") + status = None # Poll until done (respects server's pollInterval hint) async for status in session.experimental.poll_task(task_id): print(f" Status: {status.status} - {status.statusMessage or ''}") # Check final status - if status.status != "completed": + if status and status.status != "completed": print(f"Task ended with status: {status.status}") return diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py index a8a47dc57c..d99604ddf1 100644 --- a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py +++ b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py @@ -11,7 +11,7 @@ import click from mcp import ClientSession -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client from mcp.shared.context import RequestContext from mcp.types import ( CallToolResult, @@ -73,7 +73,7 @@ def get_text(result: CallToolResult) -> str: async def run(url: str) -> None: - async with streamablehttp_client(url) as (read, write, _): + async with streamable_http_client(url) as (read, write, _): async with ClientSession( read, write, diff --git a/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py index 1defd8eaa4..c00e20f2af 100644 --- a/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py +++ b/examples/clients/sse-polling-client/mcp_sse_polling_client/main.py @@ -20,7 +20,7 @@ import click from mcp import ClientSession -from mcp.client.streamable_http import streamablehttp_client +from mcp.client.streamable_http import streamable_http_client logger = logging.getLogger(__name__) @@ -34,7 +34,7 @@ async def run_demo(url: str, items: int, checkpoint_every: int) -> None: print(f"Processing {items} items with checkpoints every {checkpoint_every}") print(f"{'=' * 60}\n") - async with streamablehttp_client(url) as (read_stream, write_stream, _): + async with streamable_http_client(url) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: # Initialize the connection print("Initializing connection...") diff --git a/pyproject.toml b/pyproject.toml index 078a1dfdcb..b29f4810b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,13 @@ packages = ["src/mcp"] [tool.pyright] typeCheckingMode = "strict" -include = ["src/mcp", "tests", "examples/servers", "examples/snippets"] +include = [ + "src/mcp", + "tests", + "examples/servers", + "examples/snippets", + "examples/clients", +] venvPath = "." venv = ".venv" # The FastAPI style of using decorators in tests gives a `reportUnusedFunction` error. @@ -102,7 +108,9 @@ venv = ".venv" # those private functions instead of testing the private functions directly. It makes it easier to maintain the code source # and refactor code that is not public. executionEnvironments = [ - { root = "tests", extraPaths = ["."], reportUnusedFunction = false, reportPrivateUsage = false }, + { root = "tests", extraPaths = [ + ".", + ], reportUnusedFunction = false, reportPrivateUsage = false }, { root = "examples/servers", reportUnusedFunction = false }, ]