diff --git a/examples/mcp_agent_server/asyncio/client.py b/examples/mcp_agent_server/asyncio/client.py index 73d6a2453..271509e50 100644 --- a/examples/mcp_agent_server/asyncio/client.py +++ b/examples/mcp_agent_server/asyncio/client.py @@ -41,7 +41,7 @@ async def main(): logger.info("Connecting to workflow server...") # Override the server configuration to point to our local script - run_server_args = ["run", "basic_agent_server.py"] + run_server_args = ["run", "main.py"] if use_custom_fastmcp_settings: logger.info("Using custom FastMCP settings for the server.") run_server_args += ["--custom-fastmcp-settings"] diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 717b42c7d..c72296339 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -49,6 +49,12 @@ async def _register_session(run_id: str, execution_id: str, session: Any) -> Non async with _RUN_SESSION_LOCK: _RUN_SESSION_REGISTRY[execution_id] = session _RUN_EXECUTION_ID_REGISTRY[run_id] = execution_id + try: + logger.debug( + f"Registered upstream session for run_id={run_id}, execution_id={execution_id}, session_id={id(session)}" + ) + except Exception: + pass async def _unregister_session(run_id: str) -> None: @@ -56,11 +62,27 @@ async def _unregister_session(run_id: str) -> None: execution_id = _RUN_EXECUTION_ID_REGISTRY.pop(run_id, None) if execution_id: _RUN_SESSION_REGISTRY.pop(execution_id, None) + try: + logger.debug( + f"Unregistered upstream session mapping for run_id={run_id}, execution_id={execution_id}" + ) + except Exception: + pass async def _get_session(execution_id: str) -> Any | None: async with _RUN_SESSION_LOCK: - return _RUN_SESSION_REGISTRY.get(execution_id) + session = _RUN_SESSION_REGISTRY.get(execution_id) + try: + logger.debug( + ( + f"Lookup session for execution_id={execution_id}: " + + (f"found session_id={id(session)}" if session else "not found") + ) + ) + except Exception: + pass + return session class ServerContext(ContextDependent): @@ -322,6 +344,19 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: # Helper: install internal HTTP routes (not MCP tools) def _install_internal_routes(mcp_server: FastMCP) -> None: + def _get_fallback_upstream_session() -> Any | None: + """Best-effort fallback to the most recent upstream session captured on the app context. + + This helps when a workflow run's mapping has not been refreshed after a client reconnect. + """ + try: + app_obj: MCPApp | None = _get_attached_app(mcp_server) + if app_obj and getattr(app_obj, "context", None) is not None: + return getattr(app_obj.context, "upstream_session", None) + except Exception: + return None + return None + @mcp_server.custom_route( "/internal/session/by-run/{execution_id}/notify", methods=["POST"], @@ -332,6 +367,12 @@ async def _relay_notify(request: Request): execution_id = request.path_params.get("execution_id") method = body.get("method") params = body.get("params") or {} + try: + logger.info( + f"[notify] incoming execution_id={execution_id} method={method} idempotency_key={params.get('idempotency_key')}" + ) + except Exception: + pass # Optional shared-secret auth gw_token = os.environ.get("MCP_GATEWAY_TOKEN") @@ -360,54 +401,133 @@ async def _relay_notify(request: Request): return JSONResponse({"ok": True, "idempotent": True}) seen.add(idempotency_key) - session = await _get_session(execution_id) - if not session: + # Prefer latest upstream session first + latest_session = _get_fallback_upstream_session() + tried_latest = False + if latest_session is not None: + tried_latest = True + try: + if method == "notifications/message": + level = str(params.get("level", "info")) + data = params.get("data") + logger_name = params.get("logger") + related_request_id = params.get("related_request_id") + await latest_session.send_log_message( # type: ignore[attr-defined] + level=level, # type: ignore[arg-type] + data=data, + logger=logger_name, + related_request_id=related_request_id, + ) + logger.debug( + f"[notify] delivered via latest session_id={id(latest_session)} (message)" + ) + elif method == "notifications/progress": + progress_token = params.get("progressToken") + progress = params.get("progress") + total = params.get("total") + message = params.get("message") + await latest_session.send_progress_notification( # type: ignore[attr-defined] + progress_token=progress_token, + progress=progress, + total=total, + message=message, + ) + logger.debug( + f"[notify] delivered via latest session_id={id(latest_session)} (progress)" + ) + else: + rpc = getattr(latest_session, "rpc", None) + if rpc and hasattr(rpc, "notify"): + await rpc.notify(method, params) + logger.debug( + f"[notify] delivered via latest session_id={id(latest_session)} (generic '{method}')" + ) + else: + return JSONResponse( + {"ok": False, "error": f"unsupported method: {method}"}, + status_code=400, + ) + # Successful with latest → bind mapping for consistency + try: + await _register_session( + run_id=execution_id, + execution_id=execution_id, + session=latest_session, + ) + logger.info( + f"[notify] rebound mapping to latest session_id={id(latest_session)} for execution_id={execution_id}" + ) + except Exception: + pass + return JSONResponse({"ok": True}) + except Exception as e_latest: + logger.warning( + f"[notify] latest session delivery failed for execution_id={execution_id}: {e_latest}" + ) + + # Fallback to mapped session + mapped_session = await _get_session(execution_id) + if not mapped_session: + logger.warning( + f"[notify] session_not_available for execution_id={execution_id} (tried_latest={tried_latest})" + ) return JSONResponse( {"ok": False, "error": "session_not_available"}, status_code=503 ) try: - # Special-case the common logging notification helper if method == "notifications/message": level = str(params.get("level", "info")) data = params.get("data") logger_name = params.get("logger") related_request_id = params.get("related_request_id") - await session.send_log_message( # type: ignore[attr-defined] + await mapped_session.send_log_message( # type: ignore[attr-defined] level=level, # type: ignore[arg-type] data=data, logger=logger_name, related_request_id=related_request_id, ) + logger.debug( + f"[notify] delivered via mapped session_id={id(mapped_session)} (message)" + ) elif method == "notifications/progress": - # Minimal support for progress relay progress_token = params.get("progressToken") progress = params.get("progress") total = params.get("total") message = params.get("message") - await session.send_progress_notification( # type: ignore[attr-defined] + await mapped_session.send_progress_notification( # type: ignore[attr-defined] progress_token=progress_token, progress=progress, total=total, message=message, ) + logger.debug( + f"[notify] delivered via mapped session_id={id(mapped_session)} (progress)" + ) else: - # Generic passthrough using low-level RPC if available - rpc = getattr(session, "rpc", None) + rpc = getattr(mapped_session, "rpc", None) if rpc and hasattr(rpc, "notify"): await rpc.notify(method, params) + logger.debug( + f"[notify] delivered via mapped session_id={id(mapped_session)} (generic '{method}')" + ) else: return JSONResponse( {"ok": False, "error": f"unsupported method: {method}"}, status_code=400, ) - return JSONResponse({"ok": True}) - except Exception as e: - # After workflow cleanup, upstream sessions may be closed. Treat notify as best-effort. + except Exception as e_mapped: + # Best-effort for notifications if isinstance(method, str) and method.startswith("notifications/"): + logger.warning( + f"[notify] dropped notification for execution_id={execution_id}: {e_mapped}" + ) return JSONResponse({"ok": True, "dropped": True}) - return JSONResponse({"ok": False, "error": str(e)}, status_code=500) + logger.error( + f"[notify] error forwarding for execution_id={execution_id}: {e_mapped}" + ) + return JSONResponse({"ok": False, "error": str(e_mapped)}, status_code=500) @mcp_server.custom_route( "/internal/session/by-run/{execution_id}/request", @@ -433,16 +553,140 @@ async def _relay_request(request: Request): execution_id = request.path_params.get("execution_id") method = body.get("method") params = body.get("params") or {} + try: + logger.info( + f"[request] incoming execution_id={execution_id} method={method}" + ) + except Exception: + pass + + # Prefer latest upstream session first + latest_session = _get_fallback_upstream_session() + if latest_session is not None: + try: + rpc = getattr(latest_session, "rpc", None) + if rpc and hasattr(rpc, "request"): + result = await rpc.request(method, params) + logger.debug( + f"[request] delivered via latest session_id={id(latest_session)} (generic '{method}')" + ) + try: + await _register_session( + run_id=execution_id, + execution_id=execution_id, + session=latest_session, + ) + logger.info( + f"[request] rebound mapping to latest session_id={id(latest_session)} for execution_id={execution_id}" + ) + except Exception: + pass + return JSONResponse(result) + # If latest_session lacks rpc.request, try a limited mapping path + if method == "sampling/createMessage": + req = ServerRequest( + CreateMessageRequest( + method="sampling/createMessage", + params=CreateMessageRequestParams(**params), + ) + ) + result = await latest_session.send_request( # type: ignore[attr-defined] + request=req, + result_type=CreateMessageResult, + ) + try: + await _register_session( + run_id=execution_id, + execution_id=execution_id, + session=latest_session, + ) + except Exception: + pass + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + elif method == "elicitation/create": + req = ServerRequest( + ElicitRequest( + method="elicitation/create", + params=ElicitRequestParams(**params), + ) + ) + result = await latest_session.send_request( # type: ignore[attr-defined] + request=req, + result_type=ElicitResult, + ) + try: + await _register_session( + run_id=execution_id, + execution_id=execution_id, + session=latest_session, + ) + except Exception: + pass + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + elif method == "roots/list": + req = ServerRequest(ListRootsRequest(method="roots/list")) + result = await latest_session.send_request( # type: ignore[attr-defined] + request=req, + result_type=ListRootsResult, + ) + try: + await _register_session( + run_id=execution_id, + execution_id=execution_id, + session=latest_session, + ) + except Exception: + pass + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + elif method == "ping": + req = ServerRequest(PingRequest(method="ping")) + result = await latest_session.send_request( # type: ignore[attr-defined] + request=req, + result_type=EmptyResult, + ) + try: + await _register_session( + run_id=execution_id, + execution_id=execution_id, + session=latest_session, + ) + except Exception: + pass + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + except Exception as e_latest: + logger.warning( + f"[request] latest session delivery failed for execution_id={execution_id} method={method}: {e_latest}" + ) + # Fallback to mapped session session = await _get_session(execution_id) if not session: - return JSONResponse({"error": "session_not_available"}, status_code=503) + logger.warning( + f"[request] session_not_available for execution_id={execution_id}" + ) + return JSONResponse( + {"error": "session_not_available"}, status_code=503 + ) try: # Prefer generic request passthrough if available rpc = getattr(session, "rpc", None) if rpc and hasattr(rpc, "request"): result = await rpc.request(method, params) + try: + logger.debug( + f"[request] forwarded generic request '{method}' to session_id={id(session)}" + ) + except Exception: + pass return JSONResponse(result) # Fallback: Map a small set of supported server->client requests if method == "sampling/createMessage": @@ -496,6 +740,12 @@ async def _relay_request(request: Request): {"error": f"unsupported method: {method}"}, status_code=400 ) except Exception as e: + try: + logger.error( + f"[request] error forwarding for execution_id={execution_id} method={method}: {e}" + ) + except Exception: + pass return JSONResponse({"error": str(e)}, status_code=500) @mcp_server.custom_route( @@ -508,6 +758,12 @@ async def _internal_workflows_log(request: Request): namespace = body.get("namespace") or "mcp_agent" message = body.get("message") or "" data = body.get("data") or {} + try: + logger.info( + f"[log] incoming execution_id={execution_id} level={level} ns={namespace}" + ) + except Exception: + pass # Optional shared-secret auth gw_token = os.environ.get("MCP_GATEWAY_TOKEN") @@ -527,8 +783,45 @@ async def _internal_workflows_log(request: Request): {"ok": False, "error": "unauthorized"}, status_code=401 ) + # Prefer latest upstream session first + latest_session = _get_fallback_upstream_session() + if latest_session is not None: + try: + await latest_session.send_log_message( # type: ignore[attr-defined] + level=level, # type: ignore[arg-type] + data={ + "message": message, + "namespace": namespace, + "data": data, + }, + logger=namespace, + ) + logger.debug( + f"[log] delivered via latest session_id={id(latest_session)} level={level} ns={namespace}" + ) + try: + await _register_session( + run_id=execution_id, + execution_id=execution_id, + session=latest_session, + ) + logger.info( + f"[log] rebound mapping to latest session_id={id(latest_session)} for execution_id={execution_id}" + ) + except Exception: + pass + return JSONResponse({"ok": True}) + except Exception as e_latest: + logger.warning( + f"[log] latest session delivery failed for execution_id={execution_id}: {e_latest}" + ) + + # Fallback to mapped session session = await _get_session(execution_id) if not session: + logger.warning( + f"[log] session_not_available for execution_id={execution_id}" + ) return JSONResponse( {"ok": False, "error": "session_not_available"}, status_code=503 ) @@ -544,6 +837,12 @@ async def _internal_workflows_log(request: Request): }, logger=namespace, ) + try: + logger.debug( + f"[log] forwarded workflow log to session_id={id(session)} level={level} ns={namespace}" + ) + except Exception: + pass return JSONResponse({"ok": True}) except Exception as e: return JSONResponse({"ok": False, "error": str(e)}, status_code=500) @@ -556,6 +855,12 @@ async def _internal_human_prompts(request: Request): execution_id = body.get("execution_id") prompt = body.get("prompt") or {} metadata = body.get("metadata") or {} + try: + logger.info( + f"[human] incoming execution_id={execution_id} signal_name={metadata.get('signal_name','human_input')}" + ) + except Exception: + pass # Optional shared-secret auth gw_token = os.environ.get("MCP_GATEWAY_TOKEN") @@ -573,9 +878,8 @@ async def _internal_human_prompts(request: Request): ): return JSONResponse({"error": "unauthorized"}, status_code=401) - session = await _get_session(execution_id) - if not session: - return JSONResponse({"error": "session_not_available"}, status_code=503) + # Prefer latest upstream session first + latest_session = _get_fallback_upstream_session() import uuid request_id = str(uuid.uuid4()) @@ -594,6 +898,35 @@ async def _internal_human_prompts(request: Request): "signal_name": metadata.get("signal_name", "human_input"), "session_id": metadata.get("session_id"), } + # Try latest first + if latest_session is not None: + try: + await latest_session.send_log_message( # type: ignore[attr-defined] + level="info", # type: ignore[arg-type] + data=payload, + logger="mcp_agent.human", + ) + try: + await _register_session( + run_id=execution_id, + execution_id=execution_id, + session=latest_session, + ) + logger.info( + f"[human] rebound mapping to latest session_id={id(latest_session)} for execution_id={execution_id}" + ) + except Exception: + pass + return JSONResponse({"request_id": request_id}) + except Exception as e_latest: + logger.warning( + f"[human] latest session delivery failed for execution_id={execution_id}: {e_latest}" + ) + + # Fallback to mapped session + session = await _get_session(execution_id) + if not session: + return JSONResponse({"error": "session_not_available"}, status_code=503) await session.send_log_message( level="info", # type: ignore[arg-type] data=payload, @@ -794,6 +1127,16 @@ async def get_workflow_status( _set_upstream_from_request_ctx_if_available(ctx) except Exception: pass + # Opportunistically re-bind the upstream session mapping to this request's session + try: + sess = getattr(ctx, "session", None) + if sess and run_id: + exec_id = _RUN_EXECUTION_ID_REGISTRY.get(run_id, run_id) + await _register_session( + run_id=run_id, execution_id=exec_id, session=sess + ) + except Exception: + pass return await _workflow_status(ctx, run_id=run_id, workflow_id=workflow_id) @mcp.tool(name="workflows-resume") @@ -828,6 +1171,16 @@ async def resume_workflow( _set_upstream_from_request_ctx_if_available(ctx) except Exception: pass + # Re-bind mapping for this run to ensure worker notifies reach the current client session + try: + sess = getattr(ctx, "session", None) + if sess and run_id: + exec_id = _RUN_EXECUTION_ID_REGISTRY.get(run_id, run_id) + await _register_session( + run_id=run_id, execution_id=exec_id, session=sess + ) + except Exception: + pass if run_id is None and workflow_id is None: raise ToolError("Either run_id or workflow_id must be provided.") @@ -883,6 +1236,16 @@ async def cancel_workflow( _set_upstream_from_request_ctx_if_available(ctx) except Exception: pass + # Re-bind mapping for this run to ensure worker notifies reach the current client session + try: + sess = getattr(ctx, "session", None) + if sess and run_id: + exec_id = _RUN_EXECUTION_ID_REGISTRY.get(run_id, run_id) + await _register_session( + run_id=run_id, execution_id=exec_id, session=sess + ) + except Exception: + pass if run_id is None and workflow_id is None: raise ToolError("Either run_id or workflow_id must be provided.")