Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/mcp_agent/executor/temporal/session_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,35 @@ async def request(
return {"error": "missing_execution_id"}

if _in_workflow_runtime():
# In workflow context, dispatch an async request and wait for a signal
# indicating the response payload. The activity returns a unique signal
# name that the server will use to signal the workflow with the result.
from temporalio import workflow as _twf # type: ignore

act = self._context.task_registry.get_activity("mcp_relay_request")
return await self._executor.execute(
signal_name = await self._executor.execute(
act,
True, # make_async_call
exec_id,
method,
params or {},
)
Comment on lines +125 to 131
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Propagate activity errors returned by executor.execute.

TemporalExecutor.execute may return a BaseException; handle it to avoid waiting on a bogus signal_name.

-            signal_name = await self._executor.execute(
+            result = await self._executor.execute(
                 act,
                 True,  # make_async_call
                 exec_id,
                 method,
                 params or {},
             )
+            if isinstance(result, BaseException):
+                raise result
+            signal_name = result  # expected to be str
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
signal_name = await self._executor.execute(
act,
True, # make_async_call
exec_id,
method,
params or {},
)
result = await self._executor.execute(
act,
True, # make_async_call
exec_id,
method,
params or {},
)
if isinstance(result, BaseException):
raise result
signal_name = result # expected to be str
🤖 Prompt for AI Agents
In src/mcp_agent/executor/temporal/session_proxy.py around lines 125 to 131, the
value returned by self._executor.execute may be a BaseException and the code
currently treats it as a valid signal_name; detect when the result is an
exception and re-raise it (or otherwise propagate it) instead of proceeding to
await/use it as a signal name. Update the code to check isinstance(result,
BaseException) right after the await and raise the exception if so, ensuring
calling context observes the error rather than waiting on a bogus signal_name.


# Wait for the response via workflow signal
info = _twf.info()
payload = await self._context.executor.wait_for_signal( # type: ignore[attr-defined]
signal_name,
workflow_id=info.workflow_id,
run_id=info.run_id,
signal_description=f"Waiting for async response to {method}",
# Timeout can be controlled by Temporal workflow/activity timeouts
)
return payload
return await self._system_activities.relay_request(
exec_id, method, params or {}
False, # synchronous call path
exec_id,
method,
params or {},
)

async def send_notification(
Expand Down
33 changes: 31 additions & 2 deletions src/mcp_agent/executor/temporal/system_activities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict
import anyio
import os
import uuid

from temporalio import activity

Expand Down Expand Up @@ -90,11 +91,39 @@ async def relay_notify(

@activity.defn(name="mcp_relay_request")
async def relay_request(
self, execution_id: str, method: str, params: Dict[str, Any] | None = None
) -> Dict[str, Any]:
self,
make_async_call: bool,
execution_id: str,
method: str,
params: Dict[str, Any] | None = None,
) -> Any:
"""
Relay a server->client request via the gateway.

- If make_async_call is False: performs a synchronous RPC and returns the JSON result.
- If make_async_call is True: kicks off an async request on the server that will signal
the workflow with the result; returns a unique signal_name that the workflow should wait on.
"""
gateway_url = getattr(self.context, "gateway_url", None)
gateway_token = getattr(self.context, "gateway_token", None)

if make_async_call:
# Create a unique signal name for this request
signal_name = f"mcp_rpc_{method}_{uuid.uuid4().hex}"
await request_via_proxy(
make_async_call=True,
execution_id=execution_id,
method=method,
params=params or {},
signal_name=signal_name,
gateway_url=gateway_url,
gateway_token=gateway_token,
)
return signal_name

# Synchronous path
return await request_via_proxy(
make_async_call=False,
execution_id=execution_id,
method=method,
params=params or {},
Expand Down
89 changes: 86 additions & 3 deletions src/mcp_agent/mcp/client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,21 +148,78 @@ async def notify_via_proxy(
return bool(resp.get("ok", True))


async def request_via_proxy(
async def _request_via_proxy_impl(
execution_id: str,
method: str,
params: Dict[str, Any] | None = None,
*,
gateway_url: Optional[str] = None,
gateway_token: Optional[str] = None,
) -> Dict[str, Any]:
make_async_call: Optional[bool] = None,
signal_name: Optional[str] = None,
) -> Dict[str, Any] | None:
"""
Relay a server->client request via the gateway.

- If make_async_call is falsy/None: perform synchronous HTTP RPC and return the JSON result.
- If make_async_call is True: trigger an async request on the server that will signal the
workflow with the result, then return None (the workflow should wait on signal_name).
"""
base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None)
url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/request"
headers: Dict[str, str] = {}
tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN")
if tok:
headers["X-MCP-Gateway-Token"] = tok
headers["Authorization"] = f"Bearer {tok}"

if bool(make_async_call):
# Determine workflow_id from Temporal activity context if not provided
try:
from temporalio import activity as _ta # type: ignore
if _ta.in_activity():
wf_id = _ta.info().workflow_id
else:
wf_id = None
except Exception:
wf_id = None

if not wf_id:
# Without workflow_id, we cannot route the signal back to the workflow
return {"error": "not_in_workflow_or_activity"}

if not signal_name:
return {"error": "missing_signal_name"}

url = f"{base}/internal/session/by-run/{quote(wf_id, safe='')}/{quote(execution_id, safe='')}/async-request"
# Fire-and-forget style: return immediately after enqueuing on server
timeout_str = os.environ.get("MCP_GATEWAY_REQUEST_TIMEOUT")
if timeout_str is None:
timeout = httpx.Timeout(None)
else:
try:
timeout = float(str(timeout_str).strip())
except Exception:
timeout = httpx.Timeout(None)
try:
async with httpx.AsyncClient(timeout=timeout) as client:
r = await client.post(
url,
json={
"method": method,
"params": params or {},
"signal_name": signal_name,
},
headers=headers,
)
except httpx.RequestError:
return {"error": "request_failed"}
if r.status_code >= 400:
return {"error": r.text}
# No payload is expected for async path beyond ack
return None

# Synchronous request path
url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/request"
# Requests require a response; default to no HTTP timeout.
# Configure with MCP_GATEWAY_REQUEST_TIMEOUT (seconds). If unset or <= 0, no timeout is applied.
timeout_str = os.environ.get("MCP_GATEWAY_REQUEST_TIMEOUT")
Expand Down Expand Up @@ -192,3 +249,29 @@ async def request_via_proxy(
return r.json() if r.content else {"error": "invalid_response"}
except ValueError:
return {"error": "invalid_response"}


# Backward-compatible wrapper accepting positional or keyword args
async def request_via_proxy(*args, **kwargs) -> Dict[str, Any] | None:
"""Backward-compatible wrapper for request_via_proxy.

Supports both positional (execution_id, method, params) and keyword-only usage,
and forwards optional async parameters when provided as keywords.
"""
if args:
# Extract legacy positional args
execution_id = args[0] if len(args) > 0 else kwargs.get("execution_id")
method = args[1] if len(args) > 1 else kwargs.get("method")
params = args[2] if len(args) > 2 else kwargs.get("params")
# Remaining arguments must be passed as keywords (gateway_url, gateway_token, make_async_call, signal_name)
return await _request_via_proxy_impl(
execution_id=execution_id,
method=method,
params=params,
gateway_url=kwargs.get("gateway_url"),
gateway_token=kwargs.get("gateway_token"),
make_async_call=kwargs.get("make_async_call"),
signal_name=kwargs.get("signal_name"),
)
# Pure keyword usage
return await _request_via_proxy_impl(**kwargs)
Loading
Loading