diff --git a/examples/mcp_agent_server/asyncio/README.md b/examples/mcp_agent_server/asyncio/README.md index 6c780dbeb..0ff69feda 100644 --- a/examples/mcp_agent_server/asyncio/README.md +++ b/examples/mcp_agent_server/asyncio/README.md @@ -258,6 +258,40 @@ def make_session(read_stream: MemoryObjectReceiveStream, The example client (`client.py`) demonstrates this end-to-end: it registers a logging callback and calls `set_logging_level("info")` so logs from the server appear in the client's console. +## Testing Specific Features + +The client supports feature flags to exercise subsets of functionality. Available flags: `workflows`, `tools`, `sampling`, `elicitation`, `notifications`, or `all`. + +Examples: + +``` +# Default (all features) +uv run client.py + +# Only workflows +uv run client.py --features workflows + +# Only tools +uv run client.py --features tools + +# Sampling + elicitation demos +uv run client.py --features sampling elicitation + +# Only notifications (server logs + other notifications) +uv run client.py --features notifications + +# Increase server logging verbosity +uv run client.py --server-log-level debug + +# Use custom FastMCP settings when launching the server +uv run client.py --custom-fastmcp-settings +``` + +Console output: + +- Server logs appear as lines prefixed with `[SERVER LOG] ...`. +- Other server-originated notifications (e.g., `notifications/progress`, `notifications/resources/list_changed`) appear as `[SERVER NOTIFY] : ...`. + ## MCP Clients Since the mcp-agent app is exposed as an MCP server, it can be used in any MCP client just diff --git a/examples/mcp_agent_server/asyncio/client.py b/examples/mcp_agent_server/asyncio/client.py index 271509e50..6c229098c 100644 --- a/examples/mcp_agent_server/asyncio/client.py +++ b/examples/mcp_agent_server/asyncio/client.py @@ -8,12 +8,24 @@ from mcp.types import CallToolResult, LoggingMessageNotificationParams from mcp_agent.app import MCPApp from mcp_agent.config import MCPServerSettings +from mcp_agent.core.context import Context from mcp_agent.executor.workflow import WorkflowExecution from mcp_agent.mcp.gen_client import gen_client from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession +from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.elicitation.handler import console_elicitation_callback from rich import print +try: + from exceptiongroup import ExceptionGroup as _ExceptionGroup # Python 3.10 backport +except Exception: # pragma: no cover + _ExceptionGroup = None # type: ignore +try: + from anyio import BrokenResourceError as _BrokenResourceError +except Exception: # pragma: no cover + _BrokenResourceError = None # type: ignore + async def main(): parser = argparse.ArgumentParser() @@ -28,11 +40,32 @@ async def main(): default=None, help="Set initial server logging level (debug, info, notice, warning, error, critical, alert, emergency)", ) + parser.add_argument( + "--features", + nargs="+", + choices=[ + "workflows", + "tools", + "sampling", + "elicitation", + "notifications", + "all", + ], + default=["all"], + help="Select which features to test", + ) args = parser.parse_args() use_custom_fastmcp_settings = args.custom_fastmcp_settings + selected = set(args.features) + if "all" in selected: + selected = {"workflows", "tools", "sampling", "elicitation", "notifications"} # Create MCPApp to get the server registry - app = MCPApp(name="workflow_mcp_client") + app = MCPApp( + name="workflow_mcp_client", + human_input_callback=console_input_callback, + elicitation_callback=console_elicitation_callback, + ) async with app.run() as client_app: logger = client_app.logger context = client_app.context @@ -54,233 +87,343 @@ async def main(): args=run_server_args, ) - # Connect to the workflow server # Define a logging callback to receive server-side log notifications async def on_server_log(params: LoggingMessageNotificationParams) -> None: - # Pretty-print server logs locally for demonstration level = params.level.upper() name = params.logger or "server" - # params.data can be any JSON-serializable data print(f"[SERVER LOG] [{level}] [{name}] {params.data}") # Provide a client session factory that installs our logging callback + # and prints non-logging notifications to the console + class ConsolePrintingClientSession(MCPAgentClientSession): + async def _received_notification(self, notification): # type: ignore[override] + try: + method = getattr(notification.root, "method", None) + except Exception: + method = None + + # Avoid duplicating server log prints (handled by logging_callback) + if method and method != "notifications/message": + try: + data = notification.model_dump() + except Exception: + data = str(notification) + print(f"[SERVER NOTIFY] {method}: {data}") + + return await super()._received_notification(notification) + def make_session( read_stream: MemoryObjectReceiveStream, write_stream: MemoryObjectSendStream, read_timeout_seconds: timedelta | None, + context: Context | None = None, ) -> ClientSession: - return MCPAgentClientSession( + return ConsolePrintingClientSession( read_stream=read_stream, write_stream=write_stream, read_timeout_seconds=read_timeout_seconds, logging_callback=on_server_log, + context=context, ) - async with gen_client( - "basic_agent_server", - context.server_registry, - client_session_factory=make_session, - ) as server: - # Ask server to send logs at the requested level (default info) - level = (args.server_log_level or "info").lower() - print(f"[client] Setting server logging level to: {level}") - try: - await server.set_logging_level(level) - except Exception: - # Older servers may not support logging capability - print("[client] Server does not support logging/setLevel") - - # List available tools - tools_result = await server.list_tools() - logger.info( - "Available tools:", - data={"tools": [tool.name for tool in tools_result.tools]}, - ) - - # List available workflows - logger.info("Fetching available workflows...") - workflows_response = await server.call_tool("workflows-list", {}) - logger.info( - "Available workflows:", - data=_tool_result_to_json(workflows_response) or workflows_response, - ) - - # Call the BasicAgentWorkflow (run + status) - run_result = await server.call_tool( - "workflows-BasicAgentWorkflow-run", - arguments={ - "run_parameters": { - "input": "Print the first two paragraphs of https://modelcontextprotocol.io/introduction." - } - }, - ) - - # Tolerant parsing of run IDs from tool result - run_payload = _tool_result_to_json(run_result) - if not run_payload: - sc = getattr(run_result, "structuredContent", None) - if isinstance(sc, dict): - run_payload = sc.get("result") or sc - if not run_payload: - # Last resort: parse unstructured content if present and non-empty - if getattr(run_result, "content", None) and run_result.content[0].text: - run_payload = json.loads(run_result.content[0].text) - else: - raise RuntimeError( - "Unable to extract workflow run IDs from tool result" - ) - - execution = WorkflowExecution(**run_payload) - run_id = execution.run_id - logger.info( - f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" - ) - - # Wait for the workflow to complete - while True: - get_status_result = await server.call_tool( - "workflows-BasicAgentWorkflow-get_status", - arguments={"run_id": run_id}, + try: + async with gen_client( + "basic_agent_server", + context.server_registry, + client_session_factory=make_session, + ) as server: + # Ask server to send logs at the requested level (default info) + level = (args.server_log_level or "info").lower() + print(f"[client] Setting server logging level to: {level}") + try: + await server.set_logging_level(level) + except Exception: + # Older servers may not support logging capability + print("[client] Server does not support logging/setLevel") + + # List available tools + tools_result = await server.list_tools() + logger.info( + "Available tools:", + data={"tools": [tool.name for tool in tools_result.tools]}, ) - # Tolerant parsing of get_status result - workflow_status = _tool_result_to_json(get_status_result) - if workflow_status is None: - sc = getattr(get_status_result, "structuredContent", None) - if isinstance(sc, dict): - workflow_status = sc.get("result") or sc - if workflow_status is None: - logger.error( - f"Failed to parse workflow status response: {get_status_result}" + # List available workflows + if "workflows" in selected: + logger.info("Fetching available workflows...") + workflows_response = await server.call_tool("workflows-list", {}) + logger.info( + "Available workflows:", + data=_tool_result_to_json(workflows_response) + or workflows_response, ) - break - - logger.info( - f"Workflow run {run_id} status:", - data=workflow_status, - ) - if not workflow_status.get("status"): - logger.error( - f"Workflow run {run_id} status is empty. get_status_result:", - data=get_status_result, + # Call the BasicAgentWorkflow (run + status) + if "workflows" in selected: + run_result = await server.call_tool( + "workflows-BasicAgentWorkflow-run", + arguments={ + "run_parameters": { + "input": "Print the first two paragraphs of https://modelcontextprotocol.io/introduction." + } + }, ) - break - if workflow_status.get("status") == "completed": + # Tolerant parsing of run IDs from tool result + run_payload = _tool_result_to_json(run_result) + if not run_payload: + sc = getattr(run_result, "structuredContent", None) + if isinstance(sc, dict): + run_payload = sc.get("result") or sc + if not run_payload: + # Last resort: parse unstructured content if present and non-empty + if ( + getattr(run_result, "content", None) + and run_result.content[0].text + ): + run_payload = json.loads(run_result.content[0].text) + else: + raise RuntimeError( + "Unable to extract workflow run IDs from tool result" + ) + + execution = WorkflowExecution(**run_payload) + run_id = execution.run_id logger.info( - f"Workflow run {run_id} completed successfully! Result:", - data=workflow_status.get("result"), + f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" ) - break - elif workflow_status.get("status") == "error": - logger.error( - f"Workflow run {run_id} failed with error:", - data=workflow_status, + + # Wait for the workflow to complete + while True: + get_status_result = await server.call_tool( + "workflows-BasicAgentWorkflow-get_status", + arguments={"run_id": run_id}, + ) + + # Tolerant parsing of get_status result + workflow_status = _tool_result_to_json(get_status_result) + if workflow_status is None: + sc = getattr(get_status_result, "structuredContent", None) + if isinstance(sc, dict): + workflow_status = sc.get("result") or sc + if workflow_status is None: + logger.error( + f"Failed to parse workflow status response: {get_status_result}" + ) + break + + logger.info( + f"Workflow run {run_id} status:", + data=workflow_status, + ) + + if not workflow_status.get("status"): + logger.error( + f"Workflow run {run_id} status is empty. get_status_result:", + data=get_status_result, + ) + break + + if workflow_status.get("status") == "completed": + logger.info( + f"Workflow run {run_id} completed successfully! Result:", + data=workflow_status.get("result"), + ) + break + elif workflow_status.get("status") == "error": + logger.error( + f"Workflow run {run_id} failed with error:", + data=workflow_status, + ) + break + elif workflow_status.get("status") == "running": + logger.info( + f"Workflow run {run_id} is still running...", + ) + elif workflow_status.get("status") == "cancelled": + logger.error( + f"Workflow run {run_id} was cancelled.", + data=workflow_status, + ) + break + else: + logger.error( + f"Unknown workflow status: {workflow_status.get('status')}", + data=workflow_status, + ) + break + + await asyncio.sleep(5) + + # Get the token usage summary + logger.info("Fetching token usage summary...") + token_usage_result = await server.call_tool( + "get_token_usage", + arguments={ + "run_id": run_id, + "workflow_id": execution.workflow_id, + }, ) - break - elif workflow_status.get("status") == "running": + logger.info( - f"Workflow run {run_id} is still running...", + "Token usage summary:", + data=_tool_result_to_json(token_usage_result) + or token_usage_result, ) - elif workflow_status.get("status") == "cancelled": - logger.error( - f"Workflow run {run_id} was cancelled.", - data=workflow_status, - ) - break - else: - logger.error( - f"Unknown workflow status: {workflow_status.get('status')}", - data=workflow_status, - ) - break - - await asyncio.sleep(5) - - # Get the token usage summary - logger.info("Fetching token usage summary...") - token_usage_result = await server.call_tool( - "get_token_usage", - arguments={ - "run_id": run_id, - "workflow_id": execution.workflow_id, - }, - ) - logger.info( - "Token usage summary:", - data=_tool_result_to_json(token_usage_result) or token_usage_result, - ) + # Display the token usage summary + print(token_usage_result.structuredContent) - # Display the token usage summary - print(token_usage_result.structuredContent) + await asyncio.sleep(1) - await asyncio.sleep(5) + # Call the sync tool 'grade_story' separately (no run/status loop) + if "tools" in selected: + try: + grade_result = await server.call_tool( + "grade_story", + arguments={"story": "This is a test story."}, + ) + grade_payload = _tool_result_to_json(grade_result) or ( + ( + grade_result.structuredContent.get("result") + if getattr(grade_result, "structuredContent", None) + else None + ) + or ( + grade_result.content[0].text + if grade_result.content + else None + ) + ) + logger.info("grade_story result:", data=grade_payload) + except Exception as e: + logger.error("grade_story call failed", data=str(e)) + + # Call the async tool 'grade_story_async': start then poll status + if "tools" in selected: + try: + async_run_result = await server.call_tool( + "grade_story_async", + arguments={"story": "This is a test story."}, + ) + async_ids = ( + ( + getattr(async_run_result, "structuredContent", {}) or {} + ).get("result") + or _tool_result_to_json(async_run_result) + or json.loads(async_run_result.content[0].text) + ) + async_run_id = async_ids["run_id"] + logger.info( + f"Started grade_story_async. run ID={async_run_id}", + ) - # Call the sync tool 'grade_story' separately (no run/status loop) - try: - grade_result = await server.call_tool( - "grade_story", - arguments={"story": "This is a test story."}, - ) - grade_payload = _tool_result_to_json(grade_result) or ( - ( - grade_result.structuredContent.get("result") - if getattr(grade_result, "structuredContent", None) - else None - ) - or (grade_result.content[0].text if grade_result.content else None) - ) - logger.info("grade_story result:", data=grade_payload) - except Exception as e: - logger.error("grade_story call failed", data=str(e)) - - # Call the async tool 'grade_story_async': start then poll status - try: - async_run_result = await server.call_tool( - "grade_story_async", - arguments={"story": "This is a test story."}, - ) - async_ids = ( - (getattr(async_run_result, "structuredContent", {}) or {}).get( - "result" - ) - or _tool_result_to_json(async_run_result) - or json.loads(async_run_result.content[0].text) - ) - async_run_id = async_ids["run_id"] - logger.info( - f"Started grade_story_async. run ID={async_run_id}", + # Poll status until completion + while True: + async_status = await server.call_tool( + "workflows-get_status", + arguments={"run_id": async_run_id}, + ) + async_status_json = ( + getattr(async_status, "structuredContent", {}) or {} + ).get("result") or _tool_result_to_json(async_status) + if async_status_json is None: + logger.error( + "grade_story_async: failed to parse status", + data=async_status, + ) + break + logger.info( + "grade_story_async status:", data=async_status_json + ) + if async_status_json.get("status") in ( + "completed", + "error", + "cancelled", + ): + break + await asyncio.sleep(2) + except Exception as e: + logger.error("grade_story_async call failed", data=str(e)) + + # Sampling demo via app.tool + if "sampling" in selected: + try: + demo = await server.call_tool( + "sampling_demo", arguments={"topic": "flowers"} + ) + logger.info( + "sampling_demo result:", + data=_tool_result_to_json(demo) or demo, + ) + except Exception as e: + logger.error("sampling_demo failed", data=str(e)) + + # Elicitation demo via app.tool + if "elicitation" in selected: + try: + el = await server.call_tool( + "elicitation_demo", arguments={"action": "proceed"} + ) + logger.info( + "elicitation_demo result:", + data=_tool_result_to_json(el) or el, + ) + except Exception as e: + logger.error("elicitation_demo failed", data=str(e)) + + # Notifications demo via app.tool + if "notifications" in selected: + try: + n1 = await server.call_tool("notify_resources", arguments={}) + logger.info( + "notify_resources result:", + data=_tool_result_to_json(n1) or n1, + ) + n2 = await server.call_tool( + "notify_progress", + arguments={"progress": 0.5, "message": "Halfway there"}, + ) + logger.info( + "notify_progress result:", + data=_tool_result_to_json(n2) or n2, + ) + except Exception as e: + logger.error("notifications demo failed", data=str(e)) + except Exception as e: + # Tolerate benign shutdown races from stdio client (BrokenResourceError within ExceptionGroup) + if _ExceptionGroup is not None and isinstance(e, _ExceptionGroup): + subs = getattr(e, "exceptions", []) or [] + if ( + _BrokenResourceError is not None + and subs + and all(isinstance(se, _BrokenResourceError) for se in subs) + ): + logger.debug("Ignored BrokenResourceError from stdio shutdown") + else: + raise + elif _BrokenResourceError is not None and isinstance( + e, _BrokenResourceError + ): + logger.debug("Ignored BrokenResourceError from stdio shutdown") + elif "BrokenResourceError" in str(e): + logger.debug( + "Ignored BrokenResourceError from stdio shutdown (string match)" ) + else: + raise + # Nudge cleanup of subprocess transports before the loop closes to avoid + # 'Event loop is closed' from BaseSubprocessTransport.__del__ on GC. + try: + await asyncio.sleep(0) + except Exception: + pass + try: + import gc - # Poll status until completion - while True: - async_status = await server.call_tool( - "workflows-get_status", - arguments={"run_id": async_run_id}, - ) - async_status_json = ( - getattr(async_status, "structuredContent", {}) or {} - ).get("result") or _tool_result_to_json(async_status) - if async_status_json is None: - logger.error( - "grade_story_async: failed to parse status", - data=async_status, - ) - break - logger.info("grade_story_async status:", data=async_status_json) - if async_status_json.get("status") in ( - "completed", - "error", - "cancelled", - ): - break - await asyncio.sleep(2) - except Exception as e: - logger.error("grade_story_async call failed", data=str(e)) - - await asyncio.sleep(5) + gc.collect() + except Exception: + pass def _tool_result_to_json(tool_result: CallToolResult): diff --git a/examples/mcp_agent_server/asyncio/main.py b/examples/mcp_agent_server/asyncio/main.py index f49ec6f62..0d5350e12 100644 --- a/examples/mcp_agent_server/asyncio/main.py +++ b/examples/mcp_agent_server/asyncio/main.py @@ -25,6 +25,10 @@ from mcp_agent.workflows.parallel.parallel_llm import ParallelLLM from mcp_agent.executor.workflow import Workflow, WorkflowResult from mcp_agent.tracing.token_counter import TokenNode +from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.elicitation.handler import console_elicitation_callback +from mcp_agent.mcp.gen_client import gen_client +from mcp_agent.config import MCPServerSettings # Note: This is purely optional: # if not provided, a default FastMCP server will be created by MCPApp using create_mcp_server_for_app() @@ -36,6 +40,8 @@ name="basic_agent_server", description="Basic agent server example", mcp=mcp, + human_input_callback=console_input_callback, # enable approval prompts for local sampling + elicitation_callback=console_elicitation_callback, # enable console-driven elicitation ) @@ -109,6 +115,114 @@ async def run(self, input: str) -> WorkflowResult[str]: return WorkflowResult(value=result) +@app.tool(name="sampling_demo") +async def sampling_demo(topic: str, app_ctx: Optional[AppContext] = None) -> str: + """ + Demonstrate MCP sampling via a nested MCP server tool. + + - In asyncio (no upstream client), this triggers local sampling with a human approval prompt. + - When an MCP client is connected, the sampling request is proxied upstream. + """ + _app = app_ctx.app if app_ctx else app + + # Register a simple nested server that uses sampling in its get_haiku tool + nested_name = "nested_sampling" + nested_path = os.path.abspath( + os.path.join( + os.path.dirname(__file__), "..", "shared", "nested_sampling_server.py" + ) + ) + _app.context.config.mcp.servers[nested_name] = MCPServerSettings( + name=nested_name, + command="uv", + args=["run", nested_path], + description="Nested server providing a haiku generator using sampling", + ) + + # Connect as an MCP client to the nested server and call its sampling tool + async with gen_client( + nested_name, _app.context.server_registry, context=_app.context + ) as client: + result = await client.call_tool("get_haiku", {"topic": topic}) + + # Extract text content from CallToolResult + try: + if result.content and len(result.content) > 0: + return result.content[0].text or "" + except Exception: + pass + return "" + + +@app.tool(name="elicitation_demo") +async def elicitation_demo( + action: str = "proceed", app_ctx: Optional[AppContext] = None +) -> str: + """ + Demonstrate MCP elicitation via a nested MCP server tool. + + - In asyncio (no upstream client), this triggers local elicitation handled by console. + - When an MCP client is connected, the elicitation request is proxied upstream. + """ + _app = app_ctx.app if app_ctx else app + + nested_name = "nested_elicitation" + nested_path = os.path.abspath( + os.path.join( + os.path.dirname(__file__), "..", "shared", "nested_elicitation_server.py" + ) + ) + _app.context.config.mcp.servers[nested_name] = MCPServerSettings( + name=nested_name, + command="uv", + args=["run", nested_path], + description="Nested server demonstrating elicitation", + ) + + async with gen_client( + nested_name, _app.context.server_registry, context=_app.context + ) as client: + result = await client.call_tool("confirm_action", {"action": action}) + try: + if result.content and len(result.content) > 0: + return result.content[0].text or "" + except Exception: + pass + return "" + + +@app.tool(name="notify_resources") +async def notify_resources(app_ctx: Optional[AppContext] = None) -> str: + """Trigger a non-logging resource list changed notification.""" + _app = app_ctx.app if app_ctx else app + upstream = getattr(_app.context, "upstream_session", None) + if upstream is None: + _app.logger.warning("No upstream session to notify") + return "no-upstream" + await upstream.send_resource_list_changed() + _app.logger.info("Sent notifications/resources/list_changed") + return "ok" + + +@app.tool(name="notify_progress") +async def notify_progress( + progress: float = 0.5, + message: str | None = "Asyncio progress demo", + app_ctx: Optional[AppContext] = None, +) -> str: + """Trigger a non-logging progress notification.""" + _app = app_ctx.app if app_ctx else app + upstream = getattr(_app.context, "upstream_session", None) + if upstream is None: + _app.logger.warning("No upstream session to notify") + return "no-upstream" + await upstream.send_progress_notification( + progress_token="asyncio-demo", progress=progress, message=message + ) + _app.logger.info("Sent notifications/progress") + return "ok" + + @app.tool async def grade_story(story: str, app_ctx: Optional[AppContext] = None) -> str: """ diff --git a/examples/mcp_agent_server/shared/nested_elicitation_server.py b/examples/mcp_agent_server/shared/nested_elicitation_server.py new file mode 100644 index 000000000..34d477d83 --- /dev/null +++ b/examples/mcp_agent_server/shared/nested_elicitation_server.py @@ -0,0 +1,31 @@ +from pydantic import BaseModel +from mcp.server.fastmcp import FastMCP +from mcp.server.elicitation import elicit_with_validation, AcceptedElicitation + +mcp = FastMCP("Nested Elicitation Server") + + +class Confirmation(BaseModel): + confirm: bool + + +@mcp.tool() +async def confirm_action(action: str) -> str: + """Ask the user to confirm an action via elicitation.""" + ctx = mcp.get_context() + res = await elicit_with_validation( + ctx.session, + message=f"Do you want to {action}?", + schema=Confirmation, + ) + if isinstance(res, AcceptedElicitation) and res.data.confirm: + return f"Action '{action}' confirmed by user" + return f"Action '{action}' declined by user" + + +def main(): + mcp.run() + + +if __name__ == "__main__": + main() diff --git a/examples/mcp_agent_server/shared/nested_sampling_server.py b/examples/mcp_agent_server/shared/nested_sampling_server.py new file mode 100644 index 000000000..32953079f --- /dev/null +++ b/examples/mcp_agent_server/shared/nested_sampling_server.py @@ -0,0 +1,40 @@ +from mcp.server.fastmcp import FastMCP +from mcp.types import ModelPreferences, ModelHint, SamplingMessage, TextContent + +mcp = FastMCP("Nested Sampling Server") + + +@mcp.tool() +async def get_haiku(topic: str) -> str: + """Use MCP sampling to generate a haiku about the given topic.""" + result = await mcp.get_context().session.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent( + type="text", text=f"Generate a quirky haiku about {topic}." + ), + ) + ], + system_prompt="You are a poet.", + max_tokens=100, + temperature=0.7, + model_preferences=ModelPreferences( + hints=[ModelHint(name="gpt-4o-mini")], + costPriority=0.1, + speedPriority=0.8, + intelligencePriority=0.1, + ), + ) + + if isinstance(result.content, TextContent): + return result.content.text + return "Haiku generation failed" + + +def main(): + mcp.run() + + +if __name__ == "__main__": + main() diff --git a/examples/mcp_agent_server/temporal/README.md b/examples/mcp_agent_server/temporal/README.md index da76a7270..0ec7a378b 100644 --- a/examples/mcp_agent_server/temporal/README.md +++ b/examples/mcp_agent_server/temporal/README.md @@ -149,6 +149,37 @@ To run this example, you'll need to: uv run client.py ``` +### Testing Specific Features + +The Temporal client supports feature flags to exercise subsets of functionality. Available flags: `workflows`, `tools`, `sampling`, `elicitation`, `notifications`, or `all`. + +Examples: + +```bash +# Default (all features) +uv run client.py + +# Only workflows +uv run client.py --features workflows + +# Only tools +uv run client.py --features tools + +# Sampling + elicitation workflows +uv run client.py --features sampling elicitation + +# Only notifications-related workflow +uv run client.py --features notifications + +# Increase server logging verbosity seen by the client +uv run client.py --server-log-level debug +``` + +Console output: + +- Server logs appear as lines prefixed with `[SERVER LOG] ...`. +- Other server-originated notifications (e.g., `notifications/progress`, `notifications/resources/list_changed`) appear as `[SERVER NOTIFY] : ...`. + ## Advanced Features with Temporal ### Workflow Signals diff --git a/examples/mcp_agent_server/temporal/basic_agent_server.py b/examples/mcp_agent_server/temporal/basic_agent_server.py index f4368bb54..468a01430 100644 --- a/examples/mcp_agent_server/temporal/basic_agent_server.py +++ b/examples/mcp_agent_server/temporal/basic_agent_server.py @@ -19,13 +19,23 @@ from mcp_agent.server.app_server import create_mcp_server_for_app from mcp_agent.executor.workflow import Workflow, WorkflowResult from mcp_agent.workflows.llm.augmented_llm_openai import OpenAIAugmentedLLM +from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.elicitation.handler import console_elicitation_callback +from mcp_agent.mcp.gen_client import gen_client +from mcp_agent.config import MCPServerSettings +from mcp.types import SamplingMessage, TextContent, ModelPreferences, ModelHint # Initialize logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Create a single FastMCPApp instance (which extends MCPApp) -app = MCPApp(name="basic_agent_server", description="Basic agent server example") +app = MCPApp( + name="basic_agent_server", + description="Basic agent server example", + human_input_callback=console_input_callback, # for local sampling approval + elicitation_callback=console_elicitation_callback, # for local elicitation +) @app.workflow @@ -62,7 +72,9 @@ async def run( # Use of the app.logger will forward logs back to the mcp client app_logger = app.logger - app_logger.info("Starting finder agent") + app_logger.info( + "[workflow-mode] Starting finder agent in BasicAgentWorkflow.run" + ) async with finder_agent: finder_llm = await finder_agent.attach_llm(OpenAIAugmentedLLM) @@ -71,7 +83,9 @@ async def run( ) # forwards the log to the caller - app_logger.info(f"Finder agent completed with result {result}") + app_logger.info( + f"[workflow-mode] Finder agent completed with result {result}" + ) # print to the console (for when running locally) print(f"Agent result: {result}") return WorkflowResult(value=result) @@ -97,7 +111,7 @@ async def finder_tool(request: str, app_ctx: Context | None = None) -> str: app = app_ctx.app logger = app.logger - logger.info(f"Running finder_tool with input: {request}") + logger.info("[workflow-mode] Running finder_tool", data={"input": request}) finder_agent = Agent( name="finder", @@ -114,7 +128,7 @@ async def finder_tool(request: str, app_ctx: Context | None = None) -> str: result = await finder_llm.generate_str( message=request, ) - logger.info(f"Agent result: {result}") + logger.info("[workflow-mode] finder_tool agent result", data={"result": result}) return result @@ -156,6 +170,217 @@ async def run( return WorkflowResult(value=result) +@app.workflow_task(name="call_nested_sampling") +async def call_nested_sampling(topic: str) -> str: + """Activity: call a nested MCP server tool that uses sampling.""" + app_ctx: Context = app.context + app_ctx.app.logger.info( + "[activity-mode] call_nested_sampling starting", + data={"topic": topic}, + ) + nested_name = "nested_sampling" + nested_path = os.path.abspath( + os.path.join( + os.path.dirname(__file__), "..", "shared", "nested_sampling_server.py" + ) + ) + app_ctx.config.mcp.servers[nested_name] = MCPServerSettings( + name=nested_name, + command="uv", + args=["run", nested_path], + description="Nested server providing a haiku generator using sampling", + ) + + async with gen_client( + nested_name, app_ctx.server_registry, context=app_ctx + ) as client: + app_ctx.app.logger.info( + "[activity-mode] call_nested_sampling connected to nested server" + ) + result = await client.call_tool("get_haiku", {"topic": topic}) + app_ctx.app.logger.info( + "[activity-mode] call_nested_sampling received result", + data={"structured": getattr(result, "structuredContent", None)}, + ) + try: + if result.content and len(result.content) > 0: + return result.content[0].text or "" + except Exception: + pass + return "" + + +@app.workflow_task(name="call_nested_elicitation") +async def call_nested_elicitation(action: str) -> str: + """Activity: call a nested MCP server tool that triggers elicitation.""" + app_ctx: Context = app.context + app_ctx.app.logger.info( + "[activity-mode] call_nested_elicitation starting", + data={"action": action}, + ) + nested_name = "nested_elicitation" + nested_path = os.path.abspath( + os.path.join( + os.path.dirname(__file__), "..", "shared", "nested_elicitation_server.py" + ) + ) + app_ctx.config.mcp.servers[nested_name] = MCPServerSettings( + name=nested_name, + command="uv", + args=["run", nested_path], + description="Nested server demonstrating elicitation", + ) + + async with gen_client( + nested_name, app_ctx.server_registry, context=app_ctx + ) as client: + app_ctx.app.logger.info( + "[activity-mode] call_nested_elicitation connected to nested server" + ) + result = await client.call_tool("confirm_action", {"action": action}) + app_ctx.app.logger.info( + "[activity-mode] call_nested_elicitation received result", + data={"structured": getattr(result, "structuredContent", None)}, + ) + try: + if result.content and len(result.content) > 0: + return result.content[0].text or "" + except Exception: + pass + return "" + + +@app.workflow +class SamplingWorkflow(Workflow[str]): + """Temporal workflow that triggers an MCP sampling request via a nested server.""" + + @app.workflow_run + async def run(self, input: str = "space exploration") -> WorkflowResult[str]: + app.logger.info( + "[workflow-mode] SamplingWorkflow starting", + data={"note": "direct sampling via SessionProxy, then activity sampling"}, + ) + # 1) Direct workflow sampling via SessionProxy (will schedule mcp_relay_request activity) + app.logger.info( + "[workflow-mode] SessionProxy.create_message (direct)", + data={"path": "mcp_relay_request activity"}, + ) + direct_text = "" + try: + direct = await app.context.upstream_session.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent( + type="text", text=f"Write a haiku about {input}." + ), + ) + ], + system_prompt="You are a poet.", + max_tokens=80, + model_preferences=ModelPreferences( + hints=[ModelHint(name="gpt-4o-mini")], + costPriority=0.1, + speedPriority=0.8, + intelligencePriority=0.1, + ), + ) + try: + direct_text = ( + direct.content.text + if isinstance(direct.content, TextContent) + else "" + ) + except Exception: + direct_text = "" + except Exception as e: + app.logger.warning( + "[workflow-mode] Direct sampling failed; continuing with nested", + data={"error": str(e)}, + ) + app.logger.info( + "[workflow-mode] Direct sampling result", + data={"text": direct_text}, + ) + + # 2) Nested server sampling executed as an activity + app.logger.info( + "[activity-mode] Invoking call_nested_sampling via executor.execute", + data={"topic": input}, + ) + result = await app.context.executor.execute(call_nested_sampling, input) + # Log and return + app.logger.info( + "[activity-mode] Nested sampling result", + data={"text": result}, + ) + return WorkflowResult(value=f"direct={direct_text}\nnested={result}") + + +@app.workflow +class ElicitationWorkflow(Workflow[str]): + """Temporal workflow that triggers elicitation via direct session and nested server.""" + + @app.workflow_run + async def run(self, input: str = "proceed") -> WorkflowResult[str]: + app.logger.info( + "[workflow-mode] ElicitationWorkflow starting", + data={"note": "direct elicit via SessionProxy, then activity elicitation"}, + ) + + # 1) Direct elicitation via SessionProxy (schedules mcp_relay_request) + schema = { + "type": "object", + "properties": {"confirm": {"type": "boolean"}}, + "required": ["confirm"], + } + app.logger.info( + "[workflow-mode] SessionProxy.elicit (direct)", + data={"path": "mcp_relay_request activity"}, + ) + direct = await app.context.upstream_session.elicit( + message=f"Do you want to {input}?", + requestedSchema=schema, + ) + direct_text = f"accepted={getattr(direct, 'action', '')}" + + # 2) Nested elicitation via activity + app.logger.info( + "[activity-mode] Invoking call_nested_elicitation via executor.execute", + data={"action": input}, + ) + nested = await app.context.executor.execute(call_nested_elicitation, input) + + app.logger.info( + "[workflow-mode] Elicitation results", + data={"direct": direct_text, "nested": nested}, + ) + return WorkflowResult(value=f"direct={direct_text}\nnested={nested}") + + +@app.workflow +class NotificationsWorkflow(Workflow[str]): + """Temporal workflow that triggers non-logging notifications via proxy.""" + + @app.workflow_run + async def run(self, input: str = "notifications-demo") -> WorkflowResult[str]: + app.logger.info( + "[workflow-mode] NotificationsWorkflow starting; sending notifications via SessionProxy", + data={"path": "mcp_relay_notify activity"}, + ) + # These calls occur inside workflow and will use SessionProxy -> mcp_relay_notify activity + app.logger.info( + "[workflow-mode] send_progress_notification", + data={"token": f"{input}-token", "progress": 0.25}, + ) + await app.context.upstream_session.send_progress_notification( + progress_token=f"{input}-token", progress=0.25, message="Quarter complete" + ) + app.logger.info("[workflow-mode] send_resource_list_changed") + await app.context.upstream_session.send_resource_list_changed() + return WorkflowResult(value="ok") + + async def main(): async with app.run() as agent_app: # Log registered workflows and agent configurations diff --git a/examples/mcp_agent_server/temporal/client.py b/examples/mcp_agent_server/temporal/client.py index c945e1cd2..b978b36dd 100644 --- a/examples/mcp_agent_server/temporal/client.py +++ b/examples/mcp_agent_server/temporal/client.py @@ -3,10 +3,13 @@ import time import argparse from mcp_agent.app import MCPApp +from mcp_agent.config import Settings, LoggerSettings, MCPSettings +import yaml +from mcp_agent.elicitation.handler import console_elicitation_callback from mcp_agent.config import MCPServerSettings +from mcp_agent.core.context import Context from mcp_agent.executor.workflow import WorkflowExecution from mcp_agent.mcp.gen_client import gen_client - from datetime import timedelta from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession @@ -31,9 +34,81 @@ async def main(): default=None, help="Set server logging level (debug, info, notice, warning, error, critical, alert, emergency)", ) + parser.add_argument( + "--features", + nargs="+", + choices=[ + "workflows", + "tools", + "sampling", + "elicitation", + "notifications", + "all", + ], + default=["all"], + help="Select which features to test", + ) args = parser.parse_args() - # Create MCPApp to get the server registry - app = MCPApp(name="workflow_mcp_client") + selected = set(args.features) + if "all" in selected: + selected = {"workflows", "tools", "sampling", "elicitation", "notifications"} + # Create MCPApp to get the server registry, with console handlers + # IMPORTANT: This client acts as the “upstream MCP client” for the server. + # When the server requests sampling (sampling/createMessage), the client-side + # MCPApp must be able to service that request locally (approval prompts + LLM call). + # Those client-local flows are not running inside a Temporal workflow, so they + # must use the asyncio executor. If this were set to "temporal", local sampling + # would crash with: "TemporalExecutor.execute must be called from within a workflow". + # + # We programmatically construct Settings here (mirroring examples/basic/mcp_basic_agent/main.py) + # so everything is self-contained in this client: + settings = Settings( + execution_engine="asyncio", + logger=LoggerSettings(level="info"), + mcp=MCPSettings( + servers={ + "basic_agent_server": MCPServerSettings( + name="basic_agent_server", + description="Local workflow server running the basic agent example", + transport="sse", + # Use a routable loopback host; 0.0.0.0 is a bind address, not a client URL + url="http://127.0.0.1:8000/sse", + ) + } + ), + ) + # Load secrets (API keys, etc.) if a secrets file is available and merge into settings. + # We intentionally deep-merge the secrets on top of our base settings so + # credentials are applied without overriding our executor or server endpoint. + try: + secrets_path = Settings.find_secrets() + if secrets_path and secrets_path.exists(): + with open(secrets_path, "r", encoding="utf-8") as f: + secrets_dict = yaml.safe_load(f) or {} + + def _deep_merge(base: dict, overlay: dict) -> dict: + out = dict(base) + for k, v in (overlay or {}).items(): + if k in out and isinstance(out[k], dict) and isinstance(v, dict): + out[k] = _deep_merge(out[k], v) + else: + out[k] = v + return out + + base_dict = settings.model_dump(mode="json") + merged = _deep_merge(base_dict, secrets_dict) + settings = Settings(**merged) + except Exception: + # Best-effort: continue without secrets if parsing fails + pass + app = MCPApp( + name="workflow_mcp_client", + # Disable sampling approval prompts entirely to keep flows non-interactive. + # Elicitation remains interactive via console_elicitation_callback. + human_input_callback=None, + elicitation_callback=console_elicitation_callback, + settings=settings, + ) async with app.run() as client_app: logger = client_app.logger context = client_app.context @@ -42,13 +117,7 @@ async def main(): try: logger.info("Connecting to workflow server...") - # Override the server configuration to point to our local script - context.server_registry.registry["basic_agent_server"] = MCPServerSettings( - name="basic_agent_server", - description="Local workflow server running the basic agent example", - transport="sse", - url="http://0.0.0.0:8000/sse", - ) + # Server connection is configured via Settings above (no runtime mutation needed) # Connect to the workflow server # Define a logging callback to receive server-side log notifications @@ -60,16 +129,36 @@ async def on_server_log(params: LoggingMessageNotificationParams) -> None: print(f"[SERVER LOG] [{level}] [{name}] {params.data}") # Provide a client session factory that installs our logging callback + # and prints non-logging notifications to the console + class ConsolePrintingClientSession(MCPAgentClientSession): + async def _received_notification(self, notification): # type: ignore[override] + try: + method = getattr(notification.root, "method", None) + except Exception: + method = None + + # Avoid duplicating server log prints (handled by logging_callback) + if method and method != "notifications/message": + try: + data = notification.model_dump() + except Exception: + data = str(notification) + print(f"[SERVER NOTIFY] {method}: {data}") + + return await super()._received_notification(notification) + def make_session( read_stream: MemoryObjectReceiveStream, write_stream: MemoryObjectSendStream, read_timeout_seconds: timedelta | None, + context: Context | None = None, ) -> ClientSession: - return MCPAgentClientSession( + return ConsolePrintingClientSession( read_stream=read_stream, write_stream=write_stream, read_timeout_seconds=read_timeout_seconds, logging_callback=on_server_log, + context=context, ) # Connect to the workflow server @@ -87,89 +176,83 @@ def make_session( # Older servers may not support logging capability print("[client] Server does not support logging/setLevel") # Call the BasicAgentWorkflow - run_result = await server.call_tool( - "workflows-BasicAgentWorkflow-run", - arguments={ - "run_parameters": { - "input": "Print the first 2 paragraphs of https://modelcontextprotocol.io/introduction" - } - }, - ) - - execution = WorkflowExecution(**json.loads(run_result.content[0].text)) - run_id = execution.run_id - logger.info( - f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" - ) - - get_status_result = await server.call_tool( - "workflows-BasicAgentWorkflow-get_status", - arguments={"run_id": run_id}, - ) - - execution = WorkflowExecution(**json.loads(run_result.content[0].text)) - run_id = execution.run_id - logger.info( - f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" - ) - - # Wait for the workflow to complete - while True: - get_status_result = await server.call_tool( - "workflows-BasicAgentWorkflow-get_status", - arguments={"run_id": run_id}, + if "workflows" in selected: + run_result = await server.call_tool( + "workflows-BasicAgentWorkflow-run", + arguments={ + "run_parameters": { + "input": "Print the first 2 paragraphs of https://modelcontextprotocol.io/introduction" + } + }, ) - workflow_status = _tool_result_to_json(get_status_result) - if workflow_status is None: - logger.error( - f"Failed to parse workflow status response: {get_status_result}" - ) - break - + if "workflows" in selected: + execution = WorkflowExecution( + **json.loads(run_result.content[0].text) + ) + run_id = execution.run_id logger.info( - f"Workflow run {run_id} status:", - data=workflow_status, + f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" ) - if not workflow_status.get("status"): - logger.error( - f"Workflow run {run_id} status is empty. get_status_result:", - data=get_status_result, + # Wait for the workflow to complete + if "workflows" in selected: + while True: + get_status_result = await server.call_tool( + "workflows-BasicAgentWorkflow-get_status", + arguments={"run_id": run_id}, ) - break - if workflow_status.get("status") == "completed": - logger.info( - f"Workflow run {run_id} completed successfully! Result:", - data=workflow_status.get("result"), - ) + workflow_status = _tool_result_to_json(get_status_result) + if workflow_status is None: + logger.error( + f"Failed to parse workflow status response: {get_status_result}" + ) + break - break - elif workflow_status.get("status") == "error": - logger.error( - f"Workflow run {run_id} failed with error:", - data=workflow_status, - ) - break - elif workflow_status.get("status") == "running": logger.info( - f"Workflow run {run_id} is still running...", - ) - elif workflow_status.get("status") == "cancelled": - logger.error( - f"Workflow run {run_id} was cancelled.", + f"Workflow run {run_id} status:", data=workflow_status, ) - break - else: - logger.error( - f"Unknown workflow status: {workflow_status.get('status')}", - data=workflow_status, - ) - break - await asyncio.sleep(5) + if not workflow_status.get("status"): + logger.error( + f"Workflow run {run_id} status is empty. get_status_result:", + data=get_status_result, + ) + break + + if workflow_status.get("status") == "completed": + logger.info( + f"Workflow run {run_id} completed successfully! Result:", + data=workflow_status.get("result"), + ) + + break + elif workflow_status.get("status") == "error": + logger.error( + f"Workflow run {run_id} failed with error:", + data=workflow_status, + ) + break + elif workflow_status.get("status") == "running": + logger.info( + f"Workflow run {run_id} is still running...", + ) + elif workflow_status.get("status") == "cancelled": + logger.error( + f"Workflow run {run_id} was cancelled.", + data=workflow_status, + ) + break + else: + logger.error( + f"Unknown workflow status: {workflow_status.get('status')}", + data=workflow_status, + ) + break + + await asyncio.sleep(5) # TODO: UNCOMMENT ME to try out cancellation: # await server.call_tool( @@ -177,31 +260,102 @@ def make_session( # arguments={"workflow_id": "BasicAgentWorkflow", "run_id": run_id}, # ) - print(run_result) + if "workflows" in selected: + print(run_result) # Call the sync tool 'finder_tool' (no run/status loop) - try: - finder_result = await server.call_tool( - "finder_tool", - arguments={ - "request": "Summarize the Model Context Protocol introduction from https://modelcontextprotocol.io/introduction." - }, - ) - finder_payload = _tool_result_to_json(finder_result) or ( - ( - finder_result.structuredContent.get("result") - if getattr(finder_result, "structuredContent", None) - else None + if "tools" in selected: + try: + finder_result = await server.call_tool( + "finder_tool", + arguments={ + "request": "Summarize the Model Context Protocol introduction from https://modelcontextprotocol.io/introduction." + }, ) - or ( - finder_result.content[0].text - if getattr(finder_result, "content", None) - else None + finder_payload = _tool_result_to_json(finder_result) or ( + ( + finder_result.structuredContent.get("result") + if getattr(finder_result, "structuredContent", None) + else None + ) + or ( + finder_result.content[0].text + if getattr(finder_result, "content", None) + else None + ) ) - ) - logger.info("finder_tool result:", data=finder_payload) - except Exception as e: - logger.error("finder_tool call failed", data=str(e)) + logger.info("finder_tool result:", data=finder_payload) + except Exception as e: + logger.error("finder_tool call failed", data=str(e)) + + # SamplingWorkflow + if "sampling" in selected: + try: + sw = await server.call_tool( + "workflows-SamplingWorkflow-run", + arguments={"run_parameters": {"input": "flowers"}}, + ) + sw_ids = json.loads(sw.content[0].text) + sw_run = sw_ids["run_id"] + while True: + st = await server.call_tool( + "workflows-get_status", arguments={"run_id": sw_run} + ) + stj = _tool_result_to_json(st) + logger.info("SamplingWorkflow status:", data=stj or st) + if stj and stj.get("status") in ( + "completed", + "error", + "cancelled", + ): + break + await asyncio.sleep(2) + except Exception as e: + logger.error("SamplingWorkflow failed", data=str(e)) + + # ElicitationWorkflow + if "elicitation" in selected: + try: + ew = await server.call_tool( + "workflows-ElicitationWorkflow-run", + arguments={"run_parameters": {"input": "proceed"}}, + ) + ew_ids = json.loads(ew.content[0].text) + ew_run = ew_ids["run_id"] + while True: + st = await server.call_tool( + "workflows-get_status", arguments={"run_id": ew_run} + ) + stj = _tool_result_to_json(st) + logger.info("ElicitationWorkflow status:", data=stj or st) + if stj and stj.get("status") in ( + "completed", + "error", + "cancelled", + ): + break + await asyncio.sleep(2) + except Exception as e: + logger.error("ElicitationWorkflow failed", data=str(e)) + + # NotificationsWorkflow + if "notifications" in selected: + try: + nw = await server.call_tool( + "workflows-NotificationsWorkflow-run", + arguments={"run_parameters": {"input": "notif"}}, + ) + nw_ids = json.loads(nw.content[0].text) + nw_run = nw_ids["run_id"] + # Wait briefly to allow notifications to flush + await asyncio.sleep(2) + st = await server.call_tool( + "workflows-get_status", arguments={"run_id": nw_run} + ) + stj = _tool_result_to_json(st) + logger.info("NotificationsWorkflow status:", data=stj or st) + except Exception as e: + logger.error("NotificationsWorkflow failed", data=str(e)) except Exception as e: # Tolerate benign shutdown races from SSE client (BrokenResourceError within ExceptionGroup) if _ExceptionGroup is not None and isinstance(e, _ExceptionGroup): diff --git a/src/mcp_agent/executor/temporal/__init__.py b/src/mcp_agent/executor/temporal/__init__.py index 3840d4944..01889b464 100644 --- a/src/mcp_agent/executor/temporal/__init__.py +++ b/src/mcp_agent/executor/temporal/__init__.py @@ -98,15 +98,31 @@ def wrap_as_activity( @activity.defn(name=activity_name) async def wrapped_activity(*args, **local_kwargs): + """ + Temporal activity wrapper that supports both payload styles: + - Single dict payload: wrapped_activity({"k": v, ...}) -> func(**payload) + - Varargs/kwargs payload: wrapped_activity(a, b, c, x=1) -> func(a, b, c, x=1) + """ try: - if asyncio.iscoroutinefunction(func): - return await func(**args[0]) - elif asyncio.iscoroutine(func): - return await func + # Prefer the legacy single-dict payload convention when applicable + if len(args) == 1 and isinstance(args[0], dict) and not local_kwargs: + payload = args[0] + if asyncio.iscoroutinefunction(func): + return await func(**payload) + elif asyncio.iscoroutine(func): + return await func + else: + return func(**payload) else: - return func(**args[0]) + # Fall back to passing through varargs/kwargs directly + if asyncio.iscoroutinefunction(func): + return await func(*args, **local_kwargs) + elif asyncio.iscoroutine(func): + return await func + else: + return func(*args, **local_kwargs) except Exception as e: - # Handle exceptions gracefully + # Properly surface activity exceptions raise e return wrapped_activity @@ -163,6 +179,7 @@ async def _execute_task( activity_registry = self.context.task_registry activity_task = activity_registry.get_activity(activity_name) + # Config timeout takes priority over metadata timeout (per tests). schedule_to_close = self.config.timeout_seconds or execution_metadata.get( "schedule_to_close_timeout" ) @@ -170,15 +187,17 @@ async def _execute_task( if schedule_to_close is not None and not isinstance( schedule_to_close, timedelta ): - # Convert to timedelta if it's not already + # Convert numeric seconds to timedelta if needed schedule_to_close = timedelta(seconds=schedule_to_close) retry_policy = execution_metadata.get("retry_policy", None) try: + # Temporal's execute_activity accepts at most one positional arg; + # pass user args via the keyword-only 'args' to support multiple result = await workflow.execute_activity( activity_task, - *args, + args=list(args) if args else None, task_queue=self.config.task_queue, schedule_to_close_timeout=schedule_to_close, retry_policy=retry_policy, diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py index b0fa213a4..ea4a6e809 100644 --- a/src/mcp_agent/executor/temporal/session_proxy.py +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Any, Dict, List, Type +import asyncio import anyio import mcp.types as types @@ -86,14 +87,23 @@ async def notify(self, method: str, params: Dict[str, Any] | None = None) -> boo if _in_workflow_runtime(): try: act = self._context.task_registry.get_activity("mcp_relay_notify") - await self._executor.execute(act, exec_id, method, params or {}) + await self._executor.execute( + act, + exec_id, + method, + params or {}, + ) return True except Exception: return False - # Non-workflow (activity/asyncio) - return bool( - await self._system_activities.relay_notify(exec_id, method, params or {}) - ) + # Non-workflow (activity/asyncio): fire-and-forget best-effort + try: + asyncio.create_task( + self._system_activities.relay_notify(exec_id, method, params or {}) + ) + except Exception: + pass + return True async def request( self, method: str, params: Dict[str, Any] | None = None @@ -107,7 +117,12 @@ async def request( if _in_workflow_runtime(): act = self._context.task_registry.get_activity("mcp_relay_request") - return await self._executor.execute(act, exec_id, method, params or {}) + return await self._executor.execute( + act, + exec_id, + method, + params or {}, + ) return await self._system_activities.relay_request( exec_id, method, params or {} ) @@ -268,7 +283,10 @@ async def create_message( params["related_request_id"] = related_request_id result = await self.request("sampling/createMessage", params) - return types.CreateMessageResult.model_validate(result) + try: + return types.CreateMessageResult.model_validate(result) + except Exception as e: + raise RuntimeError(f"sampling/createMessage returned invalid result: {e}") async def elicit( self, @@ -283,7 +301,10 @@ async def elicit( if related_request_id is not None: params["related_request_id"] = related_request_id result = await self.request("elicitation/create", params) - return types.ElicitResult.model_validate(result) + try: + return types.ElicitResult.model_validate(result) + except Exception as e: + raise RuntimeError(f"elicitation/create returned invalid result: {e}") def _in_workflow_runtime() -> bool: diff --git a/src/mcp_agent/executor/temporal/system_activities.py b/src/mcp_agent/executor/temporal/system_activities.py index 024632d79..aff8c7f12 100644 --- a/src/mcp_agent/executor/temporal/system_activities.py +++ b/src/mcp_agent/executor/temporal/system_activities.py @@ -1,4 +1,6 @@ from typing import Any, Dict +import anyio +import os from temporalio import activity @@ -65,14 +67,26 @@ async def relay_notify( ) -> bool: gateway_url = getattr(self.context, "gateway_url", None) gateway_token = getattr(self.context, "gateway_token", None) + # Fire-and-forget semantics with a short timeout (best-effort) + timeout_str = os.environ.get("MCP_NOTIFY_TIMEOUT", "2.0") + try: + timeout = float(timeout_str) + except Exception: + timeout = None - return await notify_via_proxy( - execution_id=execution_id, - method=method, - params=params or {}, - gateway_url=gateway_url, - gateway_token=gateway_token, - ) + ok = True + try: + with anyio.move_on_after(timeout): + ok = await notify_via_proxy( + execution_id=execution_id, + method=method, + params=params or {}, + gateway_url=gateway_url, + gateway_token=gateway_token, + ) + except Exception: + ok = False + return ok @activity.defn(name="mcp_relay_request") async def relay_request( diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py index e289b1059..5f4394e93 100644 --- a/src/mcp_agent/mcp/client_proxy.py +++ b/src/mcp_agent/mcp/client_proxy.py @@ -163,8 +163,23 @@ async def request_via_proxy( if tok: headers["X-MCP-Gateway-Token"] = tok headers["Authorization"] = f"Bearer {tok}" - timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "20")) + # Requests require a response; default to no HTTP timeout. + # Configure with MCP_GATEWAY_REQUEST_TIMEOUT (seconds). If unset or <= 0, no timeout is applied. + timeout_str = os.environ.get("MCP_GATEWAY_REQUEST_TIMEOUT") + timeout_float: float | None + if timeout_str is None: + timeout_float = None # no timeout by default; activity timeouts still apply + else: + try: + timeout_float = float(str(timeout_str).strip()) + except Exception: + timeout_float = None try: + # If timeout is None, pass a Timeout object with no limits + if timeout_float is None: + timeout = httpx.Timeout(None) + else: + timeout = timeout_float async with httpx.AsyncClient(timeout=timeout) as client: r = await client.post( url, json={"method": method, "params": params or {}}, headers=headers diff --git a/src/mcp_agent/mcp/gen_client.py b/src/mcp_agent/mcp/gen_client.py index b3e13d6ba..725d8b538 100644 --- a/src/mcp_agent/mcp/gen_client.py +++ b/src/mcp_agent/mcp/gen_client.py @@ -1,6 +1,6 @@ from contextlib import asynccontextmanager from datetime import timedelta -from typing import AsyncGenerator, Callable +from typing import AsyncGenerator, Callable, Optional from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession @@ -8,6 +8,7 @@ from mcp_agent.logging.logger import get_logger from mcp_agent.mcp.mcp_server_registry import ServerRegistry from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession +from mcp_agent.core.context import Context logger = get_logger(__name__) @@ -17,10 +18,16 @@ async def gen_client( server_name: str, server_registry: ServerRegistry, client_session_factory: Callable[ - [MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None], + [ + MemoryObjectReceiveStream, + MemoryObjectSendStream, + timedelta | None, + Optional[Context], + ], ClientSession, ] = MCPAgentClientSession, session_id: str | None = None, + context: Optional[Context] = None, ) -> AsyncGenerator[ClientSession, None]: """ Create a client session to the specified server. @@ -37,6 +44,7 @@ async def gen_client( server_name=server_name, client_session_factory=client_session_factory, session_id=session_id, + context=context, ) as session: yield session @@ -45,10 +53,16 @@ async def connect( server_name: str, server_registry: ServerRegistry, client_session_factory: Callable[ - [MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None], + [ + MemoryObjectReceiveStream, + MemoryObjectSendStream, + timedelta | None, + Optional[Context], + ], ClientSession, ] = MCPAgentClientSession, session_id: str | None = None, + context: Optional[Context] = None, ) -> ClientSession: """ Create a persistent client session to the specified server. diff --git a/src/mcp_agent/mcp/mcp_agent_client_session.py b/src/mcp_agent/mcp/mcp_agent_client_session.py index 425a0e298..4a4e40c25 100644 --- a/src/mcp_agent/mcp/mcp_agent_client_session.py +++ b/src/mcp_agent/mcp/mcp_agent_client_session.py @@ -40,12 +40,12 @@ Implementation, JSONRPCMessage, ServerRequest, - TextContent, ListRootsResult, NotificationParams, RequestParams, Root, ElicitRequestParams as MCPElicitRequestParams, + ElicitRequest, ElicitResult, PaginatedRequestParams, ) @@ -62,6 +62,7 @@ MCP_TOOL_NAME, ) from mcp_agent.tracing.telemetry import get_tracer, record_attributes +from mcp_agent.mcp.sampling_handler import SamplingHandler if TYPE_CHECKING: from mcp_agent.core.context import Context @@ -116,6 +117,7 @@ def __init__( ) self.server_config: Optional[MCPServerSettings] = None + self._sampling_handler = SamplingHandler(context=self.context) # Session ID handling for Streamable HTTP transport self._get_session_id_callback: Optional[Callable[[], str | None]] = None @@ -334,46 +336,9 @@ async def _handle_sampling_callback( context: RequestContext["ClientSession", Any], params: CreateMessageRequestParams, ) -> CreateMessageResult | ErrorData: - logger.info("Handling sampling request: %s", params) - config = self.context.config + logger.debug(f"Handling sampling request: {params}") server_session = self.context.upstream_session - if server_session is None: - # TODO: saqadri - consider whether we should be handling the sampling request here as a client - logger.warning( - "Error: No upstream client available for sampling requests. Request:", - data=params, - ) - try: - from anthropic import AsyncAnthropic - - client = AsyncAnthropic(api_key=config.anthropic.api_key) - - response = await client.messages.create( - model="claude-3-sonnet-20240229", - max_tokens=params.maxTokens, - messages=[ - { - "role": m.role, - "content": m.content.text - if hasattr(m.content, "text") - else m.content.data, - } - for m in params.messages - ], - system=getattr(params, "systemPrompt", None), - temperature=getattr(params, "temperature", 0.7), - stop_sequences=getattr(params, "stopSequences", None), - ) - - return CreateMessageResult( - model="claude-3-sonnet-20240229", - role="assistant", - content=TextContent(type="text", text=response.content[0].text), - ) - except Exception as e: - logger.error(f"Error handling sampling request: {e}") - return ErrorData(code=-32603, message=str(e)) - else: + if server_session is not None: try: # If a server_session is available, we'll pass-through the sampling request to the upstream client result = await server_session.send_request( @@ -384,11 +349,13 @@ async def _handle_sampling_callback( ), result_type=CreateMessageResult, ) - # Pass the result from the upstream client back to the server. We just act as a pass-through client here. return result except Exception as e: return ErrorData(code=-32603, message=str(e)) + else: + # No upstream session: handle locally via SamplingHandler + return await self._sampling_handler.handle_sampling(params=params) async def _handle_elicitation_callback( self, @@ -399,6 +366,22 @@ async def _handle_elicitation_callback( logger.info("Handling elicitation request", data=params.model_dump()) try: + # Prefer upstream pass-through when an upstream session exists + server_session = self.context.upstream_session + if server_session is not None: + try: + result = await server_session.send_request( + request=ServerRequest( + ElicitRequest(method="elicitation/create", params=params) + ), + result_type=ElicitResult, + ) + return result + except Exception as e: + logger.warning( + f"Upstream elicitation forwarding failed; falling back locally: {e}" + ) + if not self.context.elicitation_handler: logger.error( "No elicitation handler configured for elicitation. Rejecting elicitation." diff --git a/src/mcp_agent/mcp/mcp_server_registry.py b/src/mcp_agent/mcp/mcp_server_registry.py index 9f832781e..a28735373 100644 --- a/src/mcp_agent/mcp/mcp_server_registry.py +++ b/src/mcp_agent/mcp/mcp_server_registry.py @@ -9,7 +9,7 @@ from contextlib import asynccontextmanager from datetime import timedelta -from typing import Callable, Dict, AsyncGenerator +from typing import Callable, Dict, AsyncGenerator, Optional, TYPE_CHECKING from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession @@ -33,6 +33,9 @@ from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession from mcp_agent.mcp.mcp_connection_manager import MCPConnectionManager +if TYPE_CHECKING: + from mcp_agent.core.context import Context + logger = get_logger(__name__) InitHookCallable = Callable[[ClientSession | None, MCPServerAuthSettings | None], bool] @@ -106,10 +109,16 @@ async def start_server( self, server_name: str, client_session_factory: Callable[ - [MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None], + [ + MemoryObjectReceiveStream, + MemoryObjectSendStream, + timedelta | None, + Optional["Context"], + ], ClientSession, ] = ClientSession, session_id: str | None = None, + context: Optional["Context"] = None, ) -> AsyncGenerator[ClientSession, None]: """ Starts the server process based on its configuration. To initialize, call initialize_server @@ -147,11 +156,20 @@ async def start_server( ) async with stdio_client(server_params) as (read_stream, write_stream): - session = client_session_factory( - read_stream, - write_stream, - read_timeout_seconds, - ) + # Construct session; tolerate factories that don't accept 'context' + try: + session = client_session_factory( + read_stream, + write_stream, + read_timeout_seconds, + context=context, + ) + except TypeError: + session = client_session_factory( + read_stream, + write_stream, + read_timeout_seconds, + ) async with session: logger.info( f"{server_name}: Connected to server using stdio transport." @@ -200,11 +218,19 @@ async def start_server( async with streamablehttp_client( **kwargs, ) as (read_stream, write_stream, session_id_callback): - session = client_session_factory( - read_stream, - write_stream, - read_timeout_seconds, - ) + try: + session = client_session_factory( + read_stream, + write_stream, + read_timeout_seconds, + context=context, + ) + except TypeError: + session = client_session_factory( + read_stream, + write_stream, + read_timeout_seconds, + ) if session_id_callback and isinstance(session, MCPAgentClientSession): session.set_session_id_callback(session_id_callback) @@ -239,11 +265,19 @@ async def start_server( read_stream, write_stream, ): - session = client_session_factory( - read_stream, - write_stream, - read_timeout_seconds, - ) + try: + session = client_session_factory( + read_stream, + write_stream, + read_timeout_seconds, + context=context, + ) + except TypeError: + session = client_session_factory( + read_stream, + write_stream, + read_timeout_seconds, + ) async with session: logger.info( f"{server_name}: Connected to server using SSE transport." @@ -263,11 +297,19 @@ async def start_server( read_stream, write_stream, ): - session = client_session_factory( - read_stream, - write_stream, - read_timeout_seconds, - ) + try: + session = client_session_factory( + read_stream, + write_stream, + read_timeout_seconds, + context=context, + ) + except TypeError: + session = client_session_factory( + read_stream, + write_stream, + read_timeout_seconds, + ) async with session: logger.info( f"{server_name}: Connected to server using websocket transport." @@ -285,11 +327,17 @@ async def initialize_server( self, server_name: str, client_session_factory: Callable[ - [MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None], + [ + MemoryObjectReceiveStream, + MemoryObjectSendStream, + timedelta | None, + Optional["Context"], + ], ClientSession, ] = ClientSession, init_hook: InitHookCallable = None, session_id: str | None = None, + context: Optional["Context"] = None, ) -> AsyncGenerator[ClientSession, None]: """ Initialize a server based on its configuration. @@ -315,6 +363,7 @@ async def initialize_server( server_name, client_session_factory=client_session_factory, session_id=session_id, + context=context, ) as session: try: logger.info(f"{server_name}: Initializing server...") diff --git a/src/mcp_agent/mcp/sampling_handler.py b/src/mcp_agent/mcp/sampling_handler.py new file mode 100644 index 000000000..c90ba2fb8 --- /dev/null +++ b/src/mcp_agent/mcp/sampling_handler.py @@ -0,0 +1,198 @@ +""" +MCP Agent Sampling Handler + +Handles sampling requests from MCP servers with human-in-the-loop approval workflow +and direct LLM provider integration. Falls back to upstream pass-through when present. +""" + +from typing import TYPE_CHECKING +from uuid import uuid4 + +from mcp.types import ( + CreateMessageRequest, + CreateMessageRequestParams, + CreateMessageResult, + ErrorData, + TextContent, + ServerRequest, +) + +from mcp.server.fastmcp.exceptions import ToolError + +from mcp_agent.core.context_dependent import ContextDependent +from mcp_agent.logging.logger import get_logger +from mcp_agent.workflows.llm.augmented_llm import RequestParams as LLMRequestParams +from mcp_agent.workflows.llm.llm_selector import ModelSelector + +logger = get_logger(__name__) + +if TYPE_CHECKING: + from mcp_agent.core.context import Context + + +class SamplingHandler(ContextDependent): + """Handles MCP sampling requests with optional human approval and LLM generation.""" + + def __init__(self, context: "Context"): + super().__init__(context=context) + + async def handle_sampling( + self, *, params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: + """Route sampling to upstream session if present, else handle locally.""" + server_session = self.context.upstream_session + if server_session is not None: + try: + return await server_session.send_request( + request=ServerRequest( + CreateMessageRequest( + method="sampling/createMessage", params=params + ) + ), + result_type=CreateMessageResult, + ) + except Exception as e: + return ErrorData(code=-32603, message=str(e)) + + # No upstream session: handle locally with optional human approval + direct LLM call + return await self._handle_sampling_locally(params) + + async def _handle_sampling_locally( + self, params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: + try: + approved_params, reason = await self._human_approve_request(params) + if approved_params is None: + return ErrorData( + code=-32603, message=f"Sampling request rejected by user: {reason}" + ) + + result = await self._generate_with_llm(approved_params) + if result is None: + return ErrorData(code=-32603, message="Failed to generate a response") + + final_result, reason = await self._human_approve_response(result) + if final_result is None: + return ErrorData( + code=-32603, message=f"Response rejected by user: {reason}" + ) + return final_result + except Exception as e: + logger.error(f"Error in local sampling flow: {e}") + return ErrorData(code=-32603, message=str(e)) + + async def _human_approve_request( + self, params: CreateMessageRequestParams + ) -> tuple[CreateMessageRequestParams | None, str]: + if not self.context.human_input_handler: + return params, "" + + from mcp_agent.human_input.types import HumanInputRequest + + req = HumanInputRequest( + prompt=( + "MCP server requests LLM sampling. Respond 'approve' to proceed, " + "anything else to reject (your input will be recorded as reason)." + ), + description="MCP Sampling Request Approval", + request_id=f"sampling_request_{uuid4()}", + metadata={ + "type": "sampling_request_approval", + "original_params": params.model_dump(), + }, + ) + resp = await self.context.human_input_handler(req) + text = (resp.response or "").strip().lower() + return ( + (params, "") if text == "approve" else (None, resp.response or "rejected") + ) + + async def _human_approve_response( + self, result: CreateMessageResult + ) -> tuple[CreateMessageResult | None, str]: + if not self.context.human_input_handler: + return result, "" + + from mcp_agent.human_input.types import HumanInputRequest + + req = HumanInputRequest( + prompt=( + "LLM has generated a response. Respond 'approve' to send, " + "anything else to reject (your input will be recorded as reason)." + ), + description="MCP Sampling Response Approval", + request_id=f"sampling_response_{uuid4()}", + metadata={ + "type": "sampling_response_approval", + "original_result": result.model_dump(), + }, + ) + resp = await self.context.human_input_handler(req) + text = (resp.response or "").strip().lower() + return ( + (result, "") if text == "approve" else (None, resp.response or "rejected") + ) + + async def _generate_with_llm( + self, params: CreateMessageRequestParams + ) -> CreateMessageResult | None: + # Require model preferences to avoid recursion/guessing + if params.modelPreferences is None: + raise ToolError("Model preferences must be provided for sampling requests") + + model_selector = self.context.model_selector or ModelSelector() + model_info = model_selector.select_best_model(params.modelPreferences) + + # Lazy import to avoid circulars, and create a clean LLM instance without current context + from mcp_agent.workflows.factory import create_llm + + # Honor the caller's systemPrompt as instruction when constructing the LLM + llm = create_llm( + agent_name="sampling", + server_names=[], + instruction=getattr(params, "systemPrompt", None), + provider=model_info.provider, + model=model_info.name, + request_params=None, + context=None, + ) + + # Flatten MCP SamplingMessage list to raw strings for generate_str + messages: list[str] = [] + for m in params.messages: + if hasattr(m.content, "text") and m.content.text: + messages.append(m.content.text) + elif hasattr(m.content, "data") and m.content.data: + messages.append(str(m.content.data)) + else: + messages.append(str(m.content)) + + # Coerce optional temperature to a sane default if missing + temperature = getattr(params, "temperature", None) + if temperature is None: + temperature = 0.7 + + # Build request params by extending CreateMessageRequestParams so + # everything the user provided is forwarded to the LLM + req_params = LLMRequestParams( + maxTokens=params.maxTokens or 2048, + temperature=temperature, + systemPrompt=getattr(params, "systemPrompt", None), + includeContext=getattr(params, "includeContext", None), + stopSequences=getattr(params, "stopSequences", None), + metadata=getattr(params, "metadata", None), + modelPreferences=params.modelPreferences, + # Keep local generation simple/deterministic + max_iterations=1, + parallel_tool_calls=False, + use_history=False, + messages=None, + ) + + text = await llm.generate_str(message=messages, request_params=req_params) + model_name = await llm.select_model(req_params) or model_info.name + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text=text), + model=model_name, + ) diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index c72296339..1b4e8b712 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -527,7 +527,9 @@ async def _relay_notify(request: Request): logger.error( f"[notify] error forwarding for execution_id={execution_id}: {e_mapped}" ) - return JSONResponse({"ok": False, "error": str(e_mapped)}, status_code=500) + return JSONResponse( + {"ok": False, "error": str(e_mapped)}, status_code=500 + ) @mcp_server.custom_route( "/internal/session/by-run/{execution_id}/request", @@ -603,7 +605,9 @@ async def _relay_request(request: Request): except Exception: pass return JSONResponse( - result.model_dump(by_alias=True, mode="json", exclude_none=True) + result.model_dump( + by_alias=True, mode="json", exclude_none=True + ) ) elif method == "elicitation/create": req = ServerRequest( @@ -625,7 +629,9 @@ async def _relay_request(request: Request): except Exception: pass return JSONResponse( - result.model_dump(by_alias=True, mode="json", exclude_none=True) + result.model_dump( + by_alias=True, mode="json", exclude_none=True + ) ) elif method == "roots/list": req = ServerRequest(ListRootsRequest(method="roots/list")) @@ -642,7 +648,9 @@ async def _relay_request(request: Request): except Exception: pass return JSONResponse( - result.model_dump(by_alias=True, mode="json", exclude_none=True) + result.model_dump( + by_alias=True, mode="json", exclude_none=True + ) ) elif method == "ping": req = ServerRequest(PingRequest(method="ping")) @@ -659,7 +667,9 @@ async def _relay_request(request: Request): except Exception: pass return JSONResponse( - result.model_dump(by_alias=True, mode="json", exclude_none=True) + result.model_dump( + by_alias=True, mode="json", exclude_none=True + ) ) except Exception as e_latest: logger.warning( @@ -672,9 +682,7 @@ async def _relay_request(request: Request): logger.warning( f"[request] session_not_available for execution_id={execution_id}" ) - return JSONResponse( - {"error": "session_not_available"}, status_code=503 - ) + return JSONResponse({"error": "session_not_available"}, status_code=503) try: # Prefer generic request passthrough if available @@ -857,7 +865,7 @@ async def _internal_human_prompts(request: Request): metadata = body.get("metadata") or {} try: logger.info( - f"[human] incoming execution_id={execution_id} signal_name={metadata.get('signal_name','human_input')}" + f"[human] incoming execution_id={execution_id} signal_name={metadata.get('signal_name', 'human_input')}" ) except Exception: pass @@ -926,7 +934,9 @@ async def _internal_human_prompts(request: Request): # Fallback to mapped session session = await _get_session(execution_id) if not session: - return JSONResponse({"error": "session_not_available"}, status_code=503) + return JSONResponse( + {"error": "session_not_available"}, status_code=503 + ) await session.send_log_message( level="info", # type: ignore[arg-type] data=payload, @@ -1804,6 +1814,27 @@ async def _workflow_run( except Exception: gateway_url = None + # Normalize gateway URL if it points to a non-routable bind address + def _normalize_gateway_url(url: str | None) -> str | None: + if not url: + return url + try: + from urllib.parse import urlparse, urlunparse + + parsed = urlparse(url) + host = parsed.hostname or "" + # Replace wildcard binds with a loopback address that's actually connectable + if host in ("0.0.0.0", "::", "[::]"): + new_host = "127.0.0.1" if host == "0.0.0.0" else "localhost" + netloc = parsed.netloc.replace(host, new_host) + parsed = parsed._replace(netloc=netloc) + return urlunparse(parsed) + except Exception: + pass + return url + + gateway_url = _normalize_gateway_url(gateway_url) + # Final fallback: environment variables (useful if proxies don't set headers) try: import os as _os diff --git a/uv.lock b/uv.lock index 4fb774a44..96e962b6f 100644 --- a/uv.lock +++ b/uv.lock @@ -2041,7 +2041,7 @@ wheels = [ [[package]] name = "mcp-agent" -version = "0.1.20" +version = "0.1.21" source = { editable = "." } dependencies = [ { name = "aiohttp" },