Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
import logging
import os
import sys
from urllib.parse import ParseResult, parse_qs, urlparse
from typing import Any, cast
from urllib.parse import parse_qs, urlparse

import httpx
from mcp import ClientSession
Expand All @@ -39,12 +40,12 @@
PrivateKeyJWTOAuthProvider,
SignedJWTParameters,
)
from mcp.client.streamable_http import streamablehttp_client
from mcp.client.streamable_http import streamable_http_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from pydantic import AnyUrl


def get_conformance_context() -> dict:
def get_conformance_context() -> dict[str, Any]:
"""Load conformance test context from MCP_CONFORMANCE_CONTEXT environment variable."""
context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT")
if not context_json:
Expand Down Expand Up @@ -116,9 +117,9 @@ async def handle_redirect(self, authorization_url: str) -> None:

# Check for redirect response
if response.status_code in (301, 302, 303, 307, 308):
location = response.headers.get("location")
location = cast(str, response.headers.get("location"))
if location:
redirect_url: ParseResult = urlparse(location)
redirect_url = urlparse(location)
query_params: dict[str, list[str]] = parse_qs(redirect_url.query)

if "code" in query_params:
Expand Down Expand Up @@ -259,12 +260,8 @@ async def run_client_credentials_basic_client(server_url: str) -> None:
async def _run_session(server_url: str, oauth_auth: OAuthClientProvider) -> None:
"""Common session logic for all OAuth flows."""
# Connect using streamable HTTP transport with OAuth
async with streamablehttp_client(
url=server_url,
auth=oauth_auth,
timeout=30.0,
sse_read_timeout=60.0,
) as (read_stream, write_stream, _):
client = httpx.AsyncClient(auth=oauth_auth, timeout=30.0)
async with streamable_http_client(url=server_url, http_client=client) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the session
await session.initialize()
Expand Down
38 changes: 29 additions & 9 deletions examples/clients/simple-auth-client/mcp_simple_auth_client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,26 @@

"""

from __future__ import annotations as _annotations

import asyncio
import os
import socketserver
import threading
import time
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from typing import Any, Callable
from urllib.parse import parse_qs, urlparse

import httpx
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamable_http_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from mcp.shared.message import SessionMessage


class InMemoryTokenStorage(TokenStorage):
Expand All @@ -46,7 +51,13 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
class CallbackHandler(BaseHTTPRequestHandler):
"""Simple HTTP handler to capture OAuth callback."""

def __init__(self, request, client_address, server, callback_data):
def __init__(
self,
request: Any,
client_address: tuple[str, int],
server: socketserver.BaseServer,
callback_data: dict[str, Any],
):
"""Initialize with callback data storage."""
self.callback_data = callback_data
super().__init__(request, client_address, server)
Expand Down Expand Up @@ -91,15 +102,14 @@ def do_GET(self):
self.send_response(404)
self.end_headers()

def log_message(self, format, *args):
def log_message(self, format: str, *args: Any):
"""Suppress default logging."""
pass


class CallbackServer:
"""Simple server to handle OAuth callbacks."""

def __init__(self, port=3000):
def __init__(self, port: int = 3000):
self.port = port
self.server = None
self.thread = None
Expand All @@ -110,7 +120,12 @@ def _create_handler_with_data(self):
callback_data = self.callback_data

class DataCallbackHandler(CallbackHandler):
def __init__(self, request, client_address, server):
def __init__(
self,
request: BaseHTTPRequestHandler,
client_address: tuple[str, int],
server: socketserver.BaseServer,
):
super().__init__(request, client_address, server, callback_data)

return DataCallbackHandler
Expand All @@ -131,7 +146,7 @@ def stop(self):
if self.thread:
self.thread.join(timeout=1)

def wait_for_callback(self, timeout=300):
def wait_for_callback(self, timeout: int = 300):
"""Wait for OAuth callback with timeout."""
start_time = time.time()
while time.time() - start_time < timeout:
Expand Down Expand Up @@ -225,7 +240,12 @@ async def _default_redirect_handler(authorization_url: str) -> None:

traceback.print_exc()

async def _run_session(self, read_stream, write_stream, get_session_id):
async def _run_session(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
get_session_id: Callable[[], str | None] | None = None,
):
"""Run the MCP session with the given streams."""
print("🤝 Initializing MCP session...")
async with ClientSession(read_stream, write_stream) as session:
Expand Down Expand Up @@ -314,7 +334,7 @@ async def interactive_loop(self):
continue

# Parse arguments (simple JSON-like format)
arguments = {}
arguments: dict[str, Any] = {}
if len(parts) > 2:
import json

Expand Down
18 changes: 10 additions & 8 deletions examples/clients/simple-chatbot/mcp_simple_chatbot/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import json
import logging
Expand Down Expand Up @@ -93,7 +95,7 @@ async def initialize(self) -> None:
await self.cleanup()
raise

async def list_tools(self) -> list[Any]:
async def list_tools(self) -> list[Tool]:
"""List available tools from the server.

