Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/mcp_agent_server/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def main():
logger.info("Connecting to workflow server...")

# Override the server configuration to point to our local script
run_server_args = ["run", "basic_agent_server.py"]
run_server_args = ["run", "main.py"]
if use_custom_fastmcp_settings:
logger.info("Using custom FastMCP settings for the server.")
run_server_args += ["--custom-fastmcp-settings"]
Expand Down
211 changes: 203 additions & 8 deletions src/mcp_agent/server/app_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,40 @@ async def _register_session(run_id: str, execution_id: str, session: Any) -> Non
async with _RUN_SESSION_LOCK:
_RUN_SESSION_REGISTRY[execution_id] = session
_RUN_EXECUTION_ID_REGISTRY[run_id] = execution_id
try:
logger.debug(
f"Registered upstream session for run_id={run_id}, execution_id={execution_id}, session_id={id(session)}"
)
except Exception:
pass


async def _unregister_session(run_id: str) -> None:
async with _RUN_SESSION_LOCK:
execution_id = _RUN_EXECUTION_ID_REGISTRY.pop(run_id, None)
if execution_id:
_RUN_SESSION_REGISTRY.pop(execution_id, None)
try:
logger.debug(
f"Unregistered upstream session mapping for run_id={run_id}, execution_id={execution_id}"
)
except Exception:
pass


async def _get_session(execution_id: str) -> Any | None:
async with _RUN_SESSION_LOCK:
return _RUN_SESSION_REGISTRY.get(execution_id)
session = _RUN_SESSION_REGISTRY.get(execution_id)
try:
logger.debug(
(
f"Lookup session for execution_id={execution_id}: "
+ (f"found session_id={id(session)}" if session else "not found")
)
)
except Exception:
pass
return session


class ServerContext(ContextDependent):
Expand Down Expand Up @@ -322,6 +344,19 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]:

# Helper: install internal HTTP routes (not MCP tools)
def _install_internal_routes(mcp_server: FastMCP) -> None:
def _get_fallback_upstream_session() -> Any | None:
"""Best-effort fallback to the most recent upstream session captured on the app context.

This helps when a workflow run's mapping has not been refreshed after a client reconnect.
"""
try:
app_obj: MCPApp | None = _get_attached_app(mcp_server)
if app_obj and getattr(app_obj, "context", None) is not None:
return getattr(app_obj.context, "upstream_session", None)
except Exception:
return None
return None

