Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__/
58 changes: 56 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ class CalculatorHandler:
return CalculateResponse(result=result)
```

### 4. Create a Nexus endpoint
### 4. Create a Nexus namespace and endpoint

**Local dev or self hosted deployment**

```bash
temporal operator namespace create --namespace my-handler-namespace
temporal operator nexus endpoint create \
--name mcp-gateway \
--target-namespace my-handler-namespace \
Expand All @@ -117,6 +118,8 @@ temporal operator nexus endpoint create \
### 3. Set Up the Temporal Worker with the Nexus handlers at `worker.py`

```python
import asyncio

from temporalio.client import Client
from temporalio.worker import Worker
from .service_handler import mcp_service_handler, CalculatorHandler
Expand Down Expand Up @@ -163,7 +166,7 @@ if __name__ == "__main__":
asyncio.run(main())
```

### 4. Configure Your MCP Client
### 5. Configure Your MCP Client

Add to your MCP client configuration (e.g., Claude Desktop):

Expand All @@ -178,6 +181,57 @@ Add to your MCP client configuration (e.g., Claude Desktop):
}
```

### 6. Make MCP calls from a Temporal workflow

```python
import asyncio
import uuid

from mcp import ClientSession
from nexusmcp import WorkflowTransport
from pydantic import BaseModel
from temporalio import workflow
from temporalio.client import Client
from temporalio.contrib.pydantic import pydantic_data_converter
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I forgot to add that on the handler side, want to do that before merging? Otherwise, I can take that on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure -- I've added it to all the client connections in the README and tests.

from temporalio.worker import Worker


class AgentWorkflowInput(BaseModel):
endpoint: str


# The workflow must have the sandbox disabled
@workflow.defn(sandboxed=False)
class AgentWorkflow:
@workflow.run
async def run(self, input: AgentWorkflowInput):
transport = WorkflowTransport(input.endpoint)
async with transport.connect() as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
list_tools_result = await session.list_tools()
print(f"available tools: {list_tools_result}")


async def main():
client = await Client.connect(
"localhost:7233",
data_converter=pydantic_data_converter,
)

