Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__/
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ class CalculatorHandler:
return CalculateResponse(result=result)
```

### 4. Create a Nexus endpoint
### 4. Create a Nexus namespace and endpoint

**Local dev or self hosted deployment**

```bash
temporal operator namespace create --namespace my-handler-namespace
temporal operator nexus endpoint create \
--name mcp-gateway \
--target-namespace my-handler-namespace \
Expand All @@ -117,6 +118,8 @@ temporal operator nexus endpoint create \
### 3. Set Up the Temporal Worker with the Nexus handlers at `worker.py`

```python
import asyncio

from temporalio.client import Client
from temporalio.worker import Worker
from .service_handler import mcp_service_handler, CalculatorHandler
Expand Down
3 changes: 2 additions & 1 deletion nexusmcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

with workflow.unsafe.imports_passed_through():
from .inbound_gateway import InboundGateway
from .nexus_transport import WorkflowNexusTransport
from .service import MCPService
from .service_handler import MCPServiceHandler, exclude

__all__ = ["MCPService", "MCPServiceHandler", "InboundGateway", "exclude"]
__all__ = ["MCPService", "MCPServiceHandler", "InboundGateway", "exclude", "WorkflowNexusTransport"]
4 changes: 2 additions & 2 deletions nexusmcp/inbound_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class InboundGateway:

This gateway acts as an adapter between the Model Context Protocol (MCP) server and Temporal Nexus Operations,
enabling tool calls to be executed reliably through Temporal's workflow engine. It handles both tool listing and
tool call requests by delegating them to corresponding a set of Temporal Nexus Services in a given endpoint.
tool call requests by delegating them to a corresponding set of Temporal Nexus Services in a given endpoint.
"""

_client: Client
Expand Down Expand Up @@ -93,7 +93,7 @@ async def _handle_call_tool(self, name: str, arguments: dict[str, Any]) -> Any:
{"a": 5, "b": 3}
)
"""
service, operation = name.split("/", maxsplit=1)
service, _, operation = name.partition("/")
if not service or not operation:
raise ValueError(f"Invalid tool name: {name}, must be in the format 'service/operation'")
return await self._client.execute_workflow(
Expand Down
152 changes: 152 additions & 0 deletions nexusmcp/nexus_transport.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename to transport.py since this is already in the nexusmcp package.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. Renamed to workflow_transport.WorkflowTransport, the idea being that this is transport over "workflow protocol" (or "workflow task protocol"), although those terms are not currently standard.

Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import asyncio
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator

import anyio
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized anyio wasn't in the project dependencies but it was already in use in tests before your PR. I wonder if there's a linter that can catch that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it to dependencies

import mcp.types as types
import pydantic
from mcp.shared.message import SessionMessage
from temporalio import workflow

from .service import MCPService


class WorkflowNexusTransport:
"""
Nexus MCP Transport for use in Temporal workflows.
This class provides a transport that proxies MCP requests to a Temporal Nexus service. It can be
used to make MCP calls via `mcp.ClientSession` from Temporal workflow code.
Example:
```python
async with WorkflowNexusTransport("my-endpoint") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
await session.list_tools()
await session.call_tool("my-service/my-operation", {"arg": "value"})
```
"""

def __init__(
self,
endpoint: str,
):
self.endpoint = endpoint

@asynccontextmanager
async def connect(
self,
) -> AsyncGenerator[
tuple[
anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage],
anyio.streams.memory.MemoryObjectSendStream[SessionMessage],
],
None,
]:
client_write, transport_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated]
transport_write, client_read = anyio.create_memory_object_stream(0) # type: ignore[var-annotated]

async def message_router() -> None:
try:
async for session_message in transport_read:
request = session_message.message.root
if not isinstance(request, types.JSONRPCRequest):
continue
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentionally ignored? Can you add a comment saying why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment saying that we ignore e.g. types.JSONRPCNotification

In practice, the test currently gets JSONRPCNotification method='notifications/initialized'

result: types.Result | types.ErrorData
try:
match request:
case types.JSONRPCRequest(method="initialize"):
result = self._handle_initialize(
types.InitializeRequestParams.model_validate(request.params)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I would have expected this to already have come validated from the client. I wonder if there's something to learn here for the Nexus client / transport separation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, the client does construct those types explicitly, e.g. https://github.com/modelcontextprotocol/python-sdk/blob/d1ac8d68eb2d7ed139bdc2608b8b4e2ec4265be5/src/mcp/client/session.py#L293-L302

We could cast here instead of validating again, but I think I prefer validating again, e.g. so that our transport can be used safely in other contexts.

)
case types.JSONRPCRequest(method="tools/list"):
result = await self._handle_list_tools()
case types.JSONRPCRequest(method="tools/call"):
result = await self._handle_call_tool(
types.CallToolRequestParams.model_validate(request.params)
)
case _:
result = types.ErrorData(
code=types.METHOD_NOT_FOUND, message=f"Unknown method: {request.method}"
)
except pydantic.ValidationError as e:
result = types.ErrorData(code=types.INVALID_PARAMS, message=f"Invalid request: {e}")

match result:
case types.Result():
response = self._json_rpc_result_response(request, result)
case types.ErrorData():
response = self._json_rpc_error_response(request, result)

await transport_write.send(SessionMessage(types.JSONRPCMessage(root=response)))

except anyio.ClosedResourceError:
pass
finally:
await transport_write.aclose()

router_task = asyncio.create_task(message_router())

try:
yield client_read, client_write
finally:
await client_write.aclose()
router_task.cancel()
try:
await router_task
except asyncio.CancelledError:
pass
await transport_read.aclose()

def _handle_initialize(self, params: types.InitializeRequestParams) -> types.InitializeResult:
# TODO: MCPService should implement this
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to open an issue? Or just add it, it should be easy..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #2

return types.InitializeResult(
protocolVersion="2024-11-05",
capabilities=types.ServerCapabilities(tools=types.ToolsCapability()),
serverInfo=types.Implementation(
name="nexus-mcp-transport",
version="0.1.0",
),
)

async def _handle_list_tools(self) -> types.ListToolsResult:
nexus_client = workflow.create_nexus_client(
endpoint=self.endpoint,
service=MCPService,
)
tools = await nexus_client.execute_operation(MCPService.list_tools, None)
return types.ListToolsResult(tools=tools)

async def _handle_call_tool(self, params: types.CallToolRequestParams) -> types.CallToolResult:
service, _, operation = params.name.partition("/")
nexus_client = workflow.create_nexus_client(
endpoint=self.endpoint,
service=service,
)
result: Any = await nexus_client.execute_operation(
operation,
params.arguments or {},
)
if isinstance(result, dict):
return types.CallToolResult(content=[], structuredContent=result)
else:
return types.CallToolResult(content=[types.TextContent(type="text", text=str(result))])

def _json_rpc_error_response(self, request: types.JSONRPCRequest, error: types.ErrorData) -> types.JSONRPCResponse:
return types.JSONRPCResponse.model_validate(
{
"jsonrpc": "2.0",
"id": request.id,
"error": error.model_dump(),
}
)

def _json_rpc_result_response(self, request: types.JSONRPCRequest, result: types.Result) -> types.JSONRPCResponse:
return types.JSONRPCResponse.model_validate(
{
"jsonrpc": "2.0",
"id": request.id,
"result": result.model_dump(),
}
)
1 change: 0 additions & 1 deletion nexusmcp/proxy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

class ToolListInput(BaseModel):
endpoint: str
pass


class ToolCallInput(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ python_version = "3.13"
strict = true

[tool.uv.sources]
temporalio = { git = "https://github.com/temporalio/sdk-python" }
temporalio = { git = "https://github.com/temporalio/sdk-python", branch = "dan-9997-python-mcp-nexus" }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Temporal SDK needs a small modification: the event loop needs a stub implementation of get_task_factory

72 changes: 72 additions & 0 deletions tests/test_workflow_caller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import uuid
from dataclasses import dataclass

import pytest
from mcp import ClientSession
from temporalio import workflow
from temporalio.api.nexus.v1 import EndpointSpec, EndpointTarget
from temporalio.api.operatorservice.v1 import CreateNexusEndpointRequest
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker

from nexusmcp.nexus_transport import WorkflowNexusTransport

from .service import TestServiceHandler, mcp_service


@dataclass
class MCPCallerWorkflowInput:
endpoint: str


@workflow.defn(sandboxed=False)
class MCPCallerWorkflow:
@workflow.run
async def run(self, input: MCPCallerWorkflowInput) -> None:
transport = WorkflowNexusTransport(input.endpoint)
async with transport.connect() as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()

list_tools_result = await session.list_tools()
assert len(list_tools_result.tools) == 2
assert list_tools_result.tools[0].name == "modified-service-name/modified-op-name"
assert list_tools_result.tools[1].name == "modified-service-name/op2"

call_result = await session.call_tool("modified-service-name/modified-op-name", {"name": "World"})
assert call_result.structuredContent == {"message": "Hello, World"}


@pytest.mark.asyncio
async def test_workflow_caller() -> None:
endpoint_name = "endpoint"
task_queue = "handler-queue"

async with await WorkflowEnvironment.start_local(data_converter=pydantic_data_converter) as env:
await env.client.operator_service.create_nexus_endpoint(
CreateNexusEndpointRequest(
spec=EndpointSpec(
name=endpoint_name,
target=EndpointTarget(
worker=EndpointTarget.Worker(
namespace=env.client.namespace,
task_queue=task_queue,
)
),
)
)
)

async with Worker(
env.client,
task_queue=task_queue,
workflows=[MCPCallerWorkflow],
nexus_service_handlers=[mcp_service, TestServiceHandler()],
):
await env.client.execute_workflow(
MCPCallerWorkflow.run,
arg=MCPCallerWorkflowInput(endpoint=endpoint_name),
id=str(uuid.uuid4()),
task_queue=task_queue,
)
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.