Returns:
Expand All @@ -106,10 +108,10 @@ async def list_tools(self) -> list[Any]:
raise RuntimeError(f"Server {self.name} not initialized")

tools_response = await self.session.list_tools()
tools = []
tools: list[Tool] = []

for item in tools_response:
if isinstance(item, tuple) and item[0] == "tools":
if item[0] == "tools":
tools.extend(Tool(tool.name, tool.description, tool.inputSchema, tool.title) for tool in item[1])

return tools
Expand Down Expand Up @@ -189,7 +191,7 @@ def format_for_llm(self) -> str:
Returns:
A formatted string describing the tool.
"""
args_desc = []
args_desc: list[str] = []
if "properties" in self.input_schema:
for param_name, param_info in self.input_schema["properties"].items():
arg_desc = f"- {param_name}: {param_info.get('description', 'No description')}"
Expand Down Expand Up @@ -311,9 +313,9 @@ def _clean_json_string(json_string: str) -> str:
result = await server.execute_tool(tool_call["tool"], tool_call["arguments"])

if isinstance(result, dict) and "progress" in result:
progress = result["progress"]
total = result["total"]
percentage = (progress / total) * 100
progress = result["progress"] # type: ignore
total = result["total"] # type: ignore
percentage = (progress / total) * 100 # type: ignore
logging.info(f"Progress: {progress}/{total} ({percentage:.1f}%)")

return f"Tool execution result: {result}"
Expand All @@ -338,7 +340,7 @@ async def start(self) -> None:
await self.cleanup_servers()
return

all_tools = []
all_tools: list[Tool] = []
for server in self.servers:
tools = await server.list_tools()
all_tools.extend(tools)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import click
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.client.streamable_http import streamable_http_client
from mcp.types import CallToolResult, TextContent


async def run(url: str) -> None:
async with streamablehttp_client(url) as (read, write, _):
async with streamable_http_client(url) as (read, write, _):
async with ClientSession(read, write) as session:
await session.initialize()

Expand All @@ -28,12 +28,13 @@ async def run(url: str) -> None:
task_id = result.task.taskId
print(f"Task created: {task_id}")

status = None
# Poll until done (respects server's pollInterval hint)
async for status in session.experimental.poll_task(task_id):
print(f" Status: {status.status} - {status.statusMessage or ''}")

# Check final status
if status.status != "completed":
if status and status.status != "completed":
print(f"Task ended with status: {status.status}")
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import click
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.client.streamable_http import streamable_http_client
from mcp.shared.context import RequestContext
from mcp.types import (
CallToolResult,
Expand Down Expand Up @@ -73,7 +73,7 @@ def get_text(result: CallToolResult) -> str:


async def run(url: str) -> None:
async with streamablehttp_client(url) as (read, write, _):
async with streamable_http_client(url) as (read, write, _):
async with ClientSession(
read,
write,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import click
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.client.streamable_http import streamable_http_client

logger = logging.getLogger(__name__)

Expand All @@ -34,7 +34,7 @@ async def run_demo(url: str, items: int, checkpoint_every: int) -> None:
print(f"Processing {items} items with checkpoints every {checkpoint_every}")
print(f"{'=' * 60}\n")

async with streamablehttp_client(url) as (read_stream, write_stream, _):
async with streamable_http_client(url) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
# Initialize the connection
print("Initializing connection...")
Expand Down
12 changes: 10 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,13 @@ packages = ["src/mcp"]

[tool.pyright]
typeCheckingMode = "strict"
include = ["src/mcp", "tests", "examples/servers", "examples/snippets"]
include = [
"src/mcp",
"tests",
"examples/servers",
"examples/snippets",
"examples/clients",
]
venvPath = "."
venv = ".venv"
# The FastAPI style of using decorators in tests gives a `reportUnusedFunction` error.
Expand All @@ -102,7 +108,9 @@ venv = ".venv"
# those private functions instead of testing the private functions directly. It makes it easier to maintain the code source
# and refactor code that is not public.
executionEnvironments = [
{ root = "tests", extraPaths = ["."], reportUnusedFunction = false, reportPrivateUsage = false },
{ root = "tests", extraPaths = [
".",
], reportUnusedFunction = false, reportPrivateUsage = false },
{ root = "examples/servers", reportUnusedFunction = false },
]

Expand Down