-
Notifications
You must be signed in to change notification settings - Fork 2
Nexus transport for workflow callers #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
016c50b
5e43c16
a0d71d5
ac63af2
4eca5c9
1654319
4732489
861b563
ca16e6e
c6c5964
361b7cf
5244fb7
d59c500
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| __pycache__/ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,11 +99,12 @@ class CalculatorHandler: | |
| return CalculateResponse(result=result) | ||
| ``` | ||
|
|
||
| ### 4. Create a Nexus endpoint | ||
| ### 4. Create a Nexus namespace and endpoint | ||
|
|
||
| **Local dev or self hosted deployment** | ||
|
|
||
| ```bash | ||
| temporal operator namespace create --namespace my-handler-namespace | ||
| temporal operator nexus endpoint create \ | ||
| --name mcp-gateway \ | ||
| --target-namespace my-handler-namespace \ | ||
|
|
@@ -117,6 +118,8 @@ temporal operator nexus endpoint create \ | |
| ### 3. Set Up the Temporal Worker with the Nexus handlers at `worker.py` | ||
|
|
||
| ```python | ||
| import asyncio | ||
|
|
||
| from temporalio.client import Client | ||
| from temporalio.worker import Worker | ||
| from .service_handler import mcp_service_handler, CalculatorHandler | ||
|
|
@@ -163,7 +166,7 @@ if __name__ == "__main__": | |
| asyncio.run(main()) | ||
| ``` | ||
|
|
||
| ### 4. Configure Your MCP Client | ||
| ### 5. Configure Your MCP Client | ||
|
|
||
| Add to your MCP client configuration (e.g., Claude Desktop): | ||
|
|
||
|
|
@@ -178,6 +181,57 @@ Add to your MCP client configuration (e.g., Claude Desktop): | |
| } | ||
| ``` | ||
|
|
||
| ### 6. Make MCP calls from a Temporal workflow | ||
dandavison marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| ```python | ||
| import asyncio | ||
| import uuid | ||
|
|
||
| from mcp import ClientSession | ||
| from nexusmcp import WorkflowTransport | ||
| from pydantic import BaseModel | ||
| from temporalio import workflow | ||
| from temporalio.client import Client | ||
| from temporalio.contrib.pydantic import pydantic_data_converter | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I forgot to add that on the handler side, want to do that before merging? Otherwise, I can take that on.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure -- I've added it to all the client connections in the README and tests. |
||
| from temporalio.worker import Worker | ||
|
|
||
|
|
||
| class AgentWorkflowInput(BaseModel): | ||
| endpoint: str | ||
|
|
||
|
|
||
| # The workflow must have the sandbox disabled | ||
| @workflow.defn(sandboxed=False) | ||
| class AgentWorkflow: | ||
| @workflow.run | ||
| async def run(self, input: AgentWorkflowInput): | ||
| transport = WorkflowTransport(input.endpoint) | ||
| async with transport.connect() as (read_stream, write_stream): | ||
| async with ClientSession(read_stream, write_stream) as session: | ||
| await session.initialize() | ||
| list_tools_result = await session.list_tools() | ||
| print(f"available tools: {list_tools_result}") | ||
|
|
||
|
|
||
| async def main(): | ||
| client = await Client.connect( | ||
| "localhost:7233", | ||
| data_converter=pydantic_data_converter, | ||
| ) | ||
|
|
||
| async with Worker( | ||
| client, | ||
| task_queue="agent-workflow", | ||
| workflows=[AgentWorkflow], | ||
| ) as worker: | ||
| await client.execute_workflow( | ||
| AgentWorkflow.run, | ||
| AgentWorkflowInput(endpoint="mcp-gateway"), | ||
| id=str(uuid.uuid4()), | ||
| task_queue=worker.task_queue, | ||
| ) | ||
| ``` | ||
|
|
||
| ## Usage Examples | ||
|
|
||
| ### Tool Filtering | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,6 @@ | |
|
|
||
| class ToolListInput(BaseModel): | ||
| endpoint: str | ||
| pass | ||
|
|
||
|
|
||
| class ToolCallInput(BaseModel): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| import asyncio | ||
| from contextlib import asynccontextmanager | ||
| from typing import Any, AsyncGenerator | ||
|
|
||
| import anyio | ||
| import mcp.types as types | ||
| import pydantic | ||
| from mcp.shared.message import SessionMessage | ||
| from temporalio import workflow | ||
|
|
||
| from .service import MCPService | ||
|
|
||
|
|
||
| class WorkflowTransport: | ||
| """ | ||
| An MCP Transport for use in Temporal workflows. | ||
|
|
||
| This class provides a transport that proxies MCP requests from a Temporal Workflow to a Temporal | ||
| Nexus service. It can be used to make MCP calls via `mcp.ClientSession` from Temporal workflow | ||
| code. | ||
|
|
||
| Example: | ||
| ```python async with WorkflowNexusTransport("my-endpoint") as (read_stream, write_stream): | ||
| async with ClientSession(read_stream, write_stream) as session: | ||
| await session.initialize() await session.list_tools() await | ||
| session.call_tool("my-service/my-operation", {"arg": "value"}) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| endpoint: str, | ||
| ): | ||
| self.endpoint = endpoint | ||
|
|
||
| @asynccontextmanager | ||
| async def connect( | ||
| self, | ||
| ) -> AsyncGenerator[ | ||
| tuple[ | ||
| anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage], | ||
| anyio.streams.memory.MemoryObjectSendStream[SessionMessage], | ||
| ], | ||
| None, | ||
| ]: | ||
| client_write, transport_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated] | ||
| transport_write, client_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated] | ||
|
|
||
| async def message_router() -> None: | ||
| try: | ||
| async for session_message in transport_read: | ||
| request = session_message.message.root | ||
| if not isinstance(request, types.JSONRPCRequest): | ||
| # Ignore e.g. types.JSONRPCNotification | ||
| continue | ||
| result: types.Result | types.ErrorData | ||
| try: | ||
| match request: | ||
| case types.JSONRPCRequest(method="initialize"): | ||
| result = self._handle_initialize( | ||
| types.InitializeRequestParams.model_validate(request.params) | ||
| ) | ||
| case types.JSONRPCRequest(method="tools/list"): | ||
| result = await self._handle_list_tools() | ||
| case types.JSONRPCRequest(method="tools/call"): | ||
| result = await self._handle_call_tool( | ||
| types.CallToolRequestParams.model_validate(request.params) | ||
| ) | ||
| case _: | ||
| result = types.ErrorData( | ||
| code=types.METHOD_NOT_FOUND, message=f"Unknown method: {request.method}" | ||
| ) | ||
| except pydantic.ValidationError as e: | ||
| result = types.ErrorData(code=types.INVALID_PARAMS, message=f"Invalid request: {e}") | ||
|
|
||
| match result: | ||
| case types.Result(): | ||
| response = self._json_rpc_result_response(request, result) | ||
| case types.ErrorData(): | ||
| response = self._json_rpc_error_response(request, result) | ||
|
|
||
| await transport_write.send(SessionMessage(types.JSONRPCMessage(root=response))) | ||
|
|
||
| except anyio.ClosedResourceError: | ||
| pass | ||
| finally: | ||
| await transport_write.aclose() | ||
|
|
||
| router_task = asyncio.create_task(message_router()) | ||
|
|
||
| try: | ||
| yield client_read, client_write | ||
| finally: | ||
| await client_write.aclose() | ||
| router_task.cancel() | ||
| try: | ||
| await router_task | ||
| except asyncio.CancelledError: | ||
| pass | ||
| await transport_read.aclose() | ||
|
|
||
| def _handle_initialize(self, params: types.InitializeRequestParams) -> types.InitializeResult: | ||
| # TODO: MCPService should implement this | ||
| return types.InitializeResult( | ||
| protocolVersion="2024-11-05", | ||
| capabilities=types.ServerCapabilities(tools=types.ToolsCapability()), | ||
| serverInfo=types.Implementation( | ||
| name="nexus-mcp-transport", | ||
| version="0.1.0", | ||
| ), | ||
| ) | ||
|
|
||
| async def _handle_list_tools(self) -> types.ListToolsResult: | ||
| nexus_client = workflow.create_nexus_client( | ||
| endpoint=self.endpoint, | ||
| service=MCPService, | ||
| ) | ||
| tools = await nexus_client.execute_operation(MCPService.list_tools, None) | ||
| return types.ListToolsResult(tools=tools) | ||
|
|
||
| async def _handle_call_tool(self, params: types.CallToolRequestParams) -> types.CallToolResult: | ||
| service, _, operation = params.name.partition("/") | ||
| nexus_client = workflow.create_nexus_client( | ||
| endpoint=self.endpoint, | ||
| service=service, | ||
| ) | ||
| result: Any = await nexus_client.execute_operation( | ||
| operation, | ||
| params.arguments or {}, | ||
| ) | ||
| if isinstance(result, dict): | ||
| return types.CallToolResult(content=[], structuredContent=result) | ||
| else: | ||
| return types.CallToolResult(content=[types.TextContent(type="text", text=str(result))]) | ||
|
|
||
| def _json_rpc_error_response(self, request: types.JSONRPCRequest, error: types.ErrorData) -> types.JSONRPCResponse: | ||
| return types.JSONRPCResponse.model_validate( | ||
| { | ||
| "jsonrpc": "2.0", | ||
| "id": request.id, | ||
| "error": error.model_dump(), | ||
| } | ||
| ) | ||
|
|
||
| def _json_rpc_result_response(self, request: types.JSONRPCRequest, result: types.Result) -> types.JSONRPCResponse: | ||
| return types.JSONRPCResponse.model_validate( | ||
| { | ||
| "jsonrpc": "2.0", | ||
| "id": request.id, | ||
| "result": result.model_dump(), | ||
| } | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ description = "Add your description here" | |
| readme = "README.md" | ||
| requires-python = ">=3.13" | ||
| dependencies = [ | ||
| "anyio>=4.10.0", | ||
| "mcp>=1.13.0", | ||
| "nexus-rpc>=1.1.0", | ||
| "pydantic>=2.11.7", | ||
|
|
@@ -41,4 +42,4 @@ python_version = "3.13" | |
| strict = true | ||
|
|
||
| [tool.uv.sources] | ||
| temporalio = { git = "https://github.com/temporalio/sdk-python" } | ||
| temporalio = { git = "https://github.com/temporalio/sdk-python", branch = "dan-9997-python-mcp-nexus" } | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| import uuid | ||
| from dataclasses import dataclass | ||
|
|
||
| import pytest | ||
| from mcp import ClientSession | ||
| from temporalio import workflow | ||
| from temporalio.api.nexus.v1 import EndpointSpec, EndpointTarget | ||
| from temporalio.api.operatorservice.v1 import CreateNexusEndpointRequest | ||
| from temporalio.contrib.pydantic import pydantic_data_converter | ||
| from temporalio.testing import WorkflowEnvironment | ||
| from temporalio.worker import Worker | ||
|
|
||
| from nexusmcp import WorkflowTransport | ||
|
|
||
| from .service import TestServiceHandler, mcp_service | ||
|
|
||
|
|
||
| @dataclass | ||
| class MCPCallerWorkflowInput: | ||
| endpoint: str | ||
|
|
||
|
|
||
| # sandbox disabled due to use of ThreadLocal by sniffio | ||
| # TODO: make this unnecessary | ||
| @workflow.defn(sandboxed=False) | ||
| class MCPCallerWorkflow: | ||
| @workflow.run | ||
| async def run(self, input: MCPCallerWorkflowInput) -> None: | ||
| transport = WorkflowTransport(input.endpoint) | ||
| async with transport.connect() as (read_stream, write_stream): | ||
| async with ClientSession(read_stream, write_stream) as session: | ||
| await session.initialize() | ||
|
|
||
| list_tools_result = await session.list_tools() | ||
| assert len(list_tools_result.tools) == 2 | ||
| assert list_tools_result.tools[0].name == "modified-service-name/modified-op-name" | ||
| assert list_tools_result.tools[1].name == "modified-service-name/op2" | ||
|
|
||
| call_result = await session.call_tool("modified-service-name/modified-op-name", {"name": "World"}) | ||
| assert call_result.structuredContent == {"message": "Hello, World"} | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_workflow_caller() -> None: | ||
| endpoint_name = "endpoint" | ||
| task_queue = "handler-queue" | ||
|
|
||
| async with await WorkflowEnvironment.start_local(data_converter=pydantic_data_converter) as env: | ||
| await env.client.operator_service.create_nexus_endpoint( | ||
| CreateNexusEndpointRequest( | ||
| spec=EndpointSpec( | ||
| name=endpoint_name, | ||
| target=EndpointTarget( | ||
| worker=EndpointTarget.Worker( | ||
| namespace=env.client.namespace, | ||
| task_queue=task_queue, | ||
| ) | ||
| ), | ||
| ) | ||
| ) | ||
| ) | ||
|
|
||
| async with Worker( | ||
| env.client, | ||
| task_queue=task_queue, | ||
| workflows=[MCPCallerWorkflow], | ||
| nexus_service_handlers=[mcp_service, TestServiceHandler()], | ||
| ): | ||
| await env.client.execute_workflow( | ||
| MCPCallerWorkflow.run, | ||
| arg=MCPCallerWorkflowInput(endpoint=endpoint_name), | ||
| id=str(uuid.uuid4()), | ||
| task_queue=task_queue, | ||
| ) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Uh oh!
There was an error while loading. Please reload this page.