|
| 1 | +import asyncio |
| 2 | +from contextlib import asynccontextmanager |
| 3 | +from typing import Any, AsyncGenerator |
| 4 | + |
| 5 | +import anyio |
| 6 | +import mcp.types as types |
| 7 | +import pydantic |
| 8 | +from mcp.shared.message import SessionMessage |
| 9 | +from temporalio import workflow |
| 10 | + |
| 11 | +from .service import MCPService |
| 12 | + |
| 13 | + |
| 14 | +class WorkflowTransport: |
| 15 | + """ |
| 16 | + An MCP Transport for use in Temporal workflows. |
| 17 | +
|
| 18 | + This class provides a transport that proxies MCP requests from a Temporal Workflow to a Temporal |
| 19 | + Nexus service. It can be used to make MCP calls via `mcp.ClientSession` from Temporal workflow |
| 20 | + code. |
| 21 | +
|
| 22 | + Example: |
| 23 | + ```python async with WorkflowNexusTransport("my-endpoint") as (read_stream, write_stream): |
| 24 | + async with ClientSession(read_stream, write_stream) as session: |
| 25 | + await session.initialize() await session.list_tools() await |
| 26 | + session.call_tool("my-service/my-operation", {"arg": "value"}) |
| 27 | + ``` |
| 28 | + """ |
| 29 | + |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + endpoint: str, |
| 33 | + ): |
| 34 | + self.endpoint = endpoint |
| 35 | + |
| 36 | + @asynccontextmanager |
| 37 | + async def connect( |
| 38 | + self, |
| 39 | + ) -> AsyncGenerator[ |
| 40 | + tuple[ |
| 41 | + anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage], |
| 42 | + anyio.streams.memory.MemoryObjectSendStream[SessionMessage], |
| 43 | + ], |
| 44 | + None, |
| 45 | + ]: |
| 46 | + client_write, transport_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated] |
| 47 | + transport_write, client_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated] |
| 48 | + |
| 49 | + async def message_router() -> None: |
| 50 | + try: |
| 51 | + async for session_message in transport_read: |
| 52 | + request = session_message.message.root |
| 53 | + if not isinstance(request, types.JSONRPCRequest): |
| 54 | + # Ignore e.g. types.JSONRPCNotification |
| 55 | + continue |
| 56 | + result: types.Result | types.ErrorData |
| 57 | + try: |
| 58 | + match request: |
| 59 | + case types.JSONRPCRequest(method="initialize"): |
| 60 | + result = self._handle_initialize( |
| 61 | + types.InitializeRequestParams.model_validate(request.params) |
| 62 | + ) |
| 63 | + case types.JSONRPCRequest(method="tools/list"): |
| 64 | + result = await self._handle_list_tools() |
| 65 | + case types.JSONRPCRequest(method="tools/call"): |
| 66 | + result = await self._handle_call_tool( |
| 67 | + types.CallToolRequestParams.model_validate(request.params) |
| 68 | + ) |
| 69 | + case _: |
| 70 | + result = types.ErrorData( |
| 71 | + code=types.METHOD_NOT_FOUND, message=f"Unknown method: {request.method}" |
| 72 | + ) |
| 73 | + except pydantic.ValidationError as e: |
| 74 | + result = types.ErrorData(code=types.INVALID_PARAMS, message=f"Invalid request: {e}") |
| 75 | + |
| 76 | + match result: |
| 77 | + case types.Result(): |
| 78 | + response = self._json_rpc_result_response(request, result) |
| 79 | + case types.ErrorData(): |
| 80 | + response = self._json_rpc_error_response(request, result) |
| 81 | + |
| 82 | + await transport_write.send(SessionMessage(types.JSONRPCMessage(root=response))) |
| 83 | + |
| 84 | + except anyio.ClosedResourceError: |
| 85 | + pass |
| 86 | + finally: |
| 87 | + await transport_write.aclose() |
| 88 | + |
| 89 | + router_task = asyncio.create_task(message_router()) |
| 90 | + |
| 91 | + try: |
| 92 | + yield client_read, client_write |
| 93 | + finally: |
| 94 | + await client_write.aclose() |
| 95 | + router_task.cancel() |
| 96 | + try: |
| 97 | + await router_task |
| 98 | + except asyncio.CancelledError: |
| 99 | + pass |
| 100 | + await transport_read.aclose() |
| 101 | + |
| 102 | + def _handle_initialize(self, params: types.InitializeRequestParams) -> types.InitializeResult: |
| 103 | + # TODO: MCPService should implement this |
| 104 | + return types.InitializeResult( |
| 105 | + protocolVersion="2024-11-05", |
| 106 | + capabilities=types.ServerCapabilities(tools=types.ToolsCapability()), |
| 107 | + serverInfo=types.Implementation( |
| 108 | + name="nexus-mcp-transport", |
| 109 | + version="0.1.0", |
| 110 | + ), |
| 111 | + ) |
| 112 | + |
| 113 | + async def _handle_list_tools(self) -> types.ListToolsResult: |
| 114 | + nexus_client = workflow.create_nexus_client( |
| 115 | + endpoint=self.endpoint, |
| 116 | + service=MCPService, |
| 117 | + ) |
| 118 | + tools = await nexus_client.execute_operation(MCPService.list_tools, None) |
| 119 | + return types.ListToolsResult(tools=tools) |
| 120 | + |
| 121 | + async def _handle_call_tool(self, params: types.CallToolRequestParams) -> types.CallToolResult: |
| 122 | + service, _, operation = params.name.partition("/") |
| 123 | + nexus_client = workflow.create_nexus_client( |
| 124 | + endpoint=self.endpoint, |
| 125 | + service=service, |
| 126 | + ) |
| 127 | + result: Any = await nexus_client.execute_operation( |
| 128 | + operation, |
| 129 | + params.arguments or {}, |
| 130 | + ) |
| 131 | + if isinstance(result, dict): |
| 132 | + return types.CallToolResult(content=[], structuredContent=result) |
| 133 | + else: |
| 134 | + return types.CallToolResult(content=[types.TextContent(type="text", text=str(result))]) |
| 135 | + |
| 136 | + def _json_rpc_error_response(self, request: types.JSONRPCRequest, error: types.ErrorData) -> types.JSONRPCResponse: |
| 137 | + return types.JSONRPCResponse.model_validate( |
| 138 | + { |
| 139 | + "jsonrpc": "2.0", |
| 140 | + "id": request.id, |
| 141 | + "error": error.model_dump(), |
| 142 | + } |
| 143 | + ) |
| 144 | + |
| 145 | + def _json_rpc_result_response(self, request: types.JSONRPCRequest, result: types.Result) -> types.JSONRPCResponse: |
| 146 | + return types.JSONRPCResponse.model_validate( |
| 147 | + { |
| 148 | + "jsonrpc": "2.0", |
| 149 | + "id": request.id, |
| 150 | + "result": result.model_dump(), |
| 151 | + } |
| 152 | + ) |
0 commit comments