diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e482470 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +*.egg-info/ \ No newline at end of file diff --git a/README.md b/README.md index a1599b4..8ec0515 100644 --- a/README.md +++ b/README.md @@ -99,11 +99,12 @@ class CalculatorHandler: return CalculateResponse(result=result) ``` -### 4. Create a Nexus endpoint +### 4. Create a namespace and Nexus endpoint **Local dev or self hosted deployment** ```bash +temporal operator namespace create --namespace my-handler-namespace temporal operator nexus endpoint create \ --name mcp-gateway \ --target-namespace my-handler-namespace \ @@ -117,13 +118,20 @@ temporal operator nexus endpoint create \ ### 3. Set Up the Temporal Worker with the Nexus handlers at `worker.py` ```python +import asyncio + from temporalio.client import Client +from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.worker import Worker from .service_handler import mcp_service_handler, CalculatorHandler async def main(): # Connect to Temporal (replace host and namespace as needed). - client = await Client.connect("localhost:7233", namespace="my-handler-namespace") + client = await Client.connect( + "localhost:7233", + namespace="my-handler-namespace", + data_converter=pydantic_data_converter, + ) async with Worker( client, @@ -136,17 +144,26 @@ async def main(): ### 4. Set Up the MCP Gateway +```bash +temporal operator namespace create --namespace my-caller-namespace +``` + ```python import asyncio from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions from temporalio.client import Client +from temporalio.contrib.pydantic import pydantic_data_converter from nexusmcp import InboundGateway async def main(): server = Server("nexus-mcp-demo") # Connect to Temporal (replace host and namespace as needed). - client = await Client.connect("localhost:7233", namespace="my-caller-namespace") + client = await Client.connect( + "localhost:7233", + namespace="my-caller-namespace", + data_converter=pydantic_data_converter, + ) # Create the MCP gateway gateway = InboundGateway( @@ -163,7 +180,7 @@ if __name__ == "__main__": asyncio.run(main()) ``` -### 4. Configure Your MCP Client +### 5. Configure Your MCP Client Add to your MCP client configuration (e.g., Claude Desktop): @@ -178,6 +195,57 @@ Add to your MCP client configuration (e.g., Claude Desktop): } ``` +### 6. Make MCP calls from a Temporal Workflow + +```python +import asyncio +import uuid + +from mcp import ClientSession +from nexusmcp import WorkflowTransport +from pydantic import BaseModel +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.worker import Worker + + +class AgentWorkflowInput(BaseModel): + endpoint: str + + +# The workflow must have the sandbox disabled +@workflow.defn(sandboxed=False) +class AgentWorkflow: + @workflow.run + async def run(self, input: AgentWorkflowInput): + transport = WorkflowTransport(input.endpoint) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + list_tools_result = await session.list_tools() + print(f"available tools: {list_tools_result}") + + +async def main(): + client = await Client.connect( + "localhost:7233", + data_converter=pydantic_data_converter, + ) + + async with Worker( + client, + task_queue="agent-workflow", + workflows=[AgentWorkflow], + ) as worker: + await client.execute_workflow( + AgentWorkflow.run, + AgentWorkflowInput(endpoint="mcp-gateway"), + id=str(uuid.uuid4()), + task_queue=worker.task_queue, + ) +``` + ## Usage Examples ### Tool Filtering diff --git a/nexusmcp/__init__.py b/nexusmcp/__init__.py index f60cb7b..de60797 100644 --- a/nexusmcp/__init__.py +++ b/nexusmcp/__init__.py @@ -4,5 +4,6 @@ from .inbound_gateway import InboundGateway from .service import MCPService from .service_handler import MCPServiceHandler, exclude + from .workflow_transport import WorkflowTransport -__all__ = ["MCPService", "MCPServiceHandler", "InboundGateway", "exclude"] +__all__ = ["MCPService", "MCPServiceHandler", "InboundGateway", "exclude", "WorkflowTransport"] diff --git a/nexusmcp/inbound_gateway.py b/nexusmcp/inbound_gateway.py index 19142f6..cfdf260 100644 --- a/nexusmcp/inbound_gateway.py +++ b/nexusmcp/inbound_gateway.py @@ -29,7 +29,7 @@ class InboundGateway: This gateway acts as an adapter between the Model Context Protocol (MCP) server and Temporal Nexus Operations, enabling tool calls to be executed reliably through Temporal's workflow engine. It handles both tool listing and - tool call requests by delegating them to corresponding a set of Temporal Nexus Services in a given endpoint. + tool call requests by delegating them to a corresponding set of Temporal Nexus Services in a given endpoint. """ _client: Client @@ -93,7 +93,7 @@ async def _handle_call_tool(self, name: str, arguments: dict[str, Any]) -> Any: {"a": 5, "b": 3} ) """ - service, operation = name.split("/", maxsplit=1) + service, _, operation = name.partition("/") if not service or not operation: raise ValueError(f"Invalid tool name: {name}, must be in the format 'service/operation'") return await self._client.execute_workflow( diff --git a/nexusmcp/proxy_workflow.py b/nexusmcp/proxy_workflow.py index b8e1478..f379c38 100644 --- a/nexusmcp/proxy_workflow.py +++ b/nexusmcp/proxy_workflow.py @@ -12,7 +12,6 @@ class ToolListInput(BaseModel): endpoint: str - pass class ToolCallInput(BaseModel): diff --git a/nexusmcp/workflow_transport.py b/nexusmcp/workflow_transport.py new file mode 100644 index 0000000..446d503 --- /dev/null +++ b/nexusmcp/workflow_transport.py @@ -0,0 +1,152 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator + +import anyio +import mcp.types as types +import pydantic +from mcp.shared.message import SessionMessage +from temporalio import workflow + +from .service import MCPService + + +class WorkflowTransport: + """ + An MCP Transport for use in Temporal workflows. + + This class provides a transport that proxies MCP requests from a Temporal Workflow to a Temporal + Nexus service. It can be used to make MCP calls via `mcp.ClientSession` from Temporal workflow + code. + + Example: + ```python async with WorkflowNexusTransport("my-endpoint") as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() await session.list_tools() await + session.call_tool("my-service/my-operation", {"arg": "value"}) + ``` + """ + + def __init__( + self, + endpoint: str, + ): + self.endpoint = endpoint + + @asynccontextmanager + async def connect( + self, + ) -> AsyncGenerator[ + tuple[ + anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage], + anyio.streams.memory.MemoryObjectSendStream[SessionMessage], + ], + None, + ]: + client_write, transport_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated] + transport_write, client_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated] + + async def message_router() -> None: + try: + async for session_message in transport_read: + request = session_message.message.root + if not isinstance(request, types.JSONRPCRequest): + # Ignore e.g. types.JSONRPCNotification + continue + result: types.Result | types.ErrorData + try: + match request: + case types.JSONRPCRequest(method="initialize"): + result = self._handle_initialize( + types.InitializeRequestParams.model_validate(request.params) + ) + case types.JSONRPCRequest(method="tools/list"): + result = await self._handle_list_tools() + case types.JSONRPCRequest(method="tools/call"): + result = await self._handle_call_tool( + types.CallToolRequestParams.model_validate(request.params) + ) + case _: + result = types.ErrorData( + code=types.METHOD_NOT_FOUND, message=f"Unknown method: {request.method}" + ) + except pydantic.ValidationError as e: + result = types.ErrorData(code=types.INVALID_PARAMS, message=f"Invalid request: {e}") + + match result: + case types.Result(): + response = self._json_rpc_result_response(request, result) + case types.ErrorData(): + response = self._json_rpc_error_response(request, result) + + await transport_write.send(SessionMessage(types.JSONRPCMessage(root=response))) + + except anyio.ClosedResourceError: + pass + finally: + await transport_write.aclose() + + router_task = asyncio.create_task(message_router()) + + try: + yield client_read, client_write + finally: + await client_write.aclose() + router_task.cancel() + try: + await router_task + except asyncio.CancelledError: + pass + await transport_read.aclose() + + def _handle_initialize(self, params: types.InitializeRequestParams) -> types.InitializeResult: + # TODO: MCPService should implement this + return types.InitializeResult( + protocolVersion="2024-11-05", + capabilities=types.ServerCapabilities(tools=types.ToolsCapability()), + serverInfo=types.Implementation( + name="nexus-mcp-transport", + version="0.1.0", + ), + ) + + async def _handle_list_tools(self) -> types.ListToolsResult: + nexus_client = workflow.create_nexus_client( + endpoint=self.endpoint, + service=MCPService, + ) + tools = await nexus_client.execute_operation(MCPService.list_tools, None) + return types.ListToolsResult(tools=tools) + + async def _handle_call_tool(self, params: types.CallToolRequestParams) -> types.CallToolResult: + service, _, operation = params.name.partition("/") + nexus_client = workflow.create_nexus_client( + endpoint=self.endpoint, + service=service, + ) + result: Any = await nexus_client.execute_operation( + operation, + params.arguments or {}, + ) + if isinstance(result, dict): + return types.CallToolResult(content=[], structuredContent=result) + else: + return types.CallToolResult(content=[types.TextContent(type="text", text=str(result))]) + + def _json_rpc_error_response(self, request: types.JSONRPCRequest, error: types.ErrorData) -> types.JSONRPCResponse: + return types.JSONRPCResponse.model_validate( + { + "jsonrpc": "2.0", + "id": request.id, + "error": error.model_dump(), + } + ) + + def _json_rpc_result_response(self, request: types.JSONRPCRequest, result: types.Result) -> types.JSONRPCResponse: + return types.JSONRPCResponse.model_validate( + { + "jsonrpc": "2.0", + "id": request.id, + "result": result.model_dump(), + } + ) diff --git a/pyproject.toml b/pyproject.toml index 6adf21c..83b5a86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.13" dependencies = [ + "anyio>=4.10.0", "mcp>=1.13.0", "nexus-rpc>=1.1.0", "pydantic>=2.11.7", diff --git a/tests/test_inbound_gateway.py b/tests/test_inbound_gateway.py index 58eac94..16385e7 100644 --- a/tests/test_inbound_gateway.py +++ b/tests/test_inbound_gateway.py @@ -1,13 +1,14 @@ import asyncio import anyio -from mcp.shared.message import SessionMessage import pytest from mcp import ClientSession from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions +from mcp.shared.message import SessionMessage from temporalio.api.nexus.v1 import EndpointSpec, EndpointTarget from temporalio.api.operatorservice.v1 import CreateNexusEndpointRequest +from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker @@ -23,7 +24,7 @@ async def test_inbound_gateway() -> None: endpoint_name = "endpoint" task_queue = "handler-queue" - async with await WorkflowEnvironment.start_local() as env: + async with await WorkflowEnvironment.start_local(data_converter=pydantic_data_converter) as env: await env.client.operator_service.create_nexus_endpoint( CreateNexusEndpointRequest( spec=EndpointSpec( diff --git a/tests/test_workflow_caller.py b/tests/test_workflow_caller.py new file mode 100644 index 0000000..0522d95 --- /dev/null +++ b/tests/test_workflow_caller.py @@ -0,0 +1,74 @@ +import uuid +from dataclasses import dataclass + +import pytest +from mcp import ClientSession +from temporalio import workflow +from temporalio.api.nexus.v1 import EndpointSpec, EndpointTarget +from temporalio.api.operatorservice.v1 import CreateNexusEndpointRequest +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker + +from nexusmcp import WorkflowTransport + +from .service import TestServiceHandler, mcp_service + + +@dataclass +class MCPCallerWorkflowInput: + endpoint: str + + +# sandbox disabled due to use of ThreadLocal by sniffio +# TODO: make this unnecessary +@workflow.defn(sandboxed=False) +class MCPCallerWorkflow: + @workflow.run + async def run(self, input: MCPCallerWorkflowInput) -> None: + transport = WorkflowTransport(input.endpoint) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + list_tools_result = await session.list_tools() + assert len(list_tools_result.tools) == 2 + assert list_tools_result.tools[0].name == "modified-service-name/modified-op-name" + assert list_tools_result.tools[1].name == "modified-service-name/op2" + + call_result = await session.call_tool("modified-service-name/modified-op-name", {"name": "World"}) + assert call_result.structuredContent == {"message": "Hello, World"} + + +@pytest.mark.asyncio +async def test_workflow_caller() -> None: + endpoint_name = "endpoint" + task_queue = "handler-queue" + + async with await WorkflowEnvironment.start_local(data_converter=pydantic_data_converter) as env: + await env.client.operator_service.create_nexus_endpoint( + CreateNexusEndpointRequest( + spec=EndpointSpec( + name=endpoint_name, + target=EndpointTarget( + worker=EndpointTarget.Worker( + namespace=env.client.namespace, + task_queue=task_queue, + ) + ), + ) + ) + ) + + async with Worker( + env.client, + task_queue=task_queue, + workflows=[MCPCallerWorkflow], + nexus_service_handlers=[mcp_service, TestServiceHandler()], + ): + await env.client.execute_workflow( + MCPCallerWorkflow.run, + arg=MCPCallerWorkflowInput(endpoint=endpoint_name), + id=str(uuid.uuid4()), + task_queue=task_queue, + ) diff --git a/uv.lock b/uv.lock index 81783bd..96e0e4e 100644 --- a/uv.lock +++ b/uv.lock @@ -216,6 +216,7 @@ name = "nexus-mcp" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "anyio" }, { name = "mcp" }, { name = "nexus-rpc" }, { name = "pydantic" }, @@ -232,10 +233,11 @@ dev = [ [package.metadata] requires-dist = [ + { name = "anyio", specifier = ">=4.10.0" }, { name = "mcp", specifier = ">=1.13.0" }, { name = "nexus-rpc", specifier = ">=1.1.0" }, { name = "pydantic", specifier = ">=2.11.7" }, - { name = "temporalio", git = "https://github.com/temporalio/sdk-python" }, + { name = "temporalio", git = "https://github.com/temporalio/sdk-python?branch=dan-9997-python-mcp-nexus" }, ] [package.metadata.requires-dev] @@ -565,7 +567,7 @@ wheels = [ [[package]] name = "temporalio" version = "1.15.0" -source = { git = "https://github.com/temporalio/sdk-python#e1016bcdbdebe613b1123573a839062cfd2ca5a1" } +source = { git = "https://github.com/temporalio/sdk-python?branch=dan-9997-python-mcp-nexus#a0897ff914d27db6c8955e4955fadc67a5ba1dd9" } dependencies = [ { name = "nexus-rpc" }, { name = "protobuf" },