Skip to content

Commit 50477fd

Browse files
authored
Nexus transport for workflow callers (#1)
This PR makes it possible to make MCP calls from a Temporal workflow to a Nexus MCP server, via a custom transport for the standard MCP SDK client. The typical use case is an AI agent implemented as a Temporal workflow that wants to make use of tools that are themselves backed by Temporal workflows.
1 parent 5e8d80f commit 50477fd

File tree

10 files changed

+312
-12
lines changed

10 files changed

+312
-12
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
__pycache__/
2+
*.egg-info/

README.md

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,12 @@ class CalculatorHandler:
9999
return CalculateResponse(result=result)
100100
```
101101

102-
### 4. Create a Nexus endpoint
102+
### 4. Create a namespace and Nexus endpoint
103103

104104
**Local dev or self hosted deployment**
105105

106106
```bash
107+
temporal operator namespace create --namespace my-handler-namespace
107108
temporal operator nexus endpoint create \
108109
--name mcp-gateway \
109110
--target-namespace my-handler-namespace \
@@ -117,13 +118,20 @@ temporal operator nexus endpoint create \
117118
### 3. Set Up the Temporal Worker with the Nexus handlers at `worker.py`
118119

119120
```python
121+
import asyncio
122+
120123
from temporalio.client import Client
124+
from temporalio.contrib.pydantic import pydantic_data_converter
121125
from temporalio.worker import Worker
122126
from .service_handler import mcp_service_handler, CalculatorHandler
123127

124128
async def main():
125129
# Connect to Temporal (replace host and namespace as needed).
126-
client = await Client.connect("localhost:7233", namespace="my-handler-namespace")
130+
client = await Client.connect(
131+
"localhost:7233",
132+
namespace="my-handler-namespace",
133+
data_converter=pydantic_data_converter,
134+
)
127135

128136
async with Worker(
129137
client,
@@ -136,17 +144,26 @@ async def main():
136144

137145
### 4. Set Up the MCP Gateway
138146

147+
```bash
148+
temporal operator namespace create --namespace my-caller-namespace
149+
```
150+
139151
```python
140152
import asyncio
141153
from mcp.server.lowlevel import NotificationOptions, Server
142154
from mcp.server.models import InitializationOptions
143155
from temporalio.client import Client
156+
from temporalio.contrib.pydantic import pydantic_data_converter
144157
from nexusmcp import InboundGateway
145158

146159
async def main():
147160
server = Server("nexus-mcp-demo")
148161
# Connect to Temporal (replace host and namespace as needed).
149-
client = await Client.connect("localhost:7233", namespace="my-caller-namespace")
162+
client = await Client.connect(
163+
"localhost:7233",
164+
namespace="my-caller-namespace",
165+
data_converter=pydantic_data_converter,
166+
)
150167

151168
# Create the MCP gateway
152169
gateway = InboundGateway(
@@ -163,7 +180,7 @@ if __name__ == "__main__":
163180
asyncio.run(main())
164181
```
165182

166-
### 4. Configure Your MCP Client
183+
### 5. Configure Your MCP Client
167184

168185
Add to your MCP client configuration (e.g., Claude Desktop):
169186

@@ -178,6 +195,57 @@ Add to your MCP client configuration (e.g., Claude Desktop):
178195
}
179196
```
180197

198+
### 6. Make MCP calls from a Temporal Workflow
199+
200+
```python
201+
import asyncio
202+
import uuid
203+
204+
from mcp import ClientSession
205+
from nexusmcp import WorkflowTransport
206+
from pydantic import BaseModel
207+
from temporalio import workflow
208+
from temporalio.client import Client
209+
from temporalio.contrib.pydantic import pydantic_data_converter
210+
from temporalio.worker import Worker
211+
212+
213+
class AgentWorkflowInput(BaseModel):
214+
endpoint: str
215+
216+
217+
# The workflow must have the sandbox disabled
218+
@workflow.defn(sandboxed=False)
219+
class AgentWorkflow:
220+
@workflow.run
221+
async def run(self, input: AgentWorkflowInput):
222+
transport = WorkflowTransport(input.endpoint)
223+
async with transport.connect() as (read_stream, write_stream):
224+
async with ClientSession(read_stream, write_stream) as session:
225+
await session.initialize()
226+
list_tools_result = await session.list_tools()
227+
print(f"available tools: {list_tools_result}")
228+
229+
230+
async def main():
231+
client = await Client.connect(
232+
"localhost:7233",
233+
data_converter=pydantic_data_converter,
234+
)
235+
236+
async with Worker(
237+
client,
238+
task_queue="agent-workflow",
239+
workflows=[AgentWorkflow],
240+
) as worker:
241+
await client.execute_workflow(
242+
AgentWorkflow.run,
243+
AgentWorkflowInput(endpoint="mcp-gateway"),
244+
id=str(uuid.uuid4()),
245+
task_queue=worker.task_queue,
246+
)
247+
```
248+
181249
## Usage Examples
182250

183251
### Tool Filtering

nexusmcp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
from .inbound_gateway import InboundGateway
55
from .service import MCPService
66
from .service_handler import MCPServiceHandler, exclude
7+
from .workflow_transport import WorkflowTransport
78

8-
__all__ = ["MCPService", "MCPServiceHandler", "InboundGateway", "exclude"]
9+
__all__ = ["MCPService", "MCPServiceHandler", "InboundGateway", "exclude", "WorkflowTransport"]

nexusmcp/inbound_gateway.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class InboundGateway:
2929
3030
This gateway acts as an adapter between the Model Context Protocol (MCP) server and Temporal Nexus Operations,
3131
enabling tool calls to be executed reliably through Temporal's workflow engine. It handles both tool listing and
32-
tool call requests by delegating them to corresponding a set of Temporal Nexus Services in a given endpoint.
32+
tool call requests by delegating them to a corresponding set of Temporal Nexus Services in a given endpoint.
3333
"""
3434

3535
_client: Client
@@ -93,7 +93,7 @@ async def _handle_call_tool(self, name: str, arguments: dict[str, Any]) -> Any:
9393
{"a": 5, "b": 3}
9494
)
9595
"""
96-
service, operation = name.split("/", maxsplit=1)
96+
service, _, operation = name.partition("/")
9797
if not service or not operation:
9898
raise ValueError(f"Invalid tool name: {name}, must be in the format 'service/operation'")
9999
return await self._client.execute_workflow(

nexusmcp/proxy_workflow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
class ToolListInput(BaseModel):
1414
endpoint: str
15-
pass
1615

1716

1817
class ToolCallInput(BaseModel):

nexusmcp/workflow_transport.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import asyncio
2+
from contextlib import asynccontextmanager
3+
from typing import Any, AsyncGenerator
4+
5+
import anyio
6+
import mcp.types as types
7+
import pydantic
8+
from mcp.shared.message import SessionMessage
9+
from temporalio import workflow
10+
11+
from .service import MCPService
12+
13+
14+
class WorkflowTransport:
15+
"""
16+
An MCP Transport for use in Temporal workflows.
17+
18+
This class provides a transport that proxies MCP requests from a Temporal Workflow to a Temporal
19+
Nexus service. It can be used to make MCP calls via `mcp.ClientSession` from Temporal workflow
20+
code.
21+
22+
Example:
23+
```python async with WorkflowNexusTransport("my-endpoint") as (read_stream, write_stream):
24+
async with ClientSession(read_stream, write_stream) as session:
25+
await session.initialize() await session.list_tools() await
26+
session.call_tool("my-service/my-operation", {"arg": "value"})
27+
```
28+
"""
29+
30+
def __init__(
31+
self,
32+
endpoint: str,
33+
):
34+
self.endpoint = endpoint
35+
36+
@asynccontextmanager
37+
async def connect(
38+
self,
39+
) -> AsyncGenerator[
40+
tuple[
41+
anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage],
42+
anyio.streams.memory.MemoryObjectSendStream[SessionMessage],
43+
],
44+
None,
45+
]:
46+
client_write, transport_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated]
47+
transport_write, client_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated]
48+
49+
async def message_router() -> None:
50+
try:
51+
async for session_message in transport_read:
52+
request = session_message.message.root
53+
if not isinstance(request, types.JSONRPCRequest):
54+
# Ignore e.g. types.JSONRPCNotification
55+
continue
56+
result: types.Result | types.ErrorData
57+
try:
58+
match request:
59+
case types.JSONRPCRequest(method="initialize"):
60+
result = self._handle_initialize(
61+
types.InitializeRequestParams.model_validate(request.params)
62+
)
63+
case types.JSONRPCRequest(method="tools/list"):
64+
result = await self._handle_list_tools()
65+
case types.JSONRPCRequest(method="tools/call"):
66+
result = await self._handle_call_tool(
67+
types.CallToolRequestParams.model_validate(request.params)
68+
)
69+
case _:
70+
result = types.ErrorData(
71+
code=types.METHOD_NOT_FOUND, message=f"Unknown method: {request.method}"
72+
)
73+
except pydantic.ValidationError as e:
74+
result = types.ErrorData(code=types.INVALID_PARAMS, message=f"Invalid request: {e}")
75+
76+
match result:
77+
case types.Result():
78+
response = self._json_rpc_result_response(request, result)
79+
case types.ErrorData():
80+
response = self._json_rpc_error_response(request, result)
81+
82+
await transport_write.send(SessionMessage(types.JSONRPCMessage(root=response)))
83+
84+
except anyio.ClosedResourceError:
85+
pass
86+
finally:
87+
await transport_write.aclose()
88+
89+
router_task = asyncio.create_task(message_router())
90+
91+
try:
92+
yield client_read, client_write
93+
finally:
94+
await client_write.aclose()
95+
router_task.cancel()
96+
try:
97+
await router_task
98+
except asyncio.CancelledError:
99+
pass
100+
await transport_read.aclose()
101+
102+
def _handle_initialize(self, params: types.InitializeRequestParams) -> types.InitializeResult:
103+
# TODO: MCPService should implement this
104+
return types.InitializeResult(
105+
protocolVersion="2024-11-05",
106+
capabilities=types.ServerCapabilities(tools=types.ToolsCapability()),
107+
serverInfo=types.Implementation(
108+
name="nexus-mcp-transport",
109+
version="0.1.0",
110+
),
111+
)
112+
113+
async def _handle_list_tools(self) -> types.ListToolsResult:
114+
nexus_client = workflow.create_nexus_client(
115+
endpoint=self.endpoint,
116+
service=MCPService,
117+
)
118+
tools = await nexus_client.execute_operation(MCPService.list_tools, None)
119+
return types.ListToolsResult(tools=tools)
120+
121+
async def _handle_call_tool(self, params: types.CallToolRequestParams) -> types.CallToolResult:
122+
service, _, operation = params.name.partition("/")
123+
nexus_client = workflow.create_nexus_client(
124+
endpoint=self.endpoint,
125+
service=service,
126+
)
127+
result: Any = await nexus_client.execute_operation(
128+
operation,
129+
params.arguments or {},
130+
)
131+
if isinstance(result, dict):
132+
return types.CallToolResult(content=[], structuredContent=result)
133+
else:
134+
return types.CallToolResult(content=[types.TextContent(type="text", text=str(result))])
135+
136+
def _json_rpc_error_response(self, request: types.JSONRPCRequest, error: types.ErrorData) -> types.JSONRPCResponse:
137+
return types.JSONRPCResponse.model_validate(
138+
{
139+
"jsonrpc": "2.0",
140+
"id": request.id,
141+
"error": error.model_dump(),
142+
}
143+
)
144+
145+
def _json_rpc_result_response(self, request: types.JSONRPCRequest, result: types.Result) -> types.JSONRPCResponse:
146+
return types.JSONRPCResponse.model_validate(
147+
{
148+
"jsonrpc": "2.0",
149+
"id": request.id,
150+
"result": result.model_dump(),
151+
}
152+
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ description = "Add your description here"
55
readme = "README.md"
66
requires-python = ">=3.13"
77
dependencies = [
8+
"anyio>=4.10.0",
89
"mcp>=1.13.0",
910
"nexus-rpc>=1.1.0",
1011
"pydantic>=2.11.7",

tests/test_inbound_gateway.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import asyncio
22

33
import anyio
4-
from mcp.shared.message import SessionMessage
54
import pytest
65
from mcp import ClientSession
76
from mcp.server.lowlevel import NotificationOptions, Server
87
from mcp.server.models import InitializationOptions
8+
from mcp.shared.message import SessionMessage
99
from temporalio.api.nexus.v1 import EndpointSpec, EndpointTarget
1010
from temporalio.api.operatorservice.v1 import CreateNexusEndpointRequest
11+
from temporalio.contrib.pydantic import pydantic_data_converter
1112
from temporalio.testing import WorkflowEnvironment
1213
from temporalio.worker import Worker
1314

@@ -23,7 +24,7 @@ async def test_inbound_gateway() -> None:
2324
endpoint_name = "endpoint"
2425
task_queue = "handler-queue"
2526

26-
async with await WorkflowEnvironment.start_local() as env:
27+
async with await WorkflowEnvironment.start_local(data_converter=pydantic_data_converter) as env:
2728
await env.client.operator_service.create_nexus_endpoint(
2829
CreateNexusEndpointRequest(
2930
spec=EndpointSpec(

0 commit comments

Comments
 (0)