@mcp_server.custom_route(
"/internal/session/by-run/{execution_id}/notify",
methods=["POST"],
Expand All @@ -332,6 +367,12 @@ async def _relay_notify(request: Request):
execution_id = request.path_params.get("execution_id")
method = body.get("method")
params = body.get("params") or {}
try:
logger.info(
f"[notify] incoming execution_id={execution_id} method={method} idempotency_key={params.get('idempotency_key')}"
)
except Exception:
pass

# Optional shared-secret auth
gw_token = os.environ.get("MCP_GATEWAY_TOKEN")
Expand Down Expand Up @@ -362,9 +403,28 @@ async def _relay_notify(request: Request):

session = await _get_session(execution_id)
if not session:
return JSONResponse(
{"ok": False, "error": "session_not_available"}, status_code=503
)
# Try a fallback upstream session from the app context (best-effort)
fallback = _get_fallback_upstream_session()
if fallback is not None:
try:
await _register_session(
run_id=execution_id,
execution_id=execution_id,
session=fallback,
)
session = fallback
logger.warning(
f"[notify] No mapped session for execution_id={execution_id}; used fallback upstream_session session_id={id(session)}"
)
except Exception:
session = None
if not session:
logger.warning(
f"[notify] session_not_available for execution_id={execution_id}"
)
return JSONResponse(
{"ok": False, "error": "session_not_available"}, status_code=503
)

try:
# Special-case the common logging notification helper
Expand All @@ -379,6 +439,12 @@ async def _relay_notify(request: Request):
logger=logger_name,
related_request_id=related_request_id,
)
try:
logger.debug(
f"[notify] forwarded notifications/message to session_id={id(session)}"
)
except Exception:
pass
elif method == "notifications/progress":
# Minimal support for progress relay
progress_token = params.get("progressToken")
Expand All @@ -391,11 +457,23 @@ async def _relay_notify(request: Request):
total=total,
message=message,
)
try:
logger.debug(
f"[notify] forwarded notifications/progress to session_id={id(session)}"
)
except Exception:
pass
else:
# Generic passthrough using low-level RPC if available
rpc = getattr(session, "rpc", None)
if rpc and hasattr(rpc, "notify"):
await rpc.notify(method, params)
try:
logger.debug(
f"[notify] forwarded generic notify '{method}' to session_id={id(session)}"
)
except Exception:
pass
else:
return JSONResponse(
{"ok": False, "error": f"unsupported method: {method}"},
Expand All @@ -404,9 +482,40 @@ async def _relay_notify(request: Request):

return JSONResponse({"ok": True})
except Exception as e:
# One more best-effort: if we failed once and haven't used fallback yet, try fallback
try:
fallback = _get_fallback_upstream_session()
if fallback is not None and fallback is not session:
try:
await fallback.send_log_message( # type: ignore[attr-defined]
level=str(params.get("level", "info")),
data=params.get("data") or {},
logger=params.get("logger"),
related_request_id=params.get("related_request_id"),
)
logger.warning(
f"[notify] primary session send failed; used fallback upstream_session session_id={id(fallback)}"
)
return JSONResponse({"ok": True, "fallback": True})
except Exception:
pass
except Exception:
pass
# After workflow cleanup, upstream sessions may be closed. Treat notify as best-effort.
if isinstance(method, str) and method.startswith("notifications/"):
try:
logger.warning(
f"[notify] dropped notification for execution_id={execution_id}: {e}"
)
except Exception:
pass
return JSONResponse({"ok": True, "dropped": True})
try:
logger.error(
f"[notify] error forwarding for execution_id={execution_id}: {e}"
)
except Exception:
pass
return JSONResponse({"ok": False, "error": str(e)}, status_code=500)

@mcp_server.custom_route(
Expand Down Expand Up @@ -436,13 +545,39 @@ async def _relay_request(request: Request):

session = await _get_session(execution_id)
if not session:
return JSONResponse({"error": "session_not_available"}, status_code=503)
fallback = _get_fallback_upstream_session()
if fallback is not None:
try:
await _register_session(
run_id=execution_id,
execution_id=execution_id,
session=fallback,
)
session = fallback
logger.warning(
f"[request] No mapped session for execution_id={execution_id}; used fallback upstream_session session_id={id(session)}"
)
except Exception:
session = None
if not session:
logger.warning(
f"[request] session_not_available for execution_id={execution_id}"
)
return JSONResponse(
{"error": "session_not_available"}, status_code=503
)

try:
# Prefer generic request passthrough if available
rpc = getattr(session, "rpc", None)
if rpc and hasattr(rpc, "request"):
result = await rpc.request(method, params)
try:
logger.debug(
f"[request] forwarded generic request '{method}' to session_id={id(session)}"
)
except Exception:
pass
return JSONResponse(result)
# Fallback: Map a small set of supported server->client requests
if method == "sampling/createMessage":
Expand Down Expand Up @@ -496,6 +631,12 @@ async def _relay_request(request: Request):
{"error": f"unsupported method: {method}"}, status_code=400
)
except Exception as e:
try:
logger.error(
f"[request] error forwarding for execution_id={execution_id} method={method}: {e}"
)
except Exception:
pass
return JSONResponse({"error": str(e)}, status_code=500)

@mcp_server.custom_route(
Expand Down Expand Up @@ -529,9 +670,27 @@ async def _internal_workflows_log(request: Request):

session = await _get_session(execution_id)
if not session:
return JSONResponse(
{"ok": False, "error": "session_not_available"}, status_code=503
)
fallback = _get_fallback_upstream_session()
if fallback is not None:
try:
await _register_session(
run_id=execution_id,
execution_id=execution_id,
session=fallback,
)
session = fallback
logger.warning(
f"[log] No mapped session for execution_id={execution_id}; used fallback upstream_session session_id={id(session)}"
)
except Exception:
session = None
if not session:
logger.warning(
f"[log] session_not_available for execution_id={execution_id}"
)
return JSONResponse(
{"ok": False, "error": "session_not_available"}, status_code=503
)
if level not in ("debug", "info", "warning", "error"):
level = "info"
try:
Expand All @@ -544,6 +703,12 @@ async def _internal_workflows_log(request: Request):
},
logger=namespace,
)
try:
logger.debug(
f"[log] forwarded workflow log to session_id={id(session)} level={level} ns={namespace}"
)
except Exception:
pass
return JSONResponse({"ok": True})
except Exception as e:
return JSONResponse({"ok": False, "error": str(e)}, status_code=500)
Expand Down Expand Up @@ -794,6 +959,16 @@ async def get_workflow_status(
_set_upstream_from_request_ctx_if_available(ctx)
except Exception:
pass
# Opportunistically re-bind the upstream session mapping to this request's session
try:
sess = getattr(ctx, "session", None)
if sess and run_id:
exec_id = _RUN_EXECUTION_ID_REGISTRY.get(run_id, run_id)
await _register_session(
run_id=run_id, execution_id=exec_id, session=sess
)
except Exception:
pass
return await _workflow_status(ctx, run_id=run_id, workflow_id=workflow_id)

@mcp.tool(name="workflows-resume")
Expand Down Expand Up @@ -828,6 +1003,16 @@ async def resume_workflow(
_set_upstream_from_request_ctx_if_available(ctx)
except Exception:
pass
# Re-bind mapping for this run to ensure worker notifies reach the current client session
try:
sess = getattr(ctx, "session", None)
if sess and run_id:
exec_id = _RUN_EXECUTION_ID_REGISTRY.get(run_id, run_id)
await _register_session(
run_id=run_id, execution_id=exec_id, session=sess
)
except Exception:
pass

if run_id is None and workflow_id is None:
raise ToolError("Either run_id or workflow_id must be provided.")
Expand Down Expand Up @@ -883,6 +1068,16 @@ async def cancel_workflow(
_set_upstream_from_request_ctx_if_available(ctx)
except Exception:
pass
# Re-bind mapping for this run to ensure worker notifies reach the current client session
try:
sess = getattr(ctx, "session", None)
if sess and run_id:
exec_id = _RUN_EXECUTION_ID_REGISTRY.get(run_id, run_id)
await _register_session(
run_id=run_id, execution_id=exec_id, session=sess
)
except Exception:
pass

if run_id is None and workflow_id is None:
raise ToolError("Either run_id or workflow_id must be provided.")
Expand Down
Loading