Skip to content

Commit 016c50b

Browse files
committed
Workflow caller
1 parent 5e8d80f commit 016c50b

File tree

7 files changed

+214
-6
lines changed

7 files changed

+214
-6
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__/

nexusmcp/inbound_gateway.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class InboundGateway:
2929
3030
This gateway acts as an adapter between the Model Context Protocol (MCP) server and Temporal Nexus Operations,
3131
enabling tool calls to be executed reliably through Temporal's workflow engine. It handles both tool listing and
32-
tool call requests by delegating them to corresponding a set of Temporal Nexus Services in a given endpoint.
32+
tool call requests by delegating them to a corresponding set of Temporal Nexus Services in a given endpoint.
3333
"""
3434

3535
_client: Client
@@ -93,7 +93,7 @@ async def _handle_call_tool(self, name: str, arguments: dict[str, Any]) -> Any:
9393
{"a": 5, "b": 3}
9494
)
9595
"""
96-
service, operation = name.split("/", maxsplit=1)
96+
service, _, operation = name.partition("/")
9797
if not service or not operation:
9898
raise ValueError(f"Invalid tool name: {name}, must be in the format 'service/operation'")
9999
return await self._client.execute_workflow(

nexusmcp/nexus_transport.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import asyncio
2+
from contextlib import asynccontextmanager
3+
from typing import Any, AsyncGenerator
4+
5+
import anyio
6+
import mcp.types as types
7+
import pydantic
8+
from mcp.shared.message import SessionMessage
9+
from temporalio import workflow
10+
11+
from .service import MCPService
12+
13+
14+
class NexusTransport:
15+
def __init__(
16+
self,
17+
endpoint: str,
18+
):
19+
self.endpoint = endpoint
20+
21+
@asynccontextmanager
22+
async def connect(
23+
self,
24+
) -> AsyncGenerator[
25+
tuple[
26+
anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage],
27+
anyio.streams.memory.MemoryObjectSendStream[SessionMessage],
28+
],
29+
None,
30+
]:
31+
client_write, transport_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated]
32+
transport_write, client_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated]
33+
34+
async def message_router() -> None:
35+
try:
36+
async for session_message in transport_read:
37+
request = session_message.message.root
38+
if not isinstance(request, types.JSONRPCRequest):
39+
continue
40+
result: types.Result | types.ErrorData
41+
try:
42+
match request:
43+
case types.JSONRPCRequest(method="initialize"):
44+
result = self._handle_initialize(
45+
types.InitializeRequestParams.model_validate(request.params)
46+
)
47+
case types.JSONRPCRequest(method="tools/list"):
48+
result = await self._handle_list_tools()
49+
case types.JSONRPCRequest(method="tools/call"):
50+
result = await self._handle_call_tool(
51+
types.CallToolRequestParams.model_validate(request.params)
52+
)
53+
case _:
54+
result = types.ErrorData(
55+
code=types.METHOD_NOT_FOUND, message=f"Unknown method: {request.method}"
56+
)
57+
except pydantic.ValidationError as e:
58+
result = types.ErrorData(code=types.INVALID_PARAMS, message=f"Invalid request: {e}")
59+
60+
match result:
61+
case types.Result():
62+
response = self._json_rpc_result_response(request, result)
63+
case types.ErrorData():
64+
response = self._json_rpc_error_response(request, result)
65+
66+
await transport_write.send(SessionMessage(types.JSONRPCMessage(root=response)))
67+
68+
except anyio.ClosedResourceError:
69+
pass
70+
finally:
71+
await transport_write.aclose()
72+
73+
router_task = asyncio.create_task(message_router())
74+
75+
try:
76+
yield client_read, client_write
77+
finally:
78+
await client_write.aclose()
79+
router_task.cancel()
80+
try:
81+
await router_task
82+
except asyncio.CancelledError:
83+
pass
84+
await transport_read.aclose()
85+
86+
def _handle_initialize(self, params: types.InitializeRequestParams) -> types.InitializeResult:
87+
# TODO: MCPService should implement this
88+
return types.InitializeResult(
89+
protocolVersion="2024-11-05",
90+
capabilities=types.ServerCapabilities(tools=types.ToolsCapability()),
91+
serverInfo=types.Implementation(
92+
name="nexus-mcp-transport",
93+
version="0.1.0",
94+
),
95+
)
96+
97+
async def _handle_list_tools(self) -> types.ListToolsResult:
98+
nexus_client = workflow.create_nexus_client(
99+
endpoint=self.endpoint,
100+
service=MCPService,
101+
)
102+
tools = await nexus_client.execute_operation(MCPService.list_tools, None)
103+
return types.ListToolsResult(tools=tools)
104+
105+
async def _handle_call_tool(self, params: types.CallToolRequestParams) -> types.CallToolResult:
106+
service, _, operation = params.name.partition("/")
107+
nexus_client = workflow.create_nexus_client(
108+
endpoint=self.endpoint,
109+
service=service,
110+
)
111+
result: Any = await nexus_client.execute_operation(
112+
operation,
113+
params.arguments or {},
114+
)
115+
if isinstance(result, dict):
116+
return types.CallToolResult(content=[], structuredContent=result)
117+
else:
118+
return types.CallToolResult(content=[types.TextContent(type="text", text=str(result))])
119+
120+
def _json_rpc_error_response(self, request: types.JSONRPCRequest, error: types.ErrorData) -> types.JSONRPCResponse:
121+
return types.JSONRPCResponse.model_validate(
122+
{
123+
"jsonrpc": "2.0",
124+
"id": request.id,
125+
"error": error.model_dump(),
126+
}
127+
)
128+
129+
def _json_rpc_result_response(self, request: types.JSONRPCRequest, result: types.Result) -> types.JSONRPCResponse:
130+
return types.JSONRPCResponse.model_validate(
131+
{
132+
"jsonrpc": "2.0",
133+
"id": request.id,
134+
"result": result.model_dump(),
135+
}
136+
)