async with Worker(
client,
task_queue="agent-workflow",
workflows=[AgentWorkflow],
) as worker:
await client.execute_workflow(
AgentWorkflow.run,
AgentWorkflowInput(endpoint="mcp-gateway"),
id=str(uuid.uuid4()),
task_queue=worker.task_queue,
)
```

## Usage Examples

### Tool Filtering
Expand Down
3 changes: 2 additions & 1 deletion nexusmcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from .inbound_gateway import InboundGateway
from .service import MCPService
from .service_handler import MCPServiceHandler, exclude
from .workflow_transport import WorkflowTransport

__all__ = ["MCPService", "MCPServiceHandler", "InboundGateway", "exclude"]
__all__ = ["MCPService", "MCPServiceHandler", "InboundGateway", "exclude", "WorkflowTransport"]
4 changes: 2 additions & 2 deletions nexusmcp/inbound_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class InboundGateway:

This gateway acts as an adapter between the Model Context Protocol (MCP) server and Temporal Nexus Operations,
enabling tool calls to be executed reliably through Temporal's workflow engine. It handles both tool listing and
tool call requests by delegating them to corresponding a set of Temporal Nexus Services in a given endpoint.
tool call requests by delegating them to a corresponding set of Temporal Nexus Services in a given endpoint.
"""

_client: Client
Expand Down Expand Up @@ -93,7 +93,7 @@ async def _handle_call_tool(self, name: str, arguments: dict[str, Any]) -> Any:
{"a": 5, "b": 3}
)
"""
service, operation = name.split("/", maxsplit=1)
service, _, operation = name.partition("/")
if not service or not operation:
raise ValueError(f"Invalid tool name: {name}, must be in the format 'service/operation'")
return await self._client.execute_workflow(
Expand Down
1 change: 0 additions & 1 deletion nexusmcp/proxy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

class ToolListInput(BaseModel):
endpoint: str
pass


class ToolCallInput(BaseModel):
Expand Down
152 changes: 152 additions & 0 deletions nexusmcp/workflow_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import asyncio
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator

import anyio
import mcp.types as types
import pydantic
from mcp.shared.message import SessionMessage
from temporalio import workflow

from .service import MCPService


class WorkflowTransport:
"""
An MCP Transport for use in Temporal workflows.

This class provides a transport that proxies MCP requests from a Temporal Workflow to a Temporal
Nexus service. It can be used to make MCP calls via `mcp.ClientSession` from Temporal workflow
code.

Example:
```python async with WorkflowNexusTransport("my-endpoint") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize() await session.list_tools() await
session.call_tool("my-service/my-operation", {"arg": "value"})
```
"""

def __init__(
self,
endpoint: str,
):
self.endpoint = endpoint

@asynccontextmanager
async def connect(
self,
) -> AsyncGenerator[
tuple[
anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage],
anyio.streams.memory.MemoryObjectSendStream[SessionMessage],
],
None,
]:
client_write, transport_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated]
transport_write, client_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated]

async def message_router() -> None:
try:
async for session_message in transport_read:
request = session_message.message.root
if not isinstance(request, types.JSONRPCRequest):
# Ignore e.g. types.JSONRPCNotification
continue
result: types.Result | types.ErrorData
try:
match request:
case types.JSONRPCRequest(method="initialize"):
result = self._handle_initialize(
types.InitializeRequestParams.model_validate(request.params)
)
case types.JSONRPCRequest(method="tools/list"):
result = await self._handle_list_tools()
case types.JSONRPCRequest(method="tools/call"):
result = await self._handle_call_tool(
types.CallToolRequestParams.model_validate(request.params)
)
case _:
result = types.ErrorData(
code=types.METHOD_NOT_FOUND, message=f"Unknown method: {request.method}"
)
except pydantic.ValidationError as e:
result = types.ErrorData(code=types.INVALID_PARAMS, message=f"Invalid request: {e}")

match result:
case types.Result():
response = self._json_rpc_result_response(request, result)
case types.ErrorData():
response = self._json_rpc_error_response(request, result)

await transport_write.send(SessionMessage(types.JSONRPCMessage(root=response)))

except anyio.ClosedResourceError:
pass
finally:
await transport_write.aclose()

router_task = asyncio.create_task(message_router())

try:
yield client_read, client_write
finally:
await client_write.aclose()
router_task.cancel()
try:
await router_task
except asyncio.CancelledError:
pass
await transport_read.aclose()

def _handle_initialize(self, params: types.InitializeRequestParams) -> types.InitializeResult:
# TODO: MCPService should implement this
return types.InitializeResult(
protocolVersion="2024-11-05",
capabilities=types.ServerCapabilities(tools=types.ToolsCapability()),
serverInfo=types.Implementation(
name="nexus-mcp-transport",
version="0.1.0",
),
)

async def _handle_list_tools(self) -> types.ListToolsResult:
nexus_client = workflow.create_nexus_client(
endpoint=self.endpoint,
service=MCPService,
)
tools = await nexus_client.execute_operation(MCPService.list_tools, None)
return types.ListToolsResult(tools=tools)

async def _handle_call_tool(self, params: types.CallToolRequestParams) -> types.CallToolResult:
service, _, operation = params.name.partition("/")
nexus_client = workflow.create_nexus_client(
endpoint=self.endpoint,
service=service,
)
result: Any = await nexus_client.execute_operation(
operation,
params.arguments or {},
)
if isinstance(result, dict):
return types.CallToolResult(content=[], structuredContent=result)
else:
return types.CallToolResult(content=[types.TextContent(type="text", text=str(result))])

def _json_rpc_error_response(self, request: types.JSONRPCRequest, error: types.ErrorData) -> types.JSONRPCResponse:
return types.JSONRPCResponse.model_validate(
{
"jsonrpc": "2.0",
"id": request.id,
"error": error.model_dump(),
}
)

def _json_rpc_result_response(self, request: types.JSONRPCRequest, result: types.Result) -> types.JSONRPCResponse:
return types.JSONRPCResponse.model_validate(
{
"jsonrpc": "2.0",
"id": request.id,
"result": result.model_dump(),
}
)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"anyio>=4.10.0",
"mcp>=1.13.0",
"nexus-rpc>=1.1.0",
"pydantic>=2.11.7",
Expand Down Expand Up @@ -41,4 +42,4 @@ python_version = "3.13"
strict = true

[tool.uv.sources]
temporalio = { git = "https://github.com/temporalio/sdk-python" }
temporalio = { git = "https://github.com/temporalio/sdk-python", branch = "dan-9997-python-mcp-nexus" }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Temporal SDK needs a small modification: the event loop needs a stub implementation of get_task_factory

74 changes: 74 additions & 0 deletions tests/test_workflow_caller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import uuid
from dataclasses import dataclass

import pytest
from mcp import ClientSession
from temporalio import workflow
from temporalio.api.nexus.v1 import EndpointSpec, EndpointTarget
from temporalio.api.operatorservice.v1 import CreateNexusEndpointRequest
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker

from nexusmcp import WorkflowTransport

from .service import TestServiceHandler, mcp_service


@dataclass
class MCPCallerWorkflowInput:
endpoint: str


# sandbox disabled due to use of ThreadLocal by sniffio
# TODO: make this unnecessary
@workflow.defn(sandboxed=False)
class MCPCallerWorkflow:
@workflow.run
async def run(self, input: MCPCallerWorkflowInput) -> None:
transport = WorkflowTransport(input.endpoint)
async with transport.connect() as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()

list_tools_result = await session.list_tools()
assert len(list_tools_result.tools) == 2
assert list_tools_result.tools[0].name == "modified-service-name/modified-op-name"
assert list_tools_result.tools[1].name == "modified-service-name/op2"

call_result = await session.call_tool("modified-service-name/modified-op-name", {"name": "World"})
assert call_result.structuredContent == {"message": "Hello, World"}


@pytest.mark.asyncio
async def test_workflow_caller() -> None:
endpoint_name = "endpoint"
task_queue = "handler-queue"

async with await WorkflowEnvironment.start_local(data_converter=pydantic_data_converter) as env:
await env.client.operator_service.create_nexus_endpoint(
CreateNexusEndpointRequest(
spec=EndpointSpec(
name=endpoint_name,
target=EndpointTarget(
worker=EndpointTarget.Worker(
namespace=env.client.namespace,
task_queue=task_queue,
)
),
)
)
)

async with Worker(
env.client,
task_queue=task_queue,
workflows=[MCPCallerWorkflow],
nexus_service_handlers=[mcp_service, TestServiceHandler()],
):
await env.client.execute_workflow(
MCPCallerWorkflow.run,
arg=MCPCallerWorkflowInput(endpoint=endpoint_name),
id=str(uuid.uuid4()),
task_queue=task_queue,
)
6 changes: 4 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.