From 6cfc614fac705b73f06227076887bbf58f2b9325 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 19 Sep 2025 02:05:04 -0400 Subject: [PATCH] Get elicitation and sampling to work with signal mailbox --- .../executor/temporal/session_proxy.py | 24 +- .../executor/temporal/system_activities.py | 33 ++- src/mcp_agent/mcp/client_proxy.py | 89 ++++++- src/mcp_agent/server/app_server.py | 244 ++++++++++++++++-- 4 files changed, 361 insertions(+), 29 deletions(-) diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py index ea4a6e809..1e46f2afe 100644 --- a/src/mcp_agent/executor/temporal/session_proxy.py +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -116,15 +116,35 @@ async def request( return {"error": "missing_execution_id"} if _in_workflow_runtime(): + # In workflow context, dispatch an async request and wait for a signal + # indicating the response payload. The activity returns a unique signal + # name that the server will use to signal the workflow with the result. + from temporalio import workflow as _twf # type: ignore + act = self._context.task_registry.get_activity("mcp_relay_request") - return await self._executor.execute( + signal_name = await self._executor.execute( act, + True, # make_async_call exec_id, method, params or {}, ) + + # Wait for the response via workflow signal + info = _twf.info() + payload = await self._context.executor.wait_for_signal( # type: ignore[attr-defined] + signal_name, + workflow_id=info.workflow_id, + run_id=info.run_id, + signal_description=f"Waiting for async response to {method}", + # Timeout can be controlled by Temporal workflow/activity timeouts + ) + return payload return await self._system_activities.relay_request( - exec_id, method, params or {} + False, # synchronous call path + exec_id, + method, + params or {}, ) async def send_notification( diff --git a/src/mcp_agent/executor/temporal/system_activities.py b/src/mcp_agent/executor/temporal/system_activities.py index aff8c7f12..82b34683b 100644 --- a/src/mcp_agent/executor/temporal/system_activities.py +++ b/src/mcp_agent/executor/temporal/system_activities.py @@ -1,6 +1,7 @@ from typing import Any, Dict import anyio import os +import uuid from temporalio import activity @@ -90,11 +91,39 @@ async def relay_notify( @activity.defn(name="mcp_relay_request") async def relay_request( - self, execution_id: str, method: str, params: Dict[str, Any] | None = None - ) -> Dict[str, Any]: + self, + make_async_call: bool, + execution_id: str, + method: str, + params: Dict[str, Any] | None = None, + ) -> Any: + """ + Relay a server->client request via the gateway. + + - If make_async_call is False: performs a synchronous RPC and returns the JSON result. + - If make_async_call is True: kicks off an async request on the server that will signal + the workflow with the result; returns a unique signal_name that the workflow should wait on. + """ gateway_url = getattr(self.context, "gateway_url", None) gateway_token = getattr(self.context, "gateway_token", None) + + if make_async_call: + # Create a unique signal name for this request + signal_name = f"mcp_rpc_{method}_{uuid.uuid4().hex}" + await request_via_proxy( + make_async_call=True, + execution_id=execution_id, + method=method, + params=params or {}, + signal_name=signal_name, + gateway_url=gateway_url, + gateway_token=gateway_token, + ) + return signal_name + + # Synchronous path return await request_via_proxy( + make_async_call=False, execution_id=execution_id, method=method, params=params or {}, diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py index 5f4394e93..b8ab39a46 100644 --- a/src/mcp_agent/mcp/client_proxy.py +++ b/src/mcp_agent/mcp/client_proxy.py @@ -148,21 +148,78 @@ async def notify_via_proxy( return bool(resp.get("ok", True)) -async def request_via_proxy( +async def _request_via_proxy_impl( execution_id: str, method: str, params: Dict[str, Any] | None = None, *, gateway_url: Optional[str] = None, gateway_token: Optional[str] = None, -) -> Dict[str, Any]: + make_async_call: Optional[bool] = None, + signal_name: Optional[str] = None, +) -> Dict[str, Any] | None: + """ + Relay a server->client request via the gateway. + + - If make_async_call is falsy/None: perform synchronous HTTP RPC and return the JSON result. + - If make_async_call is True: trigger an async request on the server that will signal the + workflow with the result, then return None (the workflow should wait on signal_name). + """ base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None) - url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/request" headers: Dict[str, str] = {} tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: headers["X-MCP-Gateway-Token"] = tok headers["Authorization"] = f"Bearer {tok}" + + if bool(make_async_call): + # Determine workflow_id from Temporal activity context if not provided + try: + from temporalio import activity as _ta # type: ignore + if _ta.in_activity(): + wf_id = _ta.info().workflow_id + else: + wf_id = None + except Exception: + wf_id = None + + if not wf_id: + # Without workflow_id, we cannot route the signal back to the workflow + return {"error": "not_in_workflow_or_activity"} + + if not signal_name: + return {"error": "missing_signal_name"} + + url = f"{base}/internal/session/by-run/{quote(wf_id, safe='')}/{quote(execution_id, safe='')}/async-request" + # Fire-and-forget style: return immediately after enqueuing on server + timeout_str = os.environ.get("MCP_GATEWAY_REQUEST_TIMEOUT") + if timeout_str is None: + timeout = httpx.Timeout(None) + else: + try: + timeout = float(str(timeout_str).strip()) + except Exception: + timeout = httpx.Timeout(None) + try: + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, + json={ + "method": method, + "params": params or {}, + "signal_name": signal_name, + }, + headers=headers, + ) + except httpx.RequestError: + return {"error": "request_failed"} + if r.status_code >= 400: + return {"error": r.text} + # No payload is expected for async path beyond ack + return None + + # Synchronous request path + url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/request" # Requests require a response; default to no HTTP timeout. # Configure with MCP_GATEWAY_REQUEST_TIMEOUT (seconds). If unset or <= 0, no timeout is applied. timeout_str = os.environ.get("MCP_GATEWAY_REQUEST_TIMEOUT") @@ -192,3 +249,29 @@ async def request_via_proxy( return r.json() if r.content else {"error": "invalid_response"} except ValueError: return {"error": "invalid_response"} + + +# Backward-compatible wrapper accepting positional or keyword args +async def request_via_proxy(*args, **kwargs) -> Dict[str, Any] | None: + """Backward-compatible wrapper for request_via_proxy. + + Supports both positional (execution_id, method, params) and keyword-only usage, + and forwards optional async parameters when provided as keywords. + """ + if args: + # Extract legacy positional args + execution_id = args[0] if len(args) > 0 else kwargs.get("execution_id") + method = args[1] if len(args) > 1 else kwargs.get("method") + params = args[2] if len(args) > 2 else kwargs.get("params") + # Remaining arguments must be passed as keywords (gateway_url, gateway_token, make_async_call, signal_name) + return await _request_via_proxy_impl( + execution_id=execution_id, + method=method, + params=params, + gateway_url=kwargs.get("gateway_url"), + gateway_token=kwargs.get("gateway_token"), + make_async_call=kwargs.get("make_async_call"), + signal_name=kwargs.get("signal_name"), + ) + # Pure keyword usage + return await _request_via_proxy_impl(**kwargs) diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 01bea8326..3d3df1538 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -358,6 +358,25 @@ def _get_fallback_upstream_session() -> Any | None: return None return None + # Helper function for shared authentication across internal endpoints + def _check_gateway_auth(request: Request) -> JSONResponse | None: + gw_token = os.environ.get("MCP_GATEWAY_TOKEN") + if not gw_token: + return None + bearer = request.headers.get("Authorization", "") + bearer_token = ( + bearer.split(" ", 1)[1] + if bearer.lower().startswith("bearer ") + else "" + ) + header_tok = request.headers.get("X-MCP-Gateway-Token", "") + if not ( + secrets.compare_digest(header_tok, gw_token) + or secrets.compare_digest(bearer_token, gw_token) + ): + return JSONResponse({"error": "unauthorized"}, status_code=401) + return None + @mcp_server.custom_route( "/internal/session/by-run/{execution_id}/notify", methods=["POST"], @@ -370,22 +389,9 @@ async def _relay_notify(request: Request): params = body.get("params") or {} # Optional shared-secret auth - gw_token = os.environ.get("MCP_GATEWAY_TOKEN") - if gw_token: - bearer = request.headers.get("Authorization", "") - bearer_token = ( - bearer.split(" ", 1)[1] - if bearer.lower().startswith("bearer ") - else "" - ) - header_tok = request.headers.get("X-MCP-Gateway-Token", "") - if not ( - secrets.compare_digest(header_tok, gw_token) - or secrets.compare_digest(bearer_token, gw_token) - ): - return JSONResponse( - {"ok": False, "error": "unauthorized"}, status_code=401 - ) + auth_err = _check_gateway_auth(request) + if auth_err: + return auth_err # Optional idempotency handling idempotency_key = params.get("idempotency_key") @@ -526,6 +532,108 @@ async def _relay_notify(request: Request): {"ok": False, "error": str(e_mapped)}, status_code=500 ) + # Helper functions for handling requests through a session; shared by sync and async endpoints + async def _handle_request_via_rpc( + session, method: str, params: Dict[str, Any], execution_id: str, log_prefix: str = "request" + ) -> Any | None: + rpc = getattr(session, "rpc", None) + if rpc and hasattr(rpc, "request"): + result = await rpc.request(method, params) + try: + logger.debug( + f"[{log_prefix}] delivered via session_id={id(session)} (generic '{method}')" + ) + except Exception: + pass + return result + return None + + async def _handle_specific_request( + session, method: str, params: Dict[str, Any], log_prefix: str = "request" + ) -> Any: + from mcp.types import ( + CreateMessageRequest, + CreateMessageRequestParams, + CreateMessageResult, + ElicitRequest, + ElicitRequestParams, + ElicitResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + EmptyResult, + ServerRequest, + ) + if method == "sampling/createMessage": + req = ServerRequest( + CreateMessageRequest( + method="sampling/createMessage", + params=CreateMessageRequestParams(**params), + ) + ) + result = await session.send_request( # type: ignore[attr-defined] + request=req, result_type=CreateMessageResult + ) + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + elif method == "elicitation/create": + req = ServerRequest( + ElicitRequest( + method="elicitation/create", + params=ElicitRequestParams(**params), + ) + ) + result = await session.send_request( # type: ignore[attr-defined] + request=req, result_type=ElicitResult + ) + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + elif method == "roots/list": + req = ServerRequest(ListRootsRequest(method="roots/list")) + result = await session.send_request( # type: ignore[attr-defined] + request=req, result_type=ListRootsResult + ) + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + elif method == "ping": + req = ServerRequest(PingRequest(method="ping")) + result = await session.send_request( # type: ignore[attr-defined] + request=req, result_type=EmptyResult + ) + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + else: + raise ValueError(f"unsupported method: {method}") + + async def _try_session_request( + session, + method: str, + params: Dict[str, Any], + execution_id: str, + *, + log_prefix: str = "request", + register_session: bool = False, + ) -> Any: + # Try RPC first + result = await _handle_request_via_rpc( + session, method, params, execution_id, log_prefix + ) + if result is not None: + if register_session: + try: + await _register_session( + run_id=execution_id, execution_id=execution_id, session=session + ) + except Exception: + pass + return result + # Fallback to typed mapping + result = await _handle_specific_request(session, method, params, log_prefix) + if register_session: + try: + await _register_session( + run_id=execution_id, execution_id=execution_id, session=session + ) + except Exception: + pass + return result + @mcp_server.custom_route( "/internal/session/by-run/{execution_id}/request", methods=["POST"], @@ -551,6 +659,11 @@ async def _relay_request(request: Request): method = body.get("method") params = body.get("params") or {} + # Optional shared-secret auth + auth_err = _check_gateway_auth(request) + if auth_err: + return auth_err + # Prefer latest upstream session first latest_session = _get_fallback_upstream_session() if latest_session is not None: @@ -558,18 +671,12 @@ async def _relay_request(request: Request): rpc = getattr(latest_session, "rpc", None) if rpc and hasattr(rpc, "request"): result = await rpc.request(method, params) - # logger.debug( - # f"[request] delivered via latest session_id={id(latest_session)} (generic '{method}')" - # ) try: await _register_session( run_id=execution_id, execution_id=execution_id, session=latest_session, ) - logger.debug( - f"[request] rebound mapping to latest session_id={id(latest_session)} for execution_id={execution_id}" - ) except Exception: pass return JSONResponse(result) @@ -745,6 +852,99 @@ async def _relay_request(request: Request): pass return JSONResponse({"error": str(e)}, status_code=500) + @mcp_server.custom_route( + "/internal/session/by-run/{workflow_id}/{execution_id}/async-request", + methods=["POST"], + include_in_schema=False, + ) + async def _relay_async_request(request: Request): + """Start an async RPC to the connected client and signal the workflow with the result. + + Body: { method: str, params: dict, signal_name: str } + Path: workflow_id, execution_id (run_id) + """ + body = await request.json() + execution_id = request.path_params.get("execution_id") + workflow_id = request.path_params.get("workflow_id") + method = body.get("method") + params = body.get("params") or {} + signal_name = body.get("signal_name") + + # Auth + auth_err = _check_gateway_auth(request) + if auth_err: + return auth_err + + if not signal_name: + return JSONResponse({"error": "missing_signal_name"}, status_code=400) + + async def _do_async(): + result: Dict[str, Any] | None = None + error: str | None = None + try: + # Try latest session first + latest_session = _get_fallback_upstream_session() + if latest_session is not None: + try: + result = await _try_session_request( + latest_session, + method, + params, + execution_id, + log_prefix="async-request", + register_session=True, + ) + except Exception as e_latest: + try: + logger.warning( + f"[async-request] latest session failed for execution_id={execution_id} method={method}: {e_latest}" + ) + except Exception: + pass + + # Fallback to mapped session + if result is None: + session = await _get_session(execution_id) + if not session: + error = "session_not_available" + else: + try: + result = await _try_session_request( + session, + method, + params, + execution_id, + log_prefix="async-request", + register_session=False, + ) + except Exception as e_sess: + error = str(e_sess) + except Exception as e: + error = str(e) + + # Signal the workflow with the result or error + try: + app = _get_attached_app(mcp_server) + if app and app.context and getattr(app.context, "executor", None): + executor = app.context.executor + client = getattr(executor, "client", None) + if client and workflow_id and execution_id: + handle = client.get_workflow_handle( + workflow_id=workflow_id, run_id=execution_id + ) + payload = result if error is None else {"error": error} + await handle.signal(signal_name, payload) + except Exception as se: + try: + logger.error(f"[async-request] failed to signal workflow: {se}") + except Exception: + pass + + asyncio.create_task(_do_async()) + return JSONResponse( + {"status": "received", "execution_id": execution_id, "signal_name": signal_name} + ) + @mcp_server.custom_route( "/internal/workflows/log", methods=["POST"], include_in_schema=False )