Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions src/mcp_agent/executor/temporal/system_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ async def forward_log(
message: str,
data: Dict[str, Any] | None = None,
) -> bool:
registry = self.context.server_registry
gateway_url = getattr(self.context, "gateway_url", None)
gateway_token = getattr(self.context, "gateway_token", None)
return await log_via_proxy(
registry,
execution_id=execution_id,
level=level,
namespace=namespace,
Expand All @@ -47,11 +45,9 @@ async def request_user_input(
signal_name: str = "human_input",
) -> Dict[str, Any]:
# Reuse proxy ask API; returns {result} or {error}
registry = self.context.server_registry
gateway_url = getattr(self.context, "gateway_url", None)
gateway_token = getattr(self.context, "gateway_token", None)
return await ask_via_proxy(
registry,
execution_id=execution_id,
prompt=prompt,
metadata={
Expand All @@ -67,11 +63,9 @@ async def request_user_input(
async def relay_notify(
self, execution_id: str, method: str, params: Dict[str, Any] | None = None
) -> bool:
registry = self.context.server_registry
gateway_url = getattr(self.context, "gateway_url", None)
gateway_token = getattr(self.context, "gateway_token", None)
return await notify_via_proxy(
registry,
execution_id=execution_id,
method=method,
params=params or {},
Expand All @@ -83,11 +77,9 @@ async def relay_notify(
async def relay_request(
self, execution_id: str, method: str, params: Dict[str, Any] | None = None
) -> Dict[str, Any]:
registry = self.context.server_registry
gateway_url = getattr(self.context, "gateway_url", None)
gateway_token = getattr(self.context, "gateway_token", None)
return await request_via_proxy(
registry,
execution_id=execution_id,
method=method,
params=params or {},
Expand Down
43 changes: 22 additions & 21 deletions src/mcp_agent/mcp/client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,56 @@
import os
import httpx

from mcp_agent.mcp.mcp_server_registry import ServerRegistry
from urllib.parse import quote


def _resolve_gateway_url(
server_registry: Optional[ServerRegistry] = None,
server_name: Optional[str] = None,
*,
gateway_url: Optional[str] = None,
context_gateway_url: Optional[str] = None,
) -> str:
"""Resolve the base URL for the MCP gateway.

Precedence:
1) Explicit override (gateway_url parameter)
2) Context-provided URL (context_gateway_url)
3) Environment variable MCP_GATEWAY_URL
4) Fallback to http://127.0.0.1:8000 (dev default)
"""
# Highest precedence: explicit override
if gateway_url:
return gateway_url.rstrip("/")

# Next: context-provided URL (e.g., from Temporal workflow memo)
if context_gateway_url:
return context_gateway_url.rstrip("/")

# Next: environment variable
env_url = os.environ.get("MCP_GATEWAY_URL")
if env_url:
return env_url.rstrip("/")

# Next: a registry entry (if provided)
if server_registry and server_name:
cfg = server_registry.get_server_config(server_name)
if cfg and getattr(cfg, "url", None):
return cfg.url.rstrip("/")

# Fallback: default local server
return "http://127.0.0.1:8000"


async def log_via_proxy(
server_registry: Optional[ServerRegistry],
execution_id: str,
level: str,
namespace: str,
message: str,
data: Dict[str, Any] | None = None,
*,
server_name: Optional[str] = None,
gateway_url: Optional[str] = None,
gateway_token: Optional[str] = None,
) -> bool:
base = _resolve_gateway_url(server_registry, server_name, gateway_url)
base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None)
url = f"{base}/internal/workflows/log"
headers: Dict[str, str] = {}
tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN")
if tok:
headers["X-MCP-Gateway-Token"] = tok
headers["Authorization"] = f"Bearer {tok}"
timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10"))
try:
async with httpx.AsyncClient(timeout=timeout) as client:
Expand All @@ -75,21 +79,20 @@ async def log_via_proxy(


async def ask_via_proxy(
server_registry: Optional[ServerRegistry],
execution_id: str,
prompt: str,
metadata: Dict[str, Any] | None = None,
*,
server_name: Optional[str] = None,
gateway_url: Optional[str] = None,
gateway_token: Optional[str] = None,
) -> Dict[str, Any]:
base = _resolve_gateway_url(server_registry, server_name, gateway_url)
base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None)
url = f"{base}/internal/human/prompts"
headers: Dict[str, str] = {}
tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN")
if tok:
headers["X-MCP-Gateway-Token"] = tok
headers["Authorization"] = f"Bearer {tok}"
timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10"))
try:
async with httpx.AsyncClient(timeout=timeout) as client:
Expand All @@ -113,21 +116,20 @@ async def ask_via_proxy(


async def notify_via_proxy(
server_registry: Optional[ServerRegistry],
execution_id: str,
method: str,
params: Dict[str, Any] | None = None,
*,
server_name: Optional[str] = None,
gateway_url: Optional[str] = None,
gateway_token: Optional[str] = None,
) -> bool:
base = _resolve_gateway_url(server_registry, server_name, gateway_url)
base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None)
url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/notify"
headers: Dict[str, str] = {}
tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN")
if tok:
headers["X-MCP-Gateway-Token"] = tok
headers["Authorization"] = f"Bearer {tok}"
timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10"))

try:
Expand All @@ -147,21 +149,20 @@ async def notify_via_proxy(


async def request_via_proxy(
server_registry: Optional[ServerRegistry],
execution_id: str,
method: str,
params: Dict[str, Any] | None = None,
*,
server_name: Optional[str] = None,
gateway_url: Optional[str] = None,
gateway_token: Optional[str] = None,
) -> Dict[str, Any]:
base = _resolve_gateway_url(server_registry, server_name, gateway_url)
base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None)
url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/request"
headers: Dict[str, str] = {}
tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN")
if tok:
headers["X-MCP-Gateway-Token"] = tok
headers["Authorization"] = f"Bearer {tok}"
timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "20"))
try:
async with httpx.AsyncClient(timeout=timeout) as client:
Expand Down
131 changes: 89 additions & 42 deletions src/mcp_agent/server/app_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,19 @@ async def _relay_notify(request: Request):

# Optional shared-secret auth
gw_token = os.environ.get("MCP_GATEWAY_TOKEN")
if gw_token and not secrets.compare_digest(
request.headers.get("X-MCP-Gateway-Token", ""), gw_token
):
return JSONResponse(
{"ok": False, "error": "unauthorized"}, status_code=401
if gw_token:
bearer = request.headers.get("Authorization", "")
bearer_token = (
bearer.split(" ", 1)[1] if bearer.lower().startswith("bearer ") else ""
)
header_tok = request.headers.get("X-MCP-Gateway-Token", "")
if not (
secrets.compare_digest(header_tok, gw_token)
or secrets.compare_digest(bearer_token, gw_token)
):
return JSONResponse(
{"ok": False, "error": "unauthorized"}, status_code=401
)

# Optional idempotency handling
idempotency_key = params.get("idempotency_key")
Expand Down Expand Up @@ -395,6 +402,9 @@ async def _relay_notify(request: Request):

return JSONResponse({"ok": True})
except Exception as e:
# After workflow cleanup, upstream sessions may be closed. Treat notify as best-effort.
if isinstance(method, str) and method.startswith("notifications/"):
return JSONResponse({"ok": True, "dropped": True})
return JSONResponse({"ok": False, "error": str(e)}, status_code=500)

@mcp_server.custom_route(
Expand Down Expand Up @@ -499,12 +509,19 @@ async def _internal_workflows_log(request: Request):

# Optional shared-secret auth
gw_token = os.environ.get("MCP_GATEWAY_TOKEN")
if gw_token and not secrets.compare_digest(
request.headers.get("X-MCP-Gateway-Token", ""), gw_token
):
return JSONResponse(
{"ok": False, "error": "unauthorized"}, status_code=401
if gw_token:
bearer = request.headers.get("Authorization", "")
bearer_token = (
bearer.split(" ", 1)[1] if bearer.lower().startswith("bearer ") else ""
)
header_tok = request.headers.get("X-MCP-Gateway-Token", "")
if not (
secrets.compare_digest(header_tok, gw_token)
or secrets.compare_digest(bearer_token, gw_token)
):
return JSONResponse(
{"ok": False, "error": "unauthorized"}, status_code=401
)

session = await _get_session(execution_id)
if not session:
Expand Down Expand Up @@ -538,10 +555,17 @@ async def _internal_human_prompts(request: Request):

# Optional shared-secret auth
gw_token = os.environ.get("MCP_GATEWAY_TOKEN")
if gw_token and not secrets.compare_digest(
request.headers.get("X-MCP-Gateway-Token", ""), gw_token
):
return JSONResponse({"error": "unauthorized"}, status_code=401)
if gw_token:
bearer = request.headers.get("Authorization", "")
bearer_token = (
bearer.split(" ", 1)[1] if bearer.lower().startswith("bearer ") else ""
)
header_tok = request.headers.get("X-MCP-Gateway-Token", "")
if not (
secrets.compare_digest(header_tok, gw_token)
or secrets.compare_digest(bearer_token, gw_token)
):
return JSONResponse({"error": "unauthorized"}, status_code=401)

session = await _get_session(execution_id)
if not session:
Expand Down Expand Up @@ -1367,38 +1391,61 @@ async def _workflow_run(
# Build memo for Temporal runs if gateway info is available
workflow_memo = None
try:
# Prefer explicit kwargs, else infer from request headers/environment
# FastMCP keeps raw request under ctx.request_context.request if available
# Prefer explicit kwargs, else infer from request context/headers
gateway_url = kwargs.get("gateway_url")
gateway_token = kwargs.get("gateway_token")

if gateway_url is None:
try:
req = getattr(ctx.request_context, "request", None)
if req is not None:
# Custom header if present
h = req.headers
gateway_url = (
h.get("X-MCP-Gateway-URL")
or h.get("X-Forwarded-Url")
or h.get("X-Forwarded-Proto")
)
# Best-effort reconstruction if only proto/host provided
if gateway_url is None:
proto = h.get("X-Forwarded-Proto") or "http"
host = h.get("X-Forwarded-Host") or h.get("Host")
if host:
gateway_url = f"{proto}://{host}"
except Exception:
pass
req = getattr(ctx.request_context, "request", None)
if req is not None:
h = req.headers
# Highest precedence: caller-provided full base URL
header_url = h.get("X-MCP-Gateway-URL") or h.get("X-Forwarded-Url")
if gateway_url is None and header_url:
gateway_url = header_url

# Token may be provided by the gateway/proxy
if gateway_token is None:
gateway_token = h.get("X-MCP-Gateway-Token")
if gateway_token is None:
# Support Authorization: Bearer <token>
auth = h.get("Authorization")
if auth and auth.lower().startswith("bearer "):
gateway_token = auth.split(" ", 1)[1]

# Prefer explicit reconstruction from X-Forwarded-* if present
if gateway_url is None and (h.get("X-Forwarded-Host") or h.get("Host")):
proto = h.get("X-Forwarded-Proto") or "http"
host = h.get("X-Forwarded-Host") or h.get("Host")
prefix = h.get("X-Forwarded-Prefix") or ""
if prefix and not prefix.startswith("/"):
prefix = "/" + prefix
if host:
gateway_url = f"{proto}://{host}{prefix}"

# Fallback to request's base_url which already includes scheme/host and any mount prefix
if gateway_url is None:
try:
if getattr(req, "base_url", None):
base_url = str(req.base_url).rstrip("/")
if base_url and base_url.lower() != "none":
gateway_url = base_url
except Exception:
gateway_url = None

if gateway_token is None:
try:
req = getattr(ctx.request_context, "request", None)
if req is not None:
gateway_token = req.headers.get("X-MCP-Gateway-Token")
except Exception:
pass
# Final fallback: environment variables (useful if proxies don't set headers)
try:
import os as _os

if gateway_url is None:
env_url = _os.environ.get("MCP_GATEWAY_URL")
if env_url:
gateway_url = env_url
if gateway_token is None:
env_tok = _os.environ.get("MCP_GATEWAY_TOKEN")
if env_tok:
gateway_token = env_tok
except Exception:
pass
Comment on lines +1394 to +1448
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Do not trust client-supplied gateway_url; gate on valid token and validate URL (SSRF/data exfil risk).

A malicious client can set X-MCP-Gateway-URL/X-Forwarded-* and steer worker callbacks (with potential sensitive logs/prompts) to an attacker-controlled URL. Only accept header-derived URLs when a valid gateway token accompanies the request; otherwise derive from request.base_url. Also normalize/validate scheme/host.

             # Prefer explicit kwargs, else infer from request context/headers
             gateway_url = kwargs.get("gateway_url")
             gateway_token = kwargs.get("gateway_token")

             req = getattr(ctx.request_context, "request", None)
             if req is not None:
                 h = req.headers
                 # Highest precedence: caller-provided full base URL
                 header_url = h.get("X-MCP-Gateway-URL") or h.get("X-Forwarded-Url")
-                if gateway_url is None and header_url:
-                    gateway_url = header_url
+                # Trust header URL only if caller presents a valid gateway token
+                env_tok = os.environ.get("MCP_GATEWAY_TOKEN")
+                trusted_headers = False
+                if gateway_token:
+                    try:
+                        trusted_headers = bool(env_tok and secrets.compare_digest(gateway_token, env_tok))
+                    except Exception:
+                        trusted_headers = False
+                if gateway_url is None and header_url and trusted_headers:
+                    gateway_url = header_url

                 # Token may be provided by the gateway/proxy
                 if gateway_token is None:
                     gateway_token = h.get("X-MCP-Gateway-Token")
                 if gateway_token is None:
                     # Support Authorization: Bearer <token>
                     auth = h.get("Authorization")
                     if auth and auth.lower().startswith("bearer "):
-                        gateway_token = auth.split(" ", 1)[1]
+                        gateway_token = auth.split(" ", 1)[1].strip()

                 # Prefer explicit reconstruction from X-Forwarded-* if present
-                if gateway_url is None and (h.get("X-Forwarded-Host") or h.get("Host")):
+                if gateway_url is None and trusted_headers and (h.get("X-Forwarded-Host") or h.get("Host")):
                     proto = h.get("X-Forwarded-Proto") or "http"
                     host = h.get("X-Forwarded-Host") or h.get("Host")
                     prefix = h.get("X-Forwarded-Prefix") or ""
                     if prefix and not prefix.startswith("/"):
                         prefix = "/" + prefix
                     if host:
                         gateway_url = f"{proto}://{host}{prefix}"

                 # Fallback to request's base_url which already includes scheme/host and any mount prefix
                 if gateway_url is None:
                     try:
                         if getattr(req, "base_url", None):
                             base_url = str(req.base_url).rstrip("/")
                             if base_url and base_url.lower() != "none":
                                 gateway_url = base_url
                     except Exception:
                         gateway_url = None
+
+                # Normalize and validate the URL (only http/https with netloc)
+                if isinstance(gateway_url, str):
+                    u = gateway_url.strip().rstrip("/")
+                    try:
+                        from urllib.parse import urlsplit, urlunsplit
+                        parts = urlsplit(u)
+                        if parts.scheme not in ("http", "https") or not parts.netloc:
+                            gateway_url = None
+                        else:
+                            gateway_url = urlunsplit((parts.scheme, parts.netloc, parts.path or "", "", ""))
+                    except Exception:
+                        gateway_url = None

             # Final fallback: environment variables (useful if proxies don't set headers)
             try:
-                import os as _os
-
                 if gateway_url is None:
-                    env_url = _os.environ.get("MCP_GATEWAY_URL")
+                    env_url = os.environ.get("MCP_GATEWAY_URL")
                     if env_url:
                         gateway_url = env_url
                 if gateway_token is None:
-                    env_tok = _os.environ.get("MCP_GATEWAY_TOKEN")
-                    if env_tok:
-                        gateway_token = env_tok
+                    if env_tok:
+                        gateway_token = env_tok
             except Exception:
                 pass

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/mcp_agent/server/app_server.py around lines 1394 to 1448, header-derived
gateway_url handling allows untrusted client-supplied URLs; restrict acceptance
to only when a valid gateway_token is present and validate/normalize the URL to
mitigate SSRF/data-exfil risks. Change logic so header_url
(X-MCP-Gateway-URL/X-Forwarded-Host/etc.) is used only if gateway_token is
non-empty and has been validated/verified for this request; otherwise ignore
header_url and fall back to req.base_url or environment. When accepting a
header-derived URL, parse and normalize it (use urllib.parse) and enforce:
scheme is http or https, netloc present, no embedded credentials, path
normalized (leading slash), and reject IP-literal hosts or resolve the hostname
and verify it is not a private/loopback address (use ipaddress and socket to
check); if any check fails, treat as invalid and fall back to req.base_url or
return an error. Ensure final gateway_url is rstrip("/") normalized before use.


if gateway_url or gateway_token:
workflow_memo = {
Expand Down
12 changes: 6 additions & 6 deletions tests/executor/temporal/test_execution_id_and_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,19 @@ async def post(self, url, json=None, headers=None):
client_proxy.httpx, "AsyncClient", lambda timeout: Client(rcodes)
)

ok = await client_proxy.log_via_proxy(None, "run", "info", "ns", "msg")
ok = await client_proxy.log_via_proxy("run", "info", "ns", "msg")
assert ok is True
ok = await client_proxy.log_via_proxy(None, "run", "info", "ns", "msg")
ok = await client_proxy.log_via_proxy("run", "info", "ns", "msg")
assert ok is False

# notify ok, then error
ok = await client_proxy.notify_via_proxy(None, "run", "m", {})
ok = await client_proxy.notify_via_proxy("run", "m", {})
assert ok is True
ok = await client_proxy.notify_via_proxy(None, "run", "m", {})
ok = await client_proxy.notify_via_proxy("run", "m", {})
assert ok is False

# request ok, then error
res = await client_proxy.request_via_proxy(None, "run", "m", {})
res = await client_proxy.request_via_proxy("run", "m", {})
assert isinstance(res, dict) and res.get("ok", True) in (True,)
res = await client_proxy.request_via_proxy(None, "run", "m", {})
res = await client_proxy.request_via_proxy("run", "m", {})
assert isinstance(res, dict) and "error" in res
Loading
Loading