-
Notifications
You must be signed in to change notification settings - Fork 2
Nexus transport for workflow callers #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
016c50b
5e43c16
a0d71d5
ac63af2
4eca5c9
1654319
4732489
861b563
ca16e6e
c6c5964
361b7cf
5244fb7
d59c500
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| __pycache__/ |
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: rename to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discussed offline. Renamed to |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| import asyncio | ||
| from contextlib import asynccontextmanager | ||
| from typing import Any, AsyncGenerator | ||
|
|
||
| import anyio | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realized
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added it to dependencies |
||
| import mcp.types as types | ||
| import pydantic | ||
| from mcp.shared.message import SessionMessage | ||
| from temporalio import workflow | ||
|
|
||
| from .service import MCPService | ||
|
|
||
|
|
||
| class WorkflowNexusTransport: | ||
| """ | ||
| Nexus MCP Transport for use in Temporal workflows. | ||
| This class provides a transport that proxies MCP requests to a Temporal Nexus service. It can be | ||
| used to make MCP calls via `mcp.ClientSession` from Temporal workflow code. | ||
| Example: | ||
| ```python | ||
| async with WorkflowNexusTransport("my-endpoint") as (read_stream, write_stream): | ||
| async with ClientSession(read_stream, write_stream) as session: | ||
| await session.initialize() | ||
| await session.list_tools() | ||
| await session.call_tool("my-service/my-operation", {"arg": "value"}) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| endpoint: str, | ||
| ): | ||
| self.endpoint = endpoint | ||
|
|
||
| @asynccontextmanager | ||
| async def connect( | ||
| self, | ||
| ) -> AsyncGenerator[ | ||
| tuple[ | ||
| anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage], | ||
| anyio.streams.memory.MemoryObjectSendStream[SessionMessage], | ||
| ], | ||
| None, | ||
| ]: | ||
| client_write, transport_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated] | ||
| transport_write, client_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated] | ||
|
|
||
| async def message_router() -> None: | ||
| try: | ||
| async for session_message in transport_read: | ||
| request = session_message.message.root | ||
| if not isinstance(request, types.JSONRPCRequest): | ||
| continue | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this intentionally ignored? Can you add a comment saying why?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a comment saying that we ignore e.g. In practice, the test currently gets |
||
| result: types.Result | types.ErrorData | ||
| try: | ||
| match request: | ||
| case types.JSONRPCRequest(method="initialize"): | ||
| result = self._handle_initialize( | ||
| types.InitializeRequestParams.model_validate(request.params) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, I would have expected this to already have come validated from the client. I wonder if there's something to learn here for the Nexus client / transport separation.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, the client does construct those types explicitly, e.g. https://github.com/modelcontextprotocol/python-sdk/blob/d1ac8d68eb2d7ed139bdc2608b8b4e2ec4265be5/src/mcp/client/session.py#L293-L302 We could cast here instead of validating again, but I think I prefer validating again, e.g. so that our transport can be used safely in other contexts. |
||
| ) | ||
| case types.JSONRPCRequest(method="tools/list"): | ||
| result = await self._handle_list_tools() | ||
| case types.JSONRPCRequest(method="tools/call"): | ||
| result = await self._handle_call_tool( | ||
| types.CallToolRequestParams.model_validate(request.params) | ||
| ) | ||
| case _: | ||
| result = types.ErrorData( | ||
| code=types.METHOD_NOT_FOUND, message=f"Unknown method: {request.method}" | ||
| ) | ||
| except pydantic.ValidationError as e: | ||
| result = types.ErrorData(code=types.INVALID_PARAMS, message=f"Invalid request: {e}") | ||
|
|
||
| match result: | ||
| case types.Result(): | ||
| response = self._json_rpc_result_response(request, result) | ||
| case types.ErrorData(): | ||
| response = self._json_rpc_error_response(request, result) | ||
|
|
||
| await transport_write.send(SessionMessage(types.JSONRPCMessage(root=response))) | ||
|
|
||
| except anyio.ClosedResourceError: | ||
| pass | ||
| finally: | ||
| await transport_write.aclose() | ||
|
|
||
| router_task = asyncio.create_task(message_router()) | ||
|
|
||
| try: | ||
| yield client_read, client_write | ||
| finally: | ||
| await client_write.aclose() | ||
| router_task.cancel() | ||
| try: | ||
| await router_task | ||
| except asyncio.CancelledError: | ||
| pass | ||
| await transport_read.aclose() | ||
|
|
||
| def _handle_initialize(self, params: types.InitializeRequestParams) -> types.InitializeResult: | ||
| # TODO: MCPService should implement this | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Want to open an issue? Or just add it, it should be easy..
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Opened #2 |
||
| return types.InitializeResult( | ||
| protocolVersion="2024-11-05", | ||
| capabilities=types.ServerCapabilities(tools=types.ToolsCapability()), | ||
| serverInfo=types.Implementation( | ||
| name="nexus-mcp-transport", | ||
| version="0.1.0", | ||
| ), | ||
| ) | ||
|
|
||
| async def _handle_list_tools(self) -> types.ListToolsResult: | ||
| nexus_client = workflow.create_nexus_client( | ||
| endpoint=self.endpoint, | ||
| service=MCPService, | ||
| ) | ||
| tools = await nexus_client.execute_operation(MCPService.list_tools, None) | ||
| return types.ListToolsResult(tools=tools) | ||
|
|
||
| async def _handle_call_tool(self, params: types.CallToolRequestParams) -> types.CallToolResult: | ||
| service, _, operation = params.name.partition("/") | ||
| nexus_client = workflow.create_nexus_client( | ||
| endpoint=self.endpoint, | ||
| service=service, | ||
| ) | ||
| result: Any = await nexus_client.execute_operation( | ||
| operation, | ||
| params.arguments or {}, | ||
| ) | ||
| if isinstance(result, dict): | ||
| return types.CallToolResult(content=[], structuredContent=result) | ||
| else: | ||
| return types.CallToolResult(content=[types.TextContent(type="text", text=str(result))]) | ||
|
|
||
| def _json_rpc_error_response(self, request: types.JSONRPCRequest, error: types.ErrorData) -> types.JSONRPCResponse: | ||
| return types.JSONRPCResponse.model_validate( | ||
| { | ||
| "jsonrpc": "2.0", | ||
| "id": request.id, | ||
| "error": error.model_dump(), | ||
| } | ||
| ) | ||
|
|
||
| def _json_rpc_result_response(self, request: types.JSONRPCRequest, result: types.Result) -> types.JSONRPCResponse: | ||
| return types.JSONRPCResponse.model_validate( | ||
| { | ||
| "jsonrpc": "2.0", | ||
| "id": request.id, | ||
| "result": result.model_dump(), | ||
| } | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,6 @@ | |
|
|
||
| class ToolListInput(BaseModel): | ||
| endpoint: str | ||
| pass | ||
|
|
||
|
|
||
| class ToolCallInput(BaseModel): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,4 +41,4 @@ python_version = "3.13" | |
| strict = true | ||
|
|
||
| [tool.uv.sources] | ||
| temporalio = { git = "https://github.com/temporalio/sdk-python" } | ||
| temporalio = { git = "https://github.com/temporalio/sdk-python", branch = "dan-9997-python-mcp-nexus" } | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| import uuid | ||
| from dataclasses import dataclass | ||
|
|
||
| import pytest | ||
| from mcp import ClientSession | ||
| from temporalio import workflow | ||
| from temporalio.api.nexus.v1 import EndpointSpec, EndpointTarget | ||
| from temporalio.api.operatorservice.v1 import CreateNexusEndpointRequest | ||
| from temporalio.contrib.pydantic import pydantic_data_converter | ||
| from temporalio.testing import WorkflowEnvironment | ||
| from temporalio.worker import Worker | ||
|
|
||
| from nexusmcp.nexus_transport import WorkflowNexusTransport | ||
|
|
||
| from .service import TestServiceHandler, mcp_service | ||
|
|
||
|
|
||
| @dataclass | ||
| class MCPCallerWorkflowInput: | ||
| endpoint: str | ||
|
|
||
|
|
||
| @workflow.defn(sandboxed=False) | ||
| class MCPCallerWorkflow: | ||
| @workflow.run | ||
| async def run(self, input: MCPCallerWorkflowInput) -> None: | ||
| transport = WorkflowNexusTransport(input.endpoint) | ||
| async with transport.connect() as (read_stream, write_stream): | ||
| async with ClientSession(read_stream, write_stream) as session: | ||
| await session.initialize() | ||
|
|
||
| list_tools_result = await session.list_tools() | ||
| assert len(list_tools_result.tools) == 2 | ||
| assert list_tools_result.tools[0].name == "modified-service-name/modified-op-name" | ||
| assert list_tools_result.tools[1].name == "modified-service-name/op2" | ||
|
|
||
| call_result = await session.call_tool("modified-service-name/modified-op-name", {"name": "World"}) | ||
| assert call_result.structuredContent == {"message": "Hello, World"} | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_workflow_caller() -> None: | ||
| endpoint_name = "endpoint" | ||
| task_queue = "handler-queue" | ||
|
|
||
| async with await WorkflowEnvironment.start_local(data_converter=pydantic_data_converter) as env: | ||
| await env.client.operator_service.create_nexus_endpoint( | ||
| CreateNexusEndpointRequest( | ||
| spec=EndpointSpec( | ||
| name=endpoint_name, | ||
| target=EndpointTarget( | ||
| worker=EndpointTarget.Worker( | ||
| namespace=env.client.namespace, | ||
| task_queue=task_queue, | ||
| ) | ||
| ), | ||
| ) | ||
| ) | ||
| ) | ||
|
|
||
| async with Worker( | ||
| env.client, | ||
| task_queue=task_queue, | ||
| workflows=[MCPCallerWorkflow], | ||
| nexus_service_handlers=[mcp_service, TestServiceHandler()], | ||
| ): | ||
| await env.client.execute_workflow( | ||
| MCPCallerWorkflow.run, | ||
| arg=MCPCallerWorkflowInput(endpoint=endpoint_name), | ||
| id=str(uuid.uuid4()), | ||
| task_queue=task_queue, | ||
| ) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Uh oh!
There was an error while loading. Please reload this page.