diff --git a/README.md b/README.md index 0ca039ae3..81638087b 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ - [Advanced Usage](#advanced-usage) - [Low-Level Server](#low-level-server) - [Writing MCP Clients](#writing-mcp-clients) + - [Custom Requests](#custom-requests) - [MCP Primitives](#mcp-primitives) - [Server Capabilities](#server-capabilities) - [Documentation](#documentation) @@ -621,6 +622,90 @@ if __name__ == "__main__": asyncio.run(run()) ``` +### Custom Requests + +The MCP sdk can be extended with custom requests to support use cases outside [Model Context Protocol specification](https://spec.modelcontextprotocol.io) + +*warning:* This capability is opt-in and must be explicitly declared in `experimental_capabilities` in the server/client's capabilities + +Example of MCP server side custom requests processing: + +```python +import asyncio as aio +from typing import Literal + +import anyio + +import mcp.types as types +from mcp.client.session import ClientSession +from mcp.server.lowlevel import Server +from mcp.shared.memory import create_client_server_memory_streams + +## define custom request type + + +class AddOneParams(types.RequestParams): + value: int + + +class AddOneRequest(types.CustomRequest[AddOneParams, Literal["add_one"]]): + method: Literal["add_one"] = "add_one" + params: AddOneParams + + +class AddOneResult(types.CustomResult): + result: int + + +async def run_all(): + async with anyio.create_task_group() as tg: + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + client_read, client_write = client_streams + server_read, server_write = server_streams + + server = Server("my-add-one-server") + + ## handle custom request type + @server.handle_custom_request(AddOneRequest) + async def handle_add_one_request(req: AddOneRequest) -> AddOneResult: + return AddOneResult(result=req.params.value + 1) + + tg.start_soon( + lambda: server.run( + server_read, + server_write, + server.create_initialization_options( + experimental_capabilities={"custom_requests": {}}, + ), + raise_exceptions=True, + ) + ) + + async with ClientSession( + read_stream=client_read, + write_stream=client_write, + experimental_capabilities={"custom_requests": {}}, + ) as client_session: + await client_session.initialize() + + ## send custom request type + req = AddOneRequest(params=AddOneParams(value=1)) + res = await client_session.send_custom_request( + req, response_type=AddOneResult + ) + print(res) + + tg.cancel_scope.cancel() + + +if __name__ == "__main__": + aio.run(run_all()) +``` + + ### MCP Primitives The MCP protocol defines three core primitives that servers can implement: diff --git a/examples/custom_requests/ttl.py b/examples/custom_requests/ttl.py new file mode 100644 index 000000000..544bc45bd --- /dev/null +++ b/examples/custom_requests/ttl.py @@ -0,0 +1,139 @@ +#!/usr/bin/env -S uv run --script +# /// script +# dependencies = [ +# "mcp", +# ] +# [tool.uv.sources] +# mcp = { path = "/workspace" } +# /// + +## +## The goal of this example is to demonstrate a workflow where +## users can define their own message types for MCP and how to +## process then client and/or server side. +## +## In this concrete example we demonstrate a new set of message types +## such that the client sends a request to the server and the server +## sends a response back to the client and back and forth until a TTL +## is reached. +## +## This is meant to demonstrate a possible future where MCP is used +## more bidirectionally as defined by a user. +## + + +import asyncio as aio +from typing import Any, Literal + +import anyio + +import mcp.types as types +from mcp.client.session import ClientSession, CustomRequestHandlerFnT +from mcp.server.lowlevel import Server +from mcp.shared.context import RequestContext +from mcp.shared.memory import create_client_server_memory_streams + +EXPERIMENTAL_CAPABILITIES: dict[str, dict[str, Any]] = {"custom_requests": {}} + +## Define a simple ttl protocol, sending a request to/from the client/server +## back and forth until a TTL is reached. + + +class TTLParams(types.RequestParams): + ttl: int + + +class TTLRequest(types.CustomRequest[TTLParams, Literal["ttl"]]): + method: Literal["ttl"] = "ttl" + params: TTLParams + + +class TTLPayloadResult(types.CustomResult): + message: str + + +async def run_all(): + async with anyio.create_task_group() as tg: + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + client_read, client_write = client_streams + server_read, server_write = server_streams + + ## MCP Server code + server = Server("my-custom-server") + + @server.handle_custom_request(TTLRequest) + async def handle_ttl_request(req: TTLRequest) -> TTLPayloadResult: + print(f"SERVER: RECEIVED REQUEST WITH TTL={req.params.ttl}") + if req.params.ttl > 0: + tg.start_soon( + server.request_context.session.send_custom_request, + TTLRequest( + params=TTLParams( + ttl=req.params.ttl - 1, + ) + ), + TTLPayloadResult, + ) + return TTLPayloadResult(message=f"Recieved ttl {req.params.ttl}!") + + tg.start_soon( + lambda: server.run( + server_read, + server_write, + server.create_initialization_options( + experimental_capabilities=EXPERIMENTAL_CAPABILITIES, + ), + raise_exceptions=True, + ) + ) + + ## MCP Client code + + class TTLPayloadResponder( + CustomRequestHandlerFnT[TTLRequest, TTLPayloadResult] + ): + async def __call__( + self, + context: RequestContext["ClientSession", Any], + message: TTLRequest, + ) -> TTLPayloadResult | types.ErrorData: + print(f"CLIENT: RECEIVED REQUEST WITH TTL={message.params.ttl}") + if message.params.ttl > 0: + tg.start_soon( + context.session.send_custom_request, + TTLRequest( + params=TTLParams( + ttl=message.params.ttl - 1, + ) + ), + TTLPayloadResult, + ) + return TTLPayloadResult( + message=f"Recieved ttl {message.params.ttl}!" + ) + + async with ClientSession( + read_stream=client_read, + write_stream=client_write, + experimental_capabilities=EXPERIMENTAL_CAPABILITIES, + custom_request_handlers={ + "ttl": TTLPayloadResponder(), + }, + ) as client_session: + await client_session.initialize() + + req = TTLRequest(params=TTLParams(ttl=8)) + print(f"Sending: {req}") + await client_session.send_custom_request( + req, response_type=TTLPayloadResult + ) + await anyio.sleep(1) + + tg.cancel_scope.cancel() + + +if __name__ == "__main__": + aio.run(run_all()) diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 0d3c372ce..d26a4b9ff 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -12,6 +12,9 @@ CompleteRequest, CreateMessageRequest, CreateMessageResult, + CustomRequest, + CustomRequestWrapper, + CustomRequestWrapperParams, ErrorData, GetPromptRequest, GetPromptResult, @@ -66,6 +69,9 @@ "CreateMessageRequest", "CreateMessageResult", "ErrorData", + "CustomRequest", + "CustomRequestWrapper", + "CustomRequestWrapperParams", "GetPromptRequest", "GetPromptResult", "Implementation", diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index e29797d17..06b3b28ac 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping from datetime import timedelta -from typing import Any, Protocol +from typing import Any, Generic, Protocol, TypeVar, get_args import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -11,6 +12,8 @@ from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") +CustomResultT = TypeVar("CustomResultT", bound=types.CustomResult, covariant=True) +CustomRequestMethodT = TypeVar("CustomRequestMethodT", bound=str) class SamplingFnT(Protocol): @@ -43,10 +46,30 @@ async def __call__( ) -> None: ... +class CustomRequestHandlerFnT(Protocol, Generic[types.CustomRequestT, CustomResultT]): + async def __call__( + self, + context: RequestContext["ClientSession", Any], + message: types.CustomRequestT, + ) -> CustomResultT | types.ErrorData: ... + + +async def _default_custom_request_handler( + context: RequestContext["ClientSession", Any], + message: types.CustomRequest[dict[str, Any] | None, str], +) -> types.CustomResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message=f"Custom request method {message.method} not supported", + ) + + async def _default_message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: ( + RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception + ), ) -> None: await anyio.lowlevel.checkpoint() @@ -100,6 +123,17 @@ def __init__( logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, + custom_request_handlers: ( + Mapping[ + str, + CustomRequestHandlerFnT[ + Any, + types.CustomResult, + ], + ] + | None + ) = None, + experimental_capabilities: dict[str, dict[str, Any]] | None = None, ) -> None: super().__init__( read_stream, @@ -113,6 +147,14 @@ def __init__( self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler + self._custom_request_handlers: Mapping[ + str, + CustomRequestHandlerFnT[ + Any, + types.CustomResult, + ], + ] = custom_request_handlers or {} + self._experimental_capabilities = experimental_capabilities or {} async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() @@ -131,7 +173,7 @@ async def initialize(self) -> types.InitializeResult: protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=types.ClientCapabilities( sampling=sampling, - experimental=None, + experimental=self._experimental_capabilities, roots=roots, ), clientInfo=self._client_info, @@ -183,6 +225,39 @@ async def send_progress_notification( ) ) + async def send_custom_request( + self, + request: types.CustomRequest[types.RequestParamsT, types.MethodT], + response_type: type[CustomResultT], + ) -> CustomResultT: + """Send a custom request.""" + if self._experimental_capabilities.get("custom_requests", None) is None: + raise RuntimeError( + "experimental capability 'custom_requests' must be set in the" + " client capabilities to send custom requests." + ) + request_params = ( + request.params.model_dump(by_alias=True, mode="json", exclude_none=True) + if isinstance(request.params, types.BaseModel) + else request.params + ) + inner_request = types.CustomRequest[dict[str, Any] | None, str]( + method=request.method, + params=request_params, + ) + result = await self.send_request( + types.ClientRequest( + types.CustomRequestWrapper( + method="custom/request", + params=types.CustomRequestWrapperParams( + inner=inner_request, + ), + ) + ), + types.CustomResultWrapper, + ) + return response_type.model_validate(result.payload) + async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: """Send a logging/setLevel request.""" return await self.send_request( @@ -361,11 +436,54 @@ async def _received_request( types.ClientResult(root=types.EmptyResult()) ) + case types.CustomRequestWrapper( + params=types.CustomRequestWrapperParams(inner=custom_request) + ): + with responder: + custom_request_handler = ( + self._custom_request_handlers.get( + custom_request.method, + _default_custom_request_handler, + ) + if self._experimental_capabilities.get("custom_requests", None) + is not None + else _default_custom_request_handler + ) + + ## TODO: Find a better way to get the type of the custom request + ## from the custom request handler. + orig_base = custom_request_handler.__orig_bases__[0] # type: ignore + (CustomRequestType, *_rest) = get_args(orig_base) + parsed_custom_request = CustomRequestType.model_validate( + custom_request.model_dump( + by_alias=True, + exclude_none=True, + mode="json", + ) + ) + + custom_request_response = await custom_request_handler( + ctx, + parsed_custom_request, + ) + client_response = ClientResponse.validate_python( + types.CustomResultWrapper( + payload=custom_request_response.model_dump( + by_alias=True, + exclude_none=True, + mode="json", + ) + ) + ) + await responder.respond(client_response) + async def _handle_incoming( self, - req: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + req: ( + RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception + ), ) -> None: """Handle incoming messages by forwarding to the message handler.""" await self._message_handler(req) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index dbaff3051..ff2a13a76 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -71,7 +71,7 @@ async def main(): import warnings from collections.abc import AsyncIterator, Awaitable, Callable, Iterable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Literal, TypeVar, get_args, get_origin, get_type_hints import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -140,6 +140,10 @@ def __init__( ] = { types.PingRequest: _ping_handler, } + self.custom_request_handlers: dict[ + str, # method literal for each custom request type + Callable[..., Awaitable[types.ServerResult]], + ] = {} self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_options = NotificationOptions() logger.debug(f"Initializing server '{name}'") @@ -458,9 +462,11 @@ async def handler(req: types.CompleteRequest): completion = await func(req.params.ref, req.params.argument) return types.ServerResult( types.CompleteResult( - completion=completion - if completion is not None - else types.Completion(values=[], total=None, hasMore=None), + completion=( + completion + if completion is not None + else types.Completion(values=[], total=None, hasMore=None) + ), ) ) @@ -469,6 +475,62 @@ async def handler(req: types.CompleteRequest): return decorator + def handle_custom_request(self, request_type: type[types.CustomRequestT]): + assert issubclass(request_type, types.CustomRequest), ( + f"Custom request type {request_type} must be a subclass of " + f"{types.CustomRequest}" + ) + + # Extract the method literal string from the custom request type + method_literal = None + try: + type_hints = get_type_hints(request_type) + if "method" in type_hints: + method_annotation = type_hints["method"] + if get_origin(method_annotation) is Literal: + args = get_args(method_annotation) + if args: + method_literal = args[0] + logger.debug(f"Extracted method literal: {method_literal}") + except Exception: + logger.debug(f"Failed to extract method literal from {request_type}") + + if method_literal is None: + logger.warning(f"Could not extract method literal from {request_type}") + raise ValueError(f"Could not extract method literal from {request_type}") + + def decorator( + func: Callable[ + [ + types.CustomRequestT, + ], + Awaitable[types.CustomResult], + ], + ): + logger.debug( + f"Registering handler for {request_type} under method " + f"literal {method_literal}" + ) + + async def handler(req: types.CustomRequestT): + result = await func( + request_type.model_validate( + req.model_dump( + by_alias=True, + exclude_none=True, + mode="json", + ) + ) + ) + return types.ServerResult( + types.CustomResultWrapper(payload=result.model_dump()) + ) + + self.custom_request_handlers[method_literal] = handler + return func + + return decorator + async def run( self, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], @@ -480,6 +542,16 @@ async def run( # in-process servers. raise_exceptions: bool = False, ): + if self.custom_request_handlers and ( + initialization_options.capabilities.experimental is None + or initialization_options.capabilities.experimental.get("custom_requests") + is None + ): + raise RuntimeError( + "server has custom request handlers but experimental capability " + "'custom_requests' is not set in the server capabilities." + ) + async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) session = await stack.enter_async_context( @@ -500,9 +572,11 @@ async def run( async def _handle_message( self, - message: RequestResponder[types.ClientRequest, types.ServerResult] - | types.ClientNotification - | Exception, + message: ( + RequestResponder[types.ClientRequest, types.ServerResult] + | types.ClientNotification + | Exception + ), session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, @@ -532,8 +606,32 @@ async def _handle_request( raise_exceptions: bool, ): logger.info(f"Processing request of type {type(req).__name__}") + + handler = None if type(req) in self.request_handlers: handler = self.request_handlers[type(req)] + elif isinstance(req, types.CustomRequestWrapper): + custom_request_method = req.params.inner.method + try: + handler = self.custom_request_handlers[custom_request_method] + req = req.params.inner + except KeyError: + logger.debug( + f"Custom request method {custom_request_method} does not match any " + f"custom request handler" + ) + pass + except Exception as err: + logger.error(f"Error handling custom request: {err}") + pass + if not handler: + await message.respond( + types.ErrorData( + code=types.METHOD_NOT_FOUND, + message="Method not found. oops", + ) + ) + else: logger.debug(f"Dispatching request of type {type(req).__name__}") token = None @@ -561,13 +659,6 @@ async def _handle_request( request_ctx.reset(token) await message.respond(response) - else: - await message.respond( - types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="Method not found", - ) - ) logger.debug("Response sent") diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 568ecd4b9..9c09bc81d 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -60,6 +60,7 @@ class InitializationState(Enum): ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") +CustomResultT = TypeVar("CustomResultT", bound=types.CustomResult) ServerRequestResponder = ( RequestResponder[types.ClientRequest, types.ServerResult] @@ -307,6 +308,44 @@ async def send_prompt_list_changed(self) -> None: ) ) + async def send_custom_request( + self, + request: types.CustomRequest[types.RequestParamsT, types.MethodT], + response_type: type[CustomResultT], + ) -> CustomResultT: + """Send a custom request.""" + if ( + self._init_options.capabilities.experimental is None + or self._init_options.capabilities.experimental.get("custom_requests") + is None + ): + raise RuntimeError( + "experimental capability 'custom_requests' must be set in the" + " server capabilities to send custom requests." + ) + + request_params = ( + request.params.model_dump(by_alias=True, mode="json", exclude_none=True) + if isinstance(request.params, types.BaseModel) + else request.params + ) + inner_request = types.CustomRequest[dict[str, Any] | None, str]( + method=request.method, + params=request_params, + ) + result = await self.send_request( + types.ServerRequest( + types.CustomRequestWrapper( + method="custom/request", + params=types.CustomRequestWrapperParams( + inner=inner_request, + ), + ) + ), + types.CustomResultWrapper, + ) + return response_type.model_validate(result.payload) + async def _handle_incoming(self, req: ServerRequestResponder) -> None: await self._incoming_message_stream_writer.send(req) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 05fd3ce37..5ab4e953c 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -312,9 +312,11 @@ async def _receive_loop(self) -> None: responder = RequestResponder( request_id=message.root.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, + request_meta=( + validated_request.root.params.meta + if validated_request.root.params + else None + ), request=validated_request, session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), @@ -386,9 +388,11 @@ async def send_progress_notification( async def _handle_incoming( self, - req: RequestResponder[ReceiveRequestT, SendResultT] - | ReceiveNotificationT - | Exception, + req: ( + RequestResponder[ReceiveRequestT, SendResultT] + | ReceiveNotificationT + | Exception + ), ) -> None: """A generic handler for incoming messages. Overwritten by subclasses.""" pass diff --git a/src/mcp/types.py b/src/mcp/types.py index bd71d51f0..fa82be883 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -8,7 +8,13 @@ TypeVar, ) -from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel +from pydantic import ( + BaseModel, + ConfigDict, + Field, + FileUrl, + RootModel, +) from pydantic.networks import AnyUrl, UrlConstraints """ @@ -79,6 +85,42 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]): model_config = ConfigDict(extra="allow") +class CustomRequest(Request[RequestParamsT, MethodT]): + """Base class for custom requests.""" + + ... + + +CustomRequestT = TypeVar( + "CustomRequestT", + bound=CustomRequest[Any, Any], + contravariant=True, +) + + +class CustomRequestWrapperParams(RequestParams): + """ + Parameters for the custom request wrapper. + """ + + inner: CustomRequest[dict[str, Any] | None, str] + """ + The custom request to be wrapped. + """ + + +class CustomRequestWrapper( + Request[CustomRequestWrapperParams, Literal["custom/request"]] +): + """ + This request is used when sending custom user defined requests from + the client to the server or vice versa. + """ + + method: Literal["custom/request"] + params: CustomRequestWrapperParams + + class PaginatedRequest(Request[RequestParamsT, MethodT]): cursor: Cursor | None = None """ @@ -107,6 +149,18 @@ class Result(BaseModel): """ +class CustomResult(Result): + """Base class for custom results.""" + + pass + + +class CustomResultWrapper(Result): + """Wrapper for custom results.""" + + payload: dict[str, Any] + + class PaginatedResult(Result): nextCursor: Cursor | None = None """ @@ -1075,6 +1129,7 @@ class ClientRequest( | UnsubscribeRequest | CallToolRequest | ListToolsRequest + | CustomRequestWrapper ] ): pass @@ -1091,11 +1146,17 @@ class ClientNotification( pass -class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]): +class ClientResult( + RootModel[EmptyResult | CreateMessageResult | ListRootsResult | CustomResultWrapper] +): pass -class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]): +class ServerRequest( + RootModel[ + PingRequest | CreateMessageRequest | ListRootsRequest | CustomRequestWrapper + ] +): pass @@ -1125,6 +1186,7 @@ class ServerResult( | ReadResourceResult | CallToolResult | ListToolsResult + | CustomResultWrapper ] ): pass