nexusmcp/proxy_workflow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
class ToolListInput(BaseModel):
1414
endpoint: str
15-
pass
1615

1716

1817
class ToolCallInput(BaseModel):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,4 @@ python_version = "3.13"
4141
strict = true
4242

4343
[tool.uv.sources]
44-
temporalio = { git = "https://github.com/temporalio/sdk-python" }
44+
temporalio = { git = "https://github.com/temporalio/sdk-python", branch = "dan-9997-python-mcp-nexus" }

tests/test_workflow_caller.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import uuid
2+
from dataclasses import dataclass
3+
4+
import pytest
5+
from mcp import ClientSession
6+
from temporalio import workflow
7+
from temporalio.api.nexus.v1 import EndpointSpec, EndpointTarget
8+
from temporalio.api.operatorservice.v1 import CreateNexusEndpointRequest
9+
from temporalio.contrib.pydantic import pydantic_data_converter
10+
from temporalio.testing import WorkflowEnvironment
11+
from temporalio.worker import Worker
12+
13+
from nexusmcp.nexus_transport import NexusTransport
14+
15+
from .service import TestServiceHandler, mcp_service
16+
17+
18+
@dataclass
19+
class MCPCallerWorkflowInput:
20+
endpoint: str
21+
22+
23+
@workflow.defn(sandboxed=False)
24+
class MCPCallerWorkflow:
25+
@workflow.run
26+
async def run(self, input: MCPCallerWorkflowInput) -> None:
27+
transport = NexusTransport(input.endpoint)
28+
async with transport.connect() as (read_stream, write_stream):
29+
async with ClientSession(read_stream, write_stream) as session:
30+
await session.initialize()
31+
32+
list_tools_result = await session.list_tools()
33+
assert len(list_tools_result.tools) == 2
34+
assert list_tools_result.tools[0].name == "modified-service-name/modified-op-name"
35+
assert list_tools_result.tools[1].name == "modified-service-name/op2"
36+
37+
call_result = await session.call_tool("modified-service-name/modified-op-name", {"name": "World"})
38+
assert call_result.structuredContent == {"message": "Hello, World"}
39+
40+
41+
@pytest.mark.asyncio
42+
async def test_workflow_caller() -> None:
43+
endpoint_name = "endpoint"
44+
task_queue = "handler-queue"
45+
46+
async with await WorkflowEnvironment.start_local(data_converter=pydantic_data_converter) as env:
47+
await env.client.operator_service.create_nexus_endpoint(
48+
CreateNexusEndpointRequest(
49+
spec=EndpointSpec(
50+
name=endpoint_name,
51+
target=EndpointTarget(
52+
worker=EndpointTarget.Worker(
53+
namespace=env.client.namespace,
54+
task_queue=task_queue,
55+
)
56+
),
57+
)
58+
)
59+
)
60+
61+
async with Worker(
62+
env.client,
63+
task_queue=task_queue,
64+
workflows=[MCPCallerWorkflow],
65+
nexus_service_handlers=[mcp_service, TestServiceHandler()],
66+
):
67+
await env.client.execute_workflow(
68+
MCPCallerWorkflow.run,
69+
arg=MCPCallerWorkflowInput(endpoint=endpoint_name),
70+
id=str(uuid.uuid4()),
71+
task_queue=task_queue,
72+
)

uv.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)