diff --git a/src/mcp_agent/executor/temporal/system_activities.py b/src/mcp_agent/executor/temporal/system_activities.py index b215ece79..c065ea9fa 100644 --- a/src/mcp_agent/executor/temporal/system_activities.py +++ b/src/mcp_agent/executor/temporal/system_activities.py @@ -23,11 +23,9 @@ async def forward_log( message: str, data: Dict[str, Any] | None = None, ) -> bool: - registry = self.context.server_registry gateway_url = getattr(self.context, "gateway_url", None) gateway_token = getattr(self.context, "gateway_token", None) return await log_via_proxy( - registry, execution_id=execution_id, level=level, namespace=namespace, @@ -47,11 +45,9 @@ async def request_user_input( signal_name: str = "human_input", ) -> Dict[str, Any]: # Reuse proxy ask API; returns {result} or {error} - registry = self.context.server_registry gateway_url = getattr(self.context, "gateway_url", None) gateway_token = getattr(self.context, "gateway_token", None) return await ask_via_proxy( - registry, execution_id=execution_id, prompt=prompt, metadata={ @@ -67,11 +63,9 @@ async def request_user_input( async def relay_notify( self, execution_id: str, method: str, params: Dict[str, Any] | None = None ) -> bool: - registry = self.context.server_registry gateway_url = getattr(self.context, "gateway_url", None) gateway_token = getattr(self.context, "gateway_token", None) return await notify_via_proxy( - registry, execution_id=execution_id, method=method, params=params or {}, @@ -83,11 +77,9 @@ async def relay_notify( async def relay_request( self, execution_id: str, method: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: - registry = self.context.server_registry gateway_url = getattr(self.context, "gateway_url", None) gateway_token = getattr(self.context, "gateway_token", None) return await request_via_proxy( - registry, execution_id=execution_id, method=method, params=params or {}, diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py index af7d8f34b..e289b1059 100644 --- a/src/mcp_agent/mcp/client_proxy.py +++ b/src/mcp_agent/mcp/client_proxy.py @@ -3,52 +3,56 @@ import os import httpx -from mcp_agent.mcp.mcp_server_registry import ServerRegistry from urllib.parse import quote def _resolve_gateway_url( - server_registry: Optional[ServerRegistry] = None, - server_name: Optional[str] = None, + *, gateway_url: Optional[str] = None, + context_gateway_url: Optional[str] = None, ) -> str: + """Resolve the base URL for the MCP gateway. + + Precedence: + 1) Explicit override (gateway_url parameter) + 2) Context-provided URL (context_gateway_url) + 3) Environment variable MCP_GATEWAY_URL + 4) Fallback to http://127.0.0.1:8000 (dev default) + """ # Highest precedence: explicit override if gateway_url: return gateway_url.rstrip("/") + # Next: context-provided URL (e.g., from Temporal workflow memo) + if context_gateway_url: + return context_gateway_url.rstrip("/") + # Next: environment variable env_url = os.environ.get("MCP_GATEWAY_URL") if env_url: return env_url.rstrip("/") - # Next: a registry entry (if provided) - if server_registry and server_name: - cfg = server_registry.get_server_config(server_name) - if cfg and getattr(cfg, "url", None): - return cfg.url.rstrip("/") - # Fallback: default local server return "http://127.0.0.1:8000" async def log_via_proxy( - server_registry: Optional[ServerRegistry], execution_id: str, level: str, namespace: str, message: str, data: Dict[str, Any] | None = None, *, - server_name: Optional[str] = None, gateway_url: Optional[str] = None, gateway_token: Optional[str] = None, ) -> bool: - base = _resolve_gateway_url(server_registry, server_name, gateway_url) + base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None) url = f"{base}/internal/workflows/log" headers: Dict[str, str] = {} tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: headers["X-MCP-Gateway-Token"] = tok + headers["Authorization"] = f"Bearer {tok}" timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) try: async with httpx.AsyncClient(timeout=timeout) as client: @@ -75,21 +79,20 @@ async def log_via_proxy( async def ask_via_proxy( - server_registry: Optional[ServerRegistry], execution_id: str, prompt: str, metadata: Dict[str, Any] | None = None, *, - server_name: Optional[str] = None, gateway_url: Optional[str] = None, gateway_token: Optional[str] = None, ) -> Dict[str, Any]: - base = _resolve_gateway_url(server_registry, server_name, gateway_url) + base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None) url = f"{base}/internal/human/prompts" headers: Dict[str, str] = {} tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: headers["X-MCP-Gateway-Token"] = tok + headers["Authorization"] = f"Bearer {tok}" timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) try: async with httpx.AsyncClient(timeout=timeout) as client: @@ -113,21 +116,20 @@ async def ask_via_proxy( async def notify_via_proxy( - server_registry: Optional[ServerRegistry], execution_id: str, method: str, params: Dict[str, Any] | None = None, *, - server_name: Optional[str] = None, gateway_url: Optional[str] = None, gateway_token: Optional[str] = None, ) -> bool: - base = _resolve_gateway_url(server_registry, server_name, gateway_url) + base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None) url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/notify" headers: Dict[str, str] = {} tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: headers["X-MCP-Gateway-Token"] = tok + headers["Authorization"] = f"Bearer {tok}" timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) try: @@ -147,21 +149,20 @@ async def notify_via_proxy( async def request_via_proxy( - server_registry: Optional[ServerRegistry], execution_id: str, method: str, params: Dict[str, Any] | None = None, *, - server_name: Optional[str] = None, gateway_url: Optional[str] = None, gateway_token: Optional[str] = None, ) -> Dict[str, Any]: - base = _resolve_gateway_url(server_registry, server_name, gateway_url) + base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None) url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/request" headers: Dict[str, str] = {} tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") if tok: headers["X-MCP-Gateway-Token"] = tok + headers["Authorization"] = f"Bearer {tok}" timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "20")) try: async with httpx.AsyncClient(timeout=timeout) as client: diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 8b46735de..ec40a88a3 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -335,12 +335,19 @@ async def _relay_notify(request: Request): # Optional shared-secret auth gw_token = os.environ.get("MCP_GATEWAY_TOKEN") - if gw_token and not secrets.compare_digest( - request.headers.get("X-MCP-Gateway-Token", ""), gw_token - ): - return JSONResponse( - {"ok": False, "error": "unauthorized"}, status_code=401 + if gw_token: + bearer = request.headers.get("Authorization", "") + bearer_token = ( + bearer.split(" ", 1)[1] if bearer.lower().startswith("bearer ") else "" ) + header_tok = request.headers.get("X-MCP-Gateway-Token", "") + if not ( + secrets.compare_digest(header_tok, gw_token) + or secrets.compare_digest(bearer_token, gw_token) + ): + return JSONResponse( + {"ok": False, "error": "unauthorized"}, status_code=401 + ) # Optional idempotency handling idempotency_key = params.get("idempotency_key") @@ -395,6 +402,9 @@ async def _relay_notify(request: Request): return JSONResponse({"ok": True}) except Exception as e: + # After workflow cleanup, upstream sessions may be closed. Treat notify as best-effort. + if isinstance(method, str) and method.startswith("notifications/"): + return JSONResponse({"ok": True, "dropped": True}) return JSONResponse({"ok": False, "error": str(e)}, status_code=500) @mcp_server.custom_route( @@ -499,12 +509,19 @@ async def _internal_workflows_log(request: Request): # Optional shared-secret auth gw_token = os.environ.get("MCP_GATEWAY_TOKEN") - if gw_token and not secrets.compare_digest( - request.headers.get("X-MCP-Gateway-Token", ""), gw_token - ): - return JSONResponse( - {"ok": False, "error": "unauthorized"}, status_code=401 + if gw_token: + bearer = request.headers.get("Authorization", "") + bearer_token = ( + bearer.split(" ", 1)[1] if bearer.lower().startswith("bearer ") else "" ) + header_tok = request.headers.get("X-MCP-Gateway-Token", "") + if not ( + secrets.compare_digest(header_tok, gw_token) + or secrets.compare_digest(bearer_token, gw_token) + ): + return JSONResponse( + {"ok": False, "error": "unauthorized"}, status_code=401 + ) session = await _get_session(execution_id) if not session: @@ -538,10 +555,17 @@ async def _internal_human_prompts(request: Request): # Optional shared-secret auth gw_token = os.environ.get("MCP_GATEWAY_TOKEN") - if gw_token and not secrets.compare_digest( - request.headers.get("X-MCP-Gateway-Token", ""), gw_token - ): - return JSONResponse({"error": "unauthorized"}, status_code=401) + if gw_token: + bearer = request.headers.get("Authorization", "") + bearer_token = ( + bearer.split(" ", 1)[1] if bearer.lower().startswith("bearer ") else "" + ) + header_tok = request.headers.get("X-MCP-Gateway-Token", "") + if not ( + secrets.compare_digest(header_tok, gw_token) + or secrets.compare_digest(bearer_token, gw_token) + ): + return JSONResponse({"error": "unauthorized"}, status_code=401) session = await _get_session(execution_id) if not session: @@ -1367,38 +1391,61 @@ async def _workflow_run( # Build memo for Temporal runs if gateway info is available workflow_memo = None try: - # Prefer explicit kwargs, else infer from request headers/environment - # FastMCP keeps raw request under ctx.request_context.request if available + # Prefer explicit kwargs, else infer from request context/headers gateway_url = kwargs.get("gateway_url") gateway_token = kwargs.get("gateway_token") - if gateway_url is None: - try: - req = getattr(ctx.request_context, "request", None) - if req is not None: - # Custom header if present - h = req.headers - gateway_url = ( - h.get("X-MCP-Gateway-URL") - or h.get("X-Forwarded-Url") - or h.get("X-Forwarded-Proto") - ) - # Best-effort reconstruction if only proto/host provided - if gateway_url is None: - proto = h.get("X-Forwarded-Proto") or "http" - host = h.get("X-Forwarded-Host") or h.get("Host") - if host: - gateway_url = f"{proto}://{host}" - except Exception: - pass + req = getattr(ctx.request_context, "request", None) + if req is not None: + h = req.headers + # Highest precedence: caller-provided full base URL + header_url = h.get("X-MCP-Gateway-URL") or h.get("X-Forwarded-Url") + if gateway_url is None and header_url: + gateway_url = header_url + + # Token may be provided by the gateway/proxy + if gateway_token is None: + gateway_token = h.get("X-MCP-Gateway-Token") + if gateway_token is None: + # Support Authorization: Bearer + auth = h.get("Authorization") + if auth and auth.lower().startswith("bearer "): + gateway_token = auth.split(" ", 1)[1] + + # Prefer explicit reconstruction from X-Forwarded-* if present + if gateway_url is None and (h.get("X-Forwarded-Host") or h.get("Host")): + proto = h.get("X-Forwarded-Proto") or "http" + host = h.get("X-Forwarded-Host") or h.get("Host") + prefix = h.get("X-Forwarded-Prefix") or "" + if prefix and not prefix.startswith("/"): + prefix = "/" + prefix + if host: + gateway_url = f"{proto}://{host}{prefix}" + + # Fallback to request's base_url which already includes scheme/host and any mount prefix + if gateway_url is None: + try: + if getattr(req, "base_url", None): + base_url = str(req.base_url).rstrip("/") + if base_url and base_url.lower() != "none": + gateway_url = base_url + except Exception: + gateway_url = None - if gateway_token is None: - try: - req = getattr(ctx.request_context, "request", None) - if req is not None: - gateway_token = req.headers.get("X-MCP-Gateway-Token") - except Exception: - pass + # Final fallback: environment variables (useful if proxies don't set headers) + try: + import os as _os + + if gateway_url is None: + env_url = _os.environ.get("MCP_GATEWAY_URL") + if env_url: + gateway_url = env_url + if gateway_token is None: + env_tok = _os.environ.get("MCP_GATEWAY_TOKEN") + if env_tok: + gateway_token = env_tok + except Exception: + pass if gateway_url or gateway_token: workflow_memo = { diff --git a/tests/executor/temporal/test_execution_id_and_interceptor.py b/tests/executor/temporal/test_execution_id_and_interceptor.py index 7aa5f5cb5..20cab3563 100644 --- a/tests/executor/temporal/test_execution_id_and_interceptor.py +++ b/tests/executor/temporal/test_execution_id_and_interceptor.py @@ -94,19 +94,19 @@ async def post(self, url, json=None, headers=None): client_proxy.httpx, "AsyncClient", lambda timeout: Client(rcodes) ) - ok = await client_proxy.log_via_proxy(None, "run", "info", "ns", "msg") + ok = await client_proxy.log_via_proxy("run", "info", "ns", "msg") assert ok is True - ok = await client_proxy.log_via_proxy(None, "run", "info", "ns", "msg") + ok = await client_proxy.log_via_proxy("run", "info", "ns", "msg") assert ok is False # notify ok, then error - ok = await client_proxy.notify_via_proxy(None, "run", "m", {}) + ok = await client_proxy.notify_via_proxy("run", "m", {}) assert ok is True - ok = await client_proxy.notify_via_proxy(None, "run", "m", {}) + ok = await client_proxy.notify_via_proxy("run", "m", {}) assert ok is False # request ok, then error - res = await client_proxy.request_via_proxy(None, "run", "m", {}) + res = await client_proxy.request_via_proxy("run", "m", {}) assert isinstance(res, dict) and res.get("ok", True) in (True,) - res = await client_proxy.request_via_proxy(None, "run", "m", {}) + res = await client_proxy.request_via_proxy("run", "m", {}) assert isinstance(res, dict) and "error" in res diff --git a/tests/server/test_app_server_memo.py b/tests/server/test_app_server_memo.py new file mode 100644 index 000000000..daffcfeed --- /dev/null +++ b/tests/server/test_app_server_memo.py @@ -0,0 +1,114 @@ +import pytest +from types import SimpleNamespace + + +class FakeWorkflow: + def __init__(self): + self.captured_memo = None + + @classmethod + async def create(cls, name: str, context): + return cls() + + async def run_async(self, *args, **kwargs): + # Capture the internal memo passed by the server layer + self.captured_memo = kwargs.get("__mcp_agent_workflow_memo") + # Return a minimal execution-like object + return SimpleNamespace(workflow_id="wf-1", run_id="run-1") + + +@pytest.mark.anyio +async def test_memo_from_forwarded_headers(monkeypatch): + from mcp_agent.server import app_server + + # Patch workflow resolution to return our FakeWorkflow and a dummy context + monkeypatch.setattr( + app_server, + "_resolve_workflows_and_context", + lambda ctx: ({"TestWorkflow": FakeWorkflow}, SimpleNamespace()), + ) + # Avoid registry side effects + monkeypatch.setattr(app_server, "_register_session", lambda *a, **k: None) + + # Construct a request-like object with only X-Forwarded-* headers + headers = { + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "app.mcpac.dev", + "X-Forwarded-Prefix": "/abc123", + } + req = SimpleNamespace(headers=headers, base_url="https://ignored/base/") + ctx = SimpleNamespace( + request_context=SimpleNamespace(request=req), fastmcp=SimpleNamespace() + ) + + # Run the private helper + result = await app_server._workflow_run(ctx, "TestWorkflow") + assert result["workflow_id"] == "wf-1" + assert result["run_id"] == "run-1" + + # Verify FakeWorkflow captured memo with full URL reconstructed from X-Forwarded-* + # Fetch the workflow instance created within _workflow_run by inspecting patched resolution + # Easiest: call again but capture via a local workflow instance + # Alternatively, patch FakeWorkflow to store last_memo globally; simpler approach below: + + # Build a workflow instance and invoke run_async directly to assert memo composition via same code path + # Instead, patch FakeWorkflow.create to stash instance + captured = {} + + async def create_and_stash(name: str, context): + wf = FakeWorkflow() + captured["wf"] = wf + return wf + + monkeypatch.setattr( + FakeWorkflow, + "create", + classmethod(lambda cls, name, context: create_and_stash(name, context)), + ) + + _ = await app_server._workflow_run(ctx, "TestWorkflow") + memo = captured["wf"].captured_memo + assert memo is not None + assert memo.get("gateway_url") == "https://app.mcpac.dev/abc123" + # No token provided in headers + assert memo.get("gateway_token") in (None, "") + + +@pytest.mark.anyio +async def test_memo_falls_back_to_env(monkeypatch): + from mcp_agent.server import app_server + + monkeypatch.setattr( + app_server, + "_resolve_workflows_and_context", + lambda ctx: ({"TestWorkflow": FakeWorkflow}, SimpleNamespace()), + ) + monkeypatch.setattr(app_server, "_register_session", lambda *a, **k: None) + + # No headers at all; env should be used + req = SimpleNamespace(headers={}, base_url=None) + ctx = SimpleNamespace( + request_context=SimpleNamespace(request=req), fastmcp=SimpleNamespace() + ) + + monkeypatch.setenv("MCP_GATEWAY_URL", "http://example:9000/base") + monkeypatch.setenv("MCP_GATEWAY_TOKEN", "secret-token") + + captured = {} + + async def create_and_stash(name: str, context): + wf = FakeWorkflow() + captured["wf"] = wf + return wf + + monkeypatch.setattr( + FakeWorkflow, + "create", + classmethod(lambda cls, name, context: create_and_stash(name, context)), + ) + + _ = await app_server._workflow_run(ctx, "TestWorkflow") + memo = captured["wf"].captured_memo + assert memo is not None + assert memo.get("gateway_url") == "http://example:9000/base" + assert memo.get("gateway_token") == "secret-token"