diff --git a/examples/mcp/mcp_elicitation/main.py b/examples/mcp/mcp_elicitation/main.py index 78728b880..1dcce620b 100644 --- a/examples/mcp/mcp_elicitation/main.py +++ b/examples/mcp/mcp_elicitation/main.py @@ -15,6 +15,7 @@ ) +@app.tool async def example_usage(): async with app.run() as agent_app: logger = agent_app.logger diff --git a/examples/mcp/mcp_elicitation/temporal/client.py b/examples/mcp/mcp_elicitation/temporal/client.py new file mode 100644 index 000000000..b6c4d114c --- /dev/null +++ b/examples/mcp/mcp_elicitation/temporal/client.py @@ -0,0 +1,288 @@ +import asyncio +import json +import time +from mcp_agent.app import MCPApp +from mcp_agent.config import Settings, LoggerSettings, MCPSettings +import yaml +from mcp_agent.elicitation.handler import console_elicitation_callback +from mcp_agent.config import MCPServerSettings +from mcp_agent.core.context import Context +from mcp_agent.executor.workflow import WorkflowExecution +from mcp_agent.mcp.gen_client import gen_client +from datetime import timedelta +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import ClientSession +from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession +from mcp.types import CallToolResult, LoggingMessageNotificationParams +from mcp_agent.human_input.handler import console_input_callback +try: + from exceptiongroup import ExceptionGroup as _ExceptionGroup # Python 3.10 backport +except Exception: # pragma: no cover + _ExceptionGroup = None # type: ignore +try: + from anyio import BrokenResourceError as _BrokenResourceError +except Exception: # pragma: no cover + _BrokenResourceError = None # type: ignore + + +async def main(): + # Create MCPApp to get the server registry, with console handlers + # IMPORTANT: This client acts as the “upstream MCP client” for the server. + # When the server requests sampling (sampling/createMessage), the client-side + # MCPApp must be able to service that request locally (approval prompts + LLM call). + # Those client-local flows are not running inside a Temporal workflow, so they + # must use the asyncio executor. If this were set to "temporal", local sampling + # would crash with: "TemporalExecutor.execute must be called from within a workflow". + # + # We programmatically construct Settings here (mirroring examples/basic/mcp_basic_agent/main.py) + # so everything is self-contained in this client: + settings = Settings( + execution_engine="asyncio", + logger=LoggerSettings(level="info"), + mcp=MCPSettings( + servers={ + "basic_agent_server": MCPServerSettings( + name="basic_agent_server", + description="Local workflow server running the basic agent example", + transport="sse", + # Use a routable loopback host; 0.0.0.0 is a bind address, not a client URL + url="http://127.0.0.1:8000/sse", + ) + } + ), + ) + # Load secrets (API keys, etc.) if a secrets file is available and merge into settings. + # We intentionally deep-merge the secrets on top of our base settings so + # credentials are applied without overriding our executor or server endpoint. + try: + secrets_path = Settings.find_secrets() + if secrets_path and secrets_path.exists(): + with open(secrets_path, "r", encoding="utf-8") as f: + secrets_dict = yaml.safe_load(f) or {} + + def _deep_merge(base: dict, overlay: dict) -> dict: + out = dict(base) + for k, v in (overlay or {}).items(): + if k in out and isinstance(out[k], dict) and isinstance(v, dict): + out[k] = _deep_merge(out[k], v) + else: + out[k] = v + return out + + base_dict = settings.model_dump(mode="json") + merged = _deep_merge(base_dict, secrets_dict) + settings = Settings(**merged) + except Exception: + # Best-effort: continue without secrets if parsing fails + pass + app = MCPApp( + name="workflow_mcp_client", + # Disable sampling approval prompts entirely to keep flows non-interactive. + # Elicitation remains interactive via console_elicitation_callback. + human_input_callback=console_input_callback, + elicitation_callback=console_elicitation_callback, + settings=settings, + ) + async with app.run() as client_app: + logger = client_app.logger + context = client_app.context + + # Connect to the workflow server + try: + logger.info("Connecting to workflow server...") + + # Server connection is configured via Settings above (no runtime mutation needed) + + # Connect to the workflow server + # Define a logging callback to receive server-side log notifications + async def on_server_log(params: LoggingMessageNotificationParams) -> None: + # Pretty-print server logs locally for demonstration + level = params.level.upper() + name = params.logger or "server" + # params.data can be any JSON-serializable data + print(f"[SERVER LOG] [{level}] [{name}] {params.data}") + + # Provide a client session factory that installs our logging callback + # and prints non-logging notifications to the console + class ConsolePrintingClientSession(MCPAgentClientSession): + async def _received_notification(self, notification): # type: ignore[override] + try: + method = getattr(notification.root, "method", None) + except Exception: + method = None + + # Avoid duplicating server log prints (handled by logging_callback) + if method and method != "notifications/message": + try: + data = notification.model_dump() + except Exception: + data = str(notification) + print(f"[SERVER NOTIFY] {method}: {data}") + + return await super()._received_notification(notification) + + def make_session( + read_stream: MemoryObjectReceiveStream, + write_stream: MemoryObjectSendStream, + read_timeout_seconds: timedelta | None, + context: Context | None = None, + ) -> ClientSession: + return ConsolePrintingClientSession( + read_stream=read_stream, + write_stream=write_stream, + read_timeout_seconds=read_timeout_seconds, + logging_callback=on_server_log, + context=context, + ) + + # Connect to the workflow server + async with gen_client( + "basic_agent_server", + context.server_registry, + client_session_factory=make_session, + ) as server: + # Ask server to send logs at the requested level (default info) + level = "info" + print(f"[client] Setting server logging level to: {level}") + try: + await server.set_logging_level(level) + except Exception: + # Older servers may not support logging capability + print("[client] Server does not support logging/setLevel") + + # Call the `book_table` tool defined via `@app.tool` + run_result = await server.call_tool( + "book_table", + arguments={ + "date": "today", + "party_size": 2, + "topic": "autumn" + }, + ) + print(f"[client] Workflow run result: {run_result}") + + # Run the `TestWorkflow` workflow... + run_result = await server.call_tool( + "workflows-TestWorkflow-run", + arguments={ + "run_parameters":{ + "args":{ + "date": "today", + "party_size": 2, + "topic": "autumn" + } + } + } + ) + + execution = WorkflowExecution( + **json.loads(run_result.content[0].text) + ) + run_id = execution.run_id + workflow_id = execution.workflow_id + + # and wait for execution to complete + while True: + get_status_result = await server.call_tool( + "workflows-get_status", + arguments={ + "run_id": run_id, + "workflow_id": workflow_id + }, + ) + + workflow_status = _tool_result_to_json(get_status_result) + if workflow_status is None: + logger.error( + f"Failed to parse workflow status response: {get_status_result}" + ) + break + + logger.info( + f"Workflow run {run_id} status:", + data=workflow_status, + ) + + if not workflow_status.get("status"): + logger.error( + f"Workflow run {run_id} status is empty. get_status_result:", + data=get_status_result, + ) + break + + if workflow_status.get("status") == "completed": + logger.info( + f"Workflow run {run_id} completed successfully! Result:", + data=workflow_status.get("result"), + ) + + break + elif workflow_status.get("status") == "error": + logger.error( + f"Workflow run {run_id} failed with error:", + data=workflow_status, + ) + break + elif workflow_status.get("status") == "running": + logger.info( + f"Workflow run {run_id} is still running...", + ) + elif workflow_status.get("status") == "cancelled": + logger.error( + f"Workflow run {run_id} was cancelled.", + data=workflow_status, + ) + break + else: + logger.error( + f"Unknown workflow status: {workflow_status.get('status')}", + data=workflow_status, + ) + break + + await asyncio.sleep(5) + + except Exception as e: + # Tolerate benign shutdown races from SSE client (BrokenResourceError within ExceptionGroup) + if _ExceptionGroup is not None and isinstance(e, _ExceptionGroup): + subs = getattr(e, "exceptions", []) or [] + if ( + _BrokenResourceError is not None + and subs + and all(isinstance(se, _BrokenResourceError) for se in subs) + ): + logger.debug("Ignored BrokenResourceError from SSE shutdown") + else: + raise + elif _BrokenResourceError is not None and isinstance( + e, _BrokenResourceError + ): + logger.debug("Ignored BrokenResourceError from SSE shutdown") + elif "BrokenResourceError" in str(e): + logger.debug( + "Ignored BrokenResourceError from SSE shutdown (string match)" + ) + else: + raise + + +def _tool_result_to_json(tool_result: CallToolResult): + if tool_result.content and len(tool_result.content) > 0: + text = tool_result.content[0].text + try: + # Try to parse the response as JSON if it's a string + import json + + return json.loads(text) + except (json.JSONDecodeError, TypeError): + # If it's not valid JSON, just use the text + return None + + +if __name__ == "__main__": + start = time.time() + asyncio.run(main()) + end = time.time() + t = end - start + + print(f"Total run time: {t:.2f}s") diff --git a/examples/mcp/mcp_elicitation/temporal/main.py b/examples/mcp/mcp_elicitation/temporal/main.py new file mode 100644 index 000000000..eff0901dd --- /dev/null +++ b/examples/mcp/mcp_elicitation/temporal/main.py @@ -0,0 +1,125 @@ +import asyncio +import logging +from typing import Dict, Any + +from mcp.server.fastmcp import Context +import mcp.types as types +from pydantic import BaseModel, Field +from mcp_agent.app import MCPApp +from mcp_agent.server.app_server import create_mcp_server_for_app +from mcp_agent.executor.workflow import Workflow, WorkflowResult + +# Initialize logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +app = MCPApp( + name="elicitation_demo", + description="Demo of workflow with elicitation" +) + + +@app.tool() +async def book_table(date: str, party_size: int, topic: str, app_ctx: Context) -> str: + """Book a table with confirmation""" + + app.logger.info(f"Confirming table for {party_size} on {date}") + + class ConfirmBooking(BaseModel): + confirm: bool = Field(description="Confirm booking?") + notes: str = Field(default="", description="Special requests") + + result = await app.context.upstream_session.elicit( + message=f"Confirm booking for {party_size} on {date}?", + requestedSchema=ConfirmBooking.model_json_schema(), + ) + + app.logger.info(f"Result from confirmation: {result}") + + haiku = await app_ctx.upstream_session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent( + type="text", text=f"Write a haiku about {topic}." + ), + ) + ], + system_prompt="You are a poet.", + max_tokens=80, + model_preferences=types.ModelPreferences( + hints=[types.ModelHint(name="gpt-4o-mini")], + costPriority=0.1, + speedPriority=0.8, + intelligencePriority=0.1, + ), + ) + + app.logger.info(f"Haiku: {haiku.content.text}") + return "Done!" + + +@app.workflow +class TestWorkflow(Workflow[str]): + + @app.workflow_run + async def run(self, args: Dict[str, Any]) -> WorkflowResult[str]: + app_ctx = app.context + + date = args.get("date", "today") + party_size = args.get("party_size", 2) + topic = args.get("topic", "autumn") + + app.logger.info(f"Confirming table for {party_size} on {date}") + + class ConfirmBooking(BaseModel): + confirm: bool = Field(description="Confirm booking?") + notes: str = Field(default="", description="Special requests") + + result = await app.context.upstream_session.elicit( + message=f"Confirm booking for {party_size} on {date}?", + requestedSchema=ConfirmBooking.model_json_schema(), + ) + + app.logger.info(f"Result from confirmation: {result}") + + haiku = await app_ctx.upstream_session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent( + type="text", text=f"Write a haiku about {topic}." + ), + ) + ], + system_prompt="You are a poet.", + max_tokens=80, + model_preferences=types.ModelPreferences( + hints=[types.ModelHint(name="gpt-4o-mini")], + costPriority=0.1, + speedPriority=0.8, + intelligencePriority=0.1, + ), + ) + + app.logger.info(f"Haiku: {haiku.content.text}") + return WorkflowResult(value="Done!") + + +async def main(): + async with app.run() as agent_app: + # Log registered workflows and agent configurations + logger.info(f"Creating MCP server for {agent_app.name}") + + logger.info("Registered workflows:") + for workflow_id in agent_app.workflows: + logger.info(f" - {workflow_id}") + # Create the MCP server that exposes both workflows and agent configurations + mcp_server = create_mcp_server_for_app(agent_app) + + # Run the server + await mcp_server.run_sse_async() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/mcp/mcp_elicitation/temporal/mcp_agent.config.yaml b/examples/mcp/mcp_elicitation/temporal/mcp_agent.config.yaml new file mode 100644 index 000000000..186222535 --- /dev/null +++ b/examples/mcp/mcp_elicitation/temporal/mcp_agent.config.yaml @@ -0,0 +1,22 @@ +$schema: ../../../../schema/mcp-agent.config.schema.json + +execution_engine: temporal + +temporal: + host: "localhost:7233" # Default Temporal server address + namespace: "default" # Default Temporal namespace + task_queue: "mcp-agent" # Task queue for workflows and activities + max_concurrent_activities: 10 # Maximum number of concurrent activities + +logger: + transports: [file] + level: debug + path_settings: + path_pattern: "logs/mcp-agent-{unique_id}.jsonl" + unique_id: "timestamp" # Options: "timestamp" or "session_id" + timestamp_format: "%Y%m%d_%H%M%S" + +openai: + # Secrets (API keys, etc.) are stored in an mcp_agent.secrets.yaml file which can be gitignored + # default_model: "o3-mini" + default_model: "gpt-4o-mini" diff --git a/examples/mcp/mcp_elicitation/temporal/mcp_agent.secrets.yaml.example b/examples/mcp/mcp_elicitation/temporal/mcp_agent.secrets.yaml.example new file mode 100644 index 000000000..930cf3648 --- /dev/null +++ b/examples/mcp/mcp_elicitation/temporal/mcp_agent.secrets.yaml.example @@ -0,0 +1,7 @@ +$schema: ../../../../schema/mcp-agent.config.schema.json + +openai: + api_key: openai_api_key + +anthropic: + api_key: anthropic_api_key diff --git a/examples/mcp/mcp_elicitation/temporal/requirements.txt b/examples/mcp/mcp_elicitation/temporal/requirements.txt new file mode 100644 index 000000000..5f239ce9d --- /dev/null +++ b/examples/mcp/mcp_elicitation/temporal/requirements.txt @@ -0,0 +1,7 @@ +# Core framework dependency +mcp-agent + +# Additional dependencies specific to this example +anthropic +openai +temporalio diff --git a/examples/mcp/mcp_elicitation/temporal/worker.py b/examples/mcp/mcp_elicitation/temporal/worker.py new file mode 100644 index 000000000..39b2a3c67 --- /dev/null +++ b/examples/mcp/mcp_elicitation/temporal/worker.py @@ -0,0 +1,31 @@ +""" +Worker script for the Temporal workflow example. +This script starts a Temporal worker that can execute workflows and activities. +Run this script in a separate terminal window before running the main.py script. + +This leverages the TemporalExecutor's start_worker method to handle the worker setup. +""" + +import asyncio +import logging + + +from mcp_agent.executor.temporal import create_temporal_worker_for_app + +from main import app + +# Initialize logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main(): + """ + Start a Temporal worker for the example workflows using the app's executor. + """ + async with create_temporal_worker_for_app(app) as worker: + await worker.run() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index 7908aac40..f40349a95 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -42,6 +42,7 @@ from mcp_agent.workflows.llm.llm_selector import ModelSelector from mcp_agent.workflows.factory import load_agent_specs_from_dir + if TYPE_CHECKING: from mcp_agent.agents.agent_spec import AgentSpec from mcp_agent.executor.workflow import Workflow @@ -459,6 +460,7 @@ def workflow( decorated_cls = workflow_defn_decorator( cls, sandboxed=False, *args, **kwargs ) + self._workflows[workflow_id] = decorated_cls return decorated_cls else: @@ -705,9 +707,9 @@ async def _run(self, *args, **kwargs): # type: ignore[no-redef] # decorate the run method with the engine-specific run decorator. if engine_type == "temporal": try: - run_decorator = self._decorator_registry.get_workflow_run_decorator( + run_decorator = (self._decorator_registry.get_workflow_run_decorator( engine_type - ) + )) if run_decorator: fn_run = getattr(auto_cls, "run") # Ensure method appears as top-level for Temporal diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py index ea4a6e809..0179bac0d 100644 --- a/src/mcp_agent/executor/temporal/session_proxy.py +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -106,7 +106,7 @@ async def notify(self, method: str, params: Dict[str, Any] | None = None) -> boo return True async def request( - self, method: str, params: Dict[str, Any] | None = None + self, method: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: """Send a server->client request and return the client's response. The result is a plain JSON-serializable dict. @@ -117,14 +117,47 @@ async def request( if _in_workflow_runtime(): act = self._context.task_registry.get_activity("mcp_relay_request") - return await self._executor.execute( + + execution_info = await self._executor.execute( act, + True, # Use the async APIs with signalling for response exec_id, method, params or {}, ) + + if execution_info.get("error"): + return execution_info + + signal_name = execution_info.get("signal_name", "") + + if not signal_name: + return {"error": "no_signal_name_returned_from_activity"} + + # Wait for the response via workflow signal + info = _twf.info() + payload = await self._context.executor.wait_for_signal( # type: ignore[attr-defined] + signal_name, + workflow_id=info.workflow_id, + run_id=info.run_id, + signal_description=f"Waiting for async response to {method}", + # Timeout can be controlled by Temporal workflow/activity timeouts + ) + + pc = _twf.payload_converter() + # Support either a Temporal payload wrapper or a plain dict + if hasattr(payload, "payload"): + return pc.from_payload(payload.payload, dict) + if isinstance(payload, dict): + return payload + return pc.from_payload(payload, dict) + + # Non-workflow (activity/asyncio): direct call and wait for result return await self._system_activities.relay_request( - exec_id, method, params or {} + False, # Do not use the async APIs, but the synchronous ones instead + exec_id, + method, + params or {}, ) async def send_notification( @@ -289,10 +322,10 @@ async def create_message( raise RuntimeError(f"sampling/createMessage returned invalid result: {e}") async def elicit( - self, - message: str, - requestedSchema: types.ElicitRequestedSchema, - related_request_id: types.RequestId | None = None, + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, ) -> types.ElicitResult: params: Dict[str, Any] = { "message": message, @@ -325,6 +358,6 @@ async def notify(self, method: str, params: Dict[str, Any] | None = None) -> Non await self._proxy.notify(method, params or {}) async def request( - self, method: str, params: Dict[str, Any] | None = None + self, method: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: return await self._proxy.request(method, params or {}) diff --git a/src/mcp_agent/executor/temporal/system_activities.py b/src/mcp_agent/executor/temporal/system_activities.py index aff8c7f12..c4bd6cc85 100644 --- a/src/mcp_agent/executor/temporal/system_activities.py +++ b/src/mcp_agent/executor/temporal/system_activities.py @@ -90,11 +90,13 @@ async def relay_notify( @activity.defn(name="mcp_relay_request") async def relay_request( - self, execution_id: str, method: str, params: Dict[str, Any] | None = None + self, make_async_call: bool, execution_id: str, method: str, params: Dict[str, Any] | None = None ) -> Dict[str, Any]: gateway_url = getattr(self.context, "gateway_url", None) gateway_token = getattr(self.context, "gateway_token", None) + return await request_via_proxy( + make_async_call=make_async_call, execution_id=execution_id, method=method, params=params or {}, diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py index 5f4394e93..ba065f1a5 100644 --- a/src/mcp_agent/mcp/client_proxy.py +++ b/src/mcp_agent/mcp/client_proxy.py @@ -2,14 +2,15 @@ import os import httpx +import uuid from urllib.parse import quote def _resolve_gateway_url( - *, - gateway_url: Optional[str] = None, - context_gateway_url: Optional[str] = None, + *, + gateway_url: Optional[str] = None, + context_gateway_url: Optional[str] = None, ) -> str: """Resolve the base URL for the MCP gateway. @@ -37,14 +38,14 @@ def _resolve_gateway_url( async def log_via_proxy( - execution_id: str, - level: str, - namespace: str, - message: str, - data: Dict[str, Any] | None = None, - *, - gateway_url: Optional[str] = None, - gateway_token: Optional[str] = None, + execution_id: str, + level: str, + namespace: str, + message: str, + data: Dict[str, Any] | None = None, + *, + gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, ) -> bool: base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None) url = f"{base}/internal/workflows/log" @@ -79,12 +80,12 @@ async def log_via_proxy( async def ask_via_proxy( - execution_id: str, - prompt: str, - metadata: Dict[str, Any] | None = None, - *, - gateway_url: Optional[str] = None, - gateway_token: Optional[str] = None, + execution_id: str, + prompt: str, + metadata: Dict[str, Any] | None = None, + *, + gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, ) -> Dict[str, Any]: base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None) url = f"{base}/internal/human/prompts" @@ -116,12 +117,12 @@ async def ask_via_proxy( async def notify_via_proxy( - execution_id: str, - method: str, - params: Dict[str, Any] | None = None, - *, - gateway_url: Optional[str] = None, - gateway_token: Optional[str] = None, + execution_id: str, + method: str, + params: Dict[str, Any] | None = None, + *, + gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, ) -> bool: base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None) url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/notify" @@ -149,46 +150,108 @@ async def notify_via_proxy( async def request_via_proxy( - execution_id: str, - method: str, - params: Dict[str, Any] | None = None, - *, - gateway_url: Optional[str] = None, - gateway_token: Optional[str] = None, + make_async_call: bool, + execution_id: str, + method: str, + params: Dict[str, Any] | None = None, + *, + gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, ) -> Dict[str, Any]: - base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None) - url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/request" - headers: Dict[str, str] = {} - tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") - if tok: - headers["X-MCP-Gateway-Token"] = tok - headers["Authorization"] = f"Bearer {tok}" - # Requests require a response; default to no HTTP timeout. - # Configure with MCP_GATEWAY_REQUEST_TIMEOUT (seconds). If unset or <= 0, no timeout is applied. - timeout_str = os.environ.get("MCP_GATEWAY_REQUEST_TIMEOUT") - timeout_float: float | None - if timeout_str is None: - timeout_float = None # no timeout by default; activity timeouts still apply - else: + if make_async_call: + # Make sure we're running in a Temporal workflow context try: - timeout_float = float(str(timeout_str).strip()) - except Exception: + from temporalio import workflow, activity + in_temporal = workflow.in_workflow() + if in_temporal: + workflow_id = workflow.info().workflow_id + else: + in_temporal = activity.in_activity() + if in_temporal: + workflow_id = activity.info().workflow_id + except ImportError: + in_temporal = False + + if not in_temporal: + return {"error": "not_in_workflow_or_activity"} + + signal_name = f"mcp_rpc_{method}_{uuid.uuid4().hex}" + + # Make the HTTP request (but don't return the response directly) + base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None) + url = f"{base}/internal/session/by-run/{quote(workflow_id, safe='')}/{quote(execution_id, safe='')}/async-request" + headers: Dict[str, str] = {} + tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") + if tok: + headers["X-MCP-Gateway-Token"] = tok + headers["Authorization"] = f"Bearer {tok}" + + timeout_str = os.environ.get("MCP_GATEWAY_REQUEST_TIMEOUT") + timeout_float: float | None + if timeout_str is None: timeout_float = None - try: - # If timeout is None, pass a Timeout object with no limits - if timeout_float is None: - timeout = httpx.Timeout(None) else: - timeout = timeout_float - async with httpx.AsyncClient(timeout=timeout) as client: - r = await client.post( - url, json={"method": method, "params": params or {}}, headers=headers - ) - except httpx.RequestError: - return {"error": "request_failed"} - if r.status_code >= 400: - return {"error": r.text} - try: - return r.json() if r.content else {"error": "invalid_response"} - except ValueError: - return {"error": "invalid_response"} + try: + timeout_float = float(str(timeout_str).strip()) + except Exception: + timeout_float = None + + try: + if timeout_float is None: + timeout = httpx.Timeout(None) + else: + timeout = timeout_float + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, + json={ + "method": method, + "params": params or {}, + "signal_name": signal_name + }, + headers=headers + ) + except httpx.RequestError: + return {"error": "request_failed"} + if r.status_code >= 400: + return {"error": r.text} + return {"error": "", "signal_name": signal_name} + else: + # Use original synchronous approach for non-workflow contexts + base = _resolve_gateway_url(gateway_url=gateway_url, context_gateway_url=None) + url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/request" + headers: Dict[str, str] = {} + tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") + if tok: + headers["X-MCP-Gateway-Token"] = tok + headers["Authorization"] = f"Bearer {tok}" + # Requests require a response; default to no HTTP timeout. + # Configure with MCP_GATEWAY_REQUEST_TIMEOUT (seconds). If unset or <= 0, no timeout is applied. + timeout_str = os.environ.get("MCP_GATEWAY_REQUEST_TIMEOUT") + timeout_float: float | None + if timeout_str is None: + timeout_float = None # no timeout by default; activity timeouts still apply + else: + try: + timeout_float = float(str(timeout_str).strip()) + except Exception: + timeout_float = None + try: + # If timeout is None, pass a Timeout object with no limits + if timeout_float is None: + timeout = httpx.Timeout(None) + else: + timeout = timeout_float + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, json={"method": method, "params": params or {}}, headers=headers + ) + except httpx.RequestError: + return {"error": "request_failed"} + if r.status_code >= 400: + return {"error": r.text} + + try: + return r.json() if r.content else {"error": "invalid_response"} + except ValueError: + return {"error": "invalid_response"} diff --git a/src/mcp_agent/mcp/sampling_handler.py b/src/mcp_agent/mcp/sampling_handler.py index d3f79f43c..09bf77be6 100644 --- a/src/mcp_agent/mcp/sampling_handler.py +++ b/src/mcp_agent/mcp/sampling_handler.py @@ -30,6 +30,73 @@ from mcp_agent.core.context import Context +def _format_sampling_request_for_human( + params: CreateMessageRequestParams +) -> str: + """Format sampling request for human review""" + messages_text = "" + for i, msg in enumerate(params.messages): + content = ( + msg.content.text if hasattr(msg.content, "text") else str(msg.content) + ) + messages_text += f" Message {i + 1} ({msg.role}): {content[:200]}{'...' if len(content) > 200 else ''}\n" + + system_prompt_display = ( + "None" + if params.systemPrompt is None + else ( + f"{params.systemPrompt[:100]}{'...' if len(params.systemPrompt) > 100 else ''}" + ) + ) + + stop_sequences_display = ( + "None" if params.stopSequences is None else str(params.stopSequences) + ) + + model_preferences_display = "None" + if params.modelPreferences is not None: + prefs = [] + if params.modelPreferences.hints: + hints = [ + hint.name + for hint in params.modelPreferences.hints + if hint.name is not None + ] + prefs.append(f"hints: {hints}") + if params.modelPreferences.costPriority is not None: + prefs.append(f"cost: {params.modelPreferences.costPriority}") + if params.modelPreferences.speedPriority is not None: + prefs.append(f"speed: {params.modelPreferences.speedPriority}") + if params.modelPreferences.intelligencePriority is not None: + prefs.append( + f"intelligence: {params.modelPreferences.intelligencePriority}" + ) + model_preferences_display = ", ".join(prefs) if prefs else "None" + + return f"""REQUEST DETAILS: +- Max Tokens: {params.maxTokens} +- System Prompt: {system_prompt_display} +- Temperature: {params.temperature if params.temperature is not None else 0.7} +- Stop Sequences: {stop_sequences_display} +- Model Preferences: {model_preferences_display} +MESSAGES: +{messages_text}""" + + +def _format_sampling_response_for_human(result: CreateMessageResult) -> str: + """Format sampling response for human review""" + content = ( + result.content.text + if hasattr(result.content, "text") + else str(result.content) + ) + return f"""RESPONSE DETAILS: +- Model: {result.model} +- Role: {result.role} +CONTENT: +{content}""" + + class SamplingHandler(ContextDependent): """Handles MCP sampling requests with optional human approval and LLM generation.""" @@ -89,10 +156,13 @@ async def _human_approve_request( from mcp_agent.human_input.types import HumanInputRequest + request_summary = _format_sampling_request_for_human(params) + req = HumanInputRequest( prompt=( "MCP server requests LLM sampling. Respond 'approve' to proceed, " "anything else to reject (your input will be recorded as reason)." + f"\n\n{request_summary}" ), description="MCP Sampling Request Approval", request_id=f"sampling_request_{uuid4()}", @@ -115,10 +185,14 @@ async def _human_approve_response( from mcp_agent.human_input.types import HumanInputRequest + response_summary = _format_sampling_response_for_human(result) + req = HumanInputRequest( prompt=( "LLM has generated a response. Respond 'approve' to send, " "anything else to reject (your input will be recorded as reason)." + f"\n\n{response_summary}" + ), description="MCP Sampling Response Approval", request_id=f"sampling_response_{uuid4()}", diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 01bea8326..d7d83e287 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -77,8 +77,8 @@ async def _get_session(execution_id: str) -> Any | None: try: logger.debug( ( - f"Lookup session for execution_id={execution_id}: " - + (f"found session_id={id(session)}" if session else "not found") + f"Lookup session for execution_id={execution_id}: " + + (f"found session_id={id(session)}" if session else "not found") ) ) except Exception: @@ -191,7 +191,7 @@ def _set_upstream_from_request_ctx_if_available(ctx: MCPContext) -> None: def _resolve_workflows_and_context( - ctx: MCPContext, + ctx: MCPContext, ) -> Tuple[Dict[str, Type["Workflow"]] | None, Optional["Context"]]: """Resolve the workflows mapping and underlying app context regardless of startup mode. @@ -201,9 +201,9 @@ def _resolve_workflows_and_context( # Try lifespan-provided ServerContext first lifespan_ctx = getattr(ctx.request_context, "lifespan_context", None) if ( - lifespan_ctx is not None - and hasattr(lifespan_ctx, "workflows") - and hasattr(lifespan_ctx, "context") + lifespan_ctx is not None + and hasattr(lifespan_ctx, "workflows") + and hasattr(lifespan_ctx, "context") ): # Ensure upstream session once at resolution time try: @@ -369,23 +369,10 @@ async def _relay_notify(request: Request): method = body.get("method") params = body.get("params") or {} - # Optional shared-secret auth - gw_token = os.environ.get("MCP_GATEWAY_TOKEN") - if gw_token: - bearer = request.headers.get("Authorization", "") - bearer_token = ( - bearer.split(" ", 1)[1] - if bearer.lower().startswith("bearer ") - else "" - ) - header_tok = request.headers.get("X-MCP-Gateway-Token", "") - if not ( - secrets.compare_digest(header_tok, gw_token) - or secrets.compare_digest(bearer_token, gw_token) - ): - return JSONResponse( - {"ok": False, "error": "unauthorized"}, status_code=401 - ) + # Check authentication + auth_error = _check_gateway_auth(request) + if auth_error: + return auth_error # Optional idempotency handling idempotency_key = params.get("idempotency_key") @@ -526,225 +513,259 @@ async def _relay_notify(request: Request): {"ok": False, "error": str(e_mapped)}, status_code=500 ) + # Helper function for shared authentication + def _check_gateway_auth(request: Request) -> JSONResponse | None: + """ + Check optional shared-secret authentication for internal endpoints. + Returns JSONResponse with error if auth fails, None if auth passes. + """ + gw_token = os.environ.get("MCP_GATEWAY_TOKEN") + if not gw_token: + return None # No auth required if no token is set + + bearer = request.headers.get("Authorization", "") + bearer_token = ( + bearer.split(" ", 1)[1] + if bearer.lower().startswith("bearer ") + else "" + ) + header_tok = request.headers.get("X-MCP-Gateway-Token", "") + + if not ( + secrets.compare_digest(header_tok, gw_token) + or secrets.compare_digest(bearer_token, gw_token) + ): + return JSONResponse( + {"ok": False, "error": "unauthorized"}, status_code=401 + ) + + return None # Auth passed + + # Helper functions for request handling + async def _handle_request_via_rpc(session, method: str, params: dict, execution_id: str, + log_prefix: str = "request"): + """Handle request via generic RPC if available.""" + rpc = getattr(session, "rpc", None) + if rpc and hasattr(rpc, "request"): + result = await rpc.request(method, params) + logger.debug(f"[{log_prefix}] delivered via session_id={id(session)} (generic '{method}')") + return result + return None + + async def _handle_specific_request(session, method: str, params: dict, log_prefix: str = "request"): + """Handle specific request types with structured request/response.""" + from mcp.types import ( + CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult, + ElicitRequest, ElicitRequestParams, ElicitResult, + ListRootsRequest, ListRootsResult, + PingRequest, EmptyResult, ServerRequest + ) + + if method == "sampling/createMessage": + req = ServerRequest( + CreateMessageRequest(method="sampling/createMessage", params=CreateMessageRequestParams(**params))) + result = await session.send_request(request=req, + result_type=CreateMessageResult) # type: ignore[attr-defined] + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + elif method == "elicitation/create": + req = ServerRequest(ElicitRequest(method="elicitation/create", params=ElicitRequestParams(**params))) + result = await session.send_request(request=req, result_type=ElicitResult) # type: ignore[attr-defined] + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + elif method == "roots/list": + req = ServerRequest(ListRootsRequest(method="roots/list")) + result = await session.send_request(request=req, + result_type=ListRootsResult) # type: ignore[attr-defined] + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + elif method == "ping": + req = ServerRequest(PingRequest(method="ping")) + result = await session.send_request(request=req, result_type=EmptyResult) # type: ignore[attr-defined] + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + else: + raise ValueError(f"unsupported method: {method}") + + async def _try_session_request(session, method: str, params: dict, execution_id: str, + log_prefix: str = "request", register_session: bool = False): + """Try to handle a request via session, with optional registration.""" + try: + # First try generic RPC passthrough + result = await _handle_request_via_rpc(session, method, params, execution_id, log_prefix) + if result is not None: + if register_session: + try: + await _register_session(run_id=execution_id, execution_id=execution_id, session=session) + # logger.debug( + # f"[{log_prefix}] rebound mapping to session_id={id(session)} for execution_id={execution_id}") + except Exception: + pass + return result + + # Fallback to specific structured request handling + result = await _handle_specific_request(session, method, params, log_prefix) + if register_session: + try: + await _register_session(run_id=execution_id, execution_id=execution_id, session=session) + # logger.debug( + # f"[{log_prefix}] rebound mapping to session_id={id(session)} for execution_id={execution_id}") + except Exception: + pass + return result + except Exception as e: + if "unsupported method" in str(e): + raise # Re-raise unsupported method errors + logger.warning( + f"[{log_prefix}] session delivery failed for execution_id={execution_id} method={method}: {e}") + raise + @mcp_server.custom_route( "/internal/session/by-run/{execution_id}/request", methods=["POST"], include_in_schema=False, ) async def _relay_request(request: Request): - from mcp.types import ( - CreateMessageRequest, - CreateMessageRequestParams, - CreateMessageResult, - ElicitRequest, - ElicitRequestParams, - ElicitResult, - ListRootsRequest, - ListRootsResult, - PingRequest, - EmptyResult, - ServerRequest, - ) - body = await request.json() execution_id = request.path_params.get("execution_id") method = body.get("method") params = body.get("params") or {} - # Prefer latest upstream session first + # Check authentication + auth_error = _check_gateway_auth(request) + if auth_error: + return auth_error + + # Try latest upstream session first latest_session = _get_fallback_upstream_session() if latest_session is not None: try: - rpc = getattr(latest_session, "rpc", None) - if rpc and hasattr(rpc, "request"): - result = await rpc.request(method, params) - # logger.debug( - # f"[request] delivered via latest session_id={id(latest_session)} (generic '{method}')" - # ) - try: - await _register_session( - run_id=execution_id, - execution_id=execution_id, - session=latest_session, - ) - logger.debug( - f"[request] rebound mapping to latest session_id={id(latest_session)} for execution_id={execution_id}" - ) - except Exception: - pass - return JSONResponse(result) - # If latest_session lacks rpc.request, try a limited mapping path - if method == "sampling/createMessage": - req = ServerRequest( - CreateMessageRequest( - method="sampling/createMessage", - params=CreateMessageRequestParams(**params), - ) - ) - result = await latest_session.send_request( # type: ignore[attr-defined] - request=req, - result_type=CreateMessageResult, - ) - try: - await _register_session( - run_id=execution_id, - execution_id=execution_id, - session=latest_session, - ) - except Exception: - pass - return JSONResponse( - result.model_dump( - by_alias=True, mode="json", exclude_none=True - ) - ) - elif method == "elicitation/create": - req = ServerRequest( - ElicitRequest( - method="elicitation/create", - params=ElicitRequestParams(**params), - ) - ) - result = await latest_session.send_request( # type: ignore[attr-defined] - request=req, - result_type=ElicitResult, - ) - try: - await _register_session( - run_id=execution_id, - execution_id=execution_id, - session=latest_session, - ) - except Exception: - pass - return JSONResponse( - result.model_dump( - by_alias=True, mode="json", exclude_none=True - ) - ) - elif method == "roots/list": - req = ServerRequest(ListRootsRequest(method="roots/list")) - result = await latest_session.send_request( # type: ignore[attr-defined] - request=req, - result_type=ListRootsResult, - ) - try: - await _register_session( - run_id=execution_id, - execution_id=execution_id, - session=latest_session, - ) - except Exception: - pass - return JSONResponse( - result.model_dump( - by_alias=True, mode="json", exclude_none=True - ) - ) - elif method == "ping": - req = ServerRequest(PingRequest(method="ping")) - result = await latest_session.send_request( # type: ignore[attr-defined] - request=req, - result_type=EmptyResult, - ) - try: - await _register_session( - run_id=execution_id, - execution_id=execution_id, - session=latest_session, - ) - except Exception: - pass - return JSONResponse( - result.model_dump( - by_alias=True, mode="json", exclude_none=True - ) - ) - except Exception as e_latest: - logger.warning( - f"[request] latest session delivery failed for execution_id={execution_id} method={method}: {e_latest}" + result = await _try_session_request( + latest_session, method, params, execution_id, + log_prefix="request", register_session=True ) + return JSONResponse(result) + except Exception as e_latest: + # Only log and continue to fallback if it's not an unsupported method error + if "unsupported method" not in str(e_latest): + logger.warning( + f"[request] latest session delivery failed for execution_id={execution_id} method={method}: {e_latest}") # Fallback to mapped session session = await _get_session(execution_id) if not session: - logger.warning( - f"[request] session_not_available for execution_id={execution_id}" - ) + logger.warning(f"[request] session_not_available for execution_id={execution_id}") return JSONResponse({"error": "session_not_available"}, status_code=503) try: - # Prefer generic request passthrough if available - rpc = getattr(session, "rpc", None) - if rpc and hasattr(rpc, "request"): - result = await rpc.request(method, params) - try: - logger.debug( - f"[request] forwarded generic request '{method}' to session_id={id(session)}" - ) - except Exception: - pass - return JSONResponse(result) - # Fallback: Map a small set of supported server->client requests - if method == "sampling/createMessage": - req = ServerRequest( - CreateMessageRequest( - method="sampling/createMessage", - params=CreateMessageRequestParams(**params), - ) - ) - result = await session.send_request( # type: ignore[attr-defined] - request=req, - result_type=CreateMessageResult, - ) - return JSONResponse( - result.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - elif method == "elicitation/create": - req = ServerRequest( - ElicitRequest( - method="elicitation/create", - params=ElicitRequestParams(**params), - ) - ) - result = await session.send_request( # type: ignore[attr-defined] - request=req, - result_type=ElicitResult, - ) - return JSONResponse( - result.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - elif method == "roots/list": - req = ServerRequest(ListRootsRequest(method="roots/list")) - result = await session.send_request( # type: ignore[attr-defined] - request=req, - result_type=ListRootsResult, - ) - return JSONResponse( - result.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - elif method == "ping": - req = ServerRequest(PingRequest(method="ping")) - result = await session.send_request( # type: ignore[attr-defined] - request=req, - result_type=EmptyResult, - ) - return JSONResponse( - result.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - else: - return JSONResponse( - {"error": f"unsupported method: {method}"}, status_code=400 - ) + result = await _try_session_request( + session, method, params, execution_id, + log_prefix="request", register_session=False + ) + return JSONResponse(result) except Exception as e: + if "unsupported method" in str(e): + return JSONResponse({"error": f"unsupported method: {method}"}, status_code=400) try: - logger.error( - f"[request] error forwarding for execution_id={execution_id} method={method}: {e}" - ) + logger.error(f"[request] error forwarding for execution_id={execution_id} method={method}: {e}") except Exception: pass return JSONResponse({"error": str(e)}, status_code=500) + @mcp_server.custom_route( + "/internal/session/by-run/{workflow_id}/{execution_id}/async-request", + methods=["POST"], + include_in_schema=False, + ) + async def _async_relay_request(request: Request): + body = await request.json() + execution_id = request.path_params.get("execution_id") + workflow_id = request.path_params.get("workflow_id") + method = body.get("method") + params = body.get("params") or {} + signal_name = body.get("signal_name") + + # Check authentication + auth_error = _check_gateway_auth(request) + if auth_error: + return auth_error + + try: + logger.info(f"[async-request] incoming execution_id={execution_id} method={method}") + except Exception: + pass + + if method != "sampling/createMessage" and method != "elicitation/create": + logger.error(f"async not supported for method {method}") + return JSONResponse({"error": f"async not supported for method {method}"}, + status_code=405) + + if not signal_name: + return JSONResponse({"error": "missing_signal_name"}, status_code=400) + + # Create background task to handle the request and signal the workflow + async def _handle_async_request_task(): + try: + result = None + + # Try latest upstream session first + latest_session = _get_fallback_upstream_session() + if latest_session is not None: + try: + result = await _try_session_request( + latest_session, method, params, execution_id, + log_prefix="async-request", register_session=True + ) + except Exception as e_latest: + logger.warning(f"[async-request] latest session delivery failed for execution_id={execution_id} method={method}: {e_latest}") + + # Fallback to mapped session if latest session failed + if result is None: + session = await _get_session(execution_id) + if session: + try: + result = await _try_session_request( + session, method, params, execution_id, + log_prefix="async-request", register_session=False + ) + except Exception as e: + logger.error(f"[async-request] error forwarding for execution_id={execution_id} method={method}: {e}") + result = {"error": str(e)} + else: + logger.warning(f"[async-request] session_not_available for execution_id={execution_id}") + result = {"error": "session_not_available"} + + # Signal the workflow with the result using method-specific signal + try: + # Try to get Temporal client from the app context + app = _get_attached_app(mcp_server) + if app and app.context and hasattr(app.context, 'executor'): + executor = app.context.executor + if hasattr(executor, 'client'): + client = executor.client + # Find the workflow using execution_id as both workflow_id and run_id + try: + workflow_handle = client.get_workflow_handle( + workflow_id=workflow_id, + run_id=execution_id + ) + + await workflow_handle.signal(signal_name, result) + logger.info(f"[async-request] signaled workflow {execution_id} " + f"with {method} result using signal") + except Exception as signal_error: + logger.warning(f"[async-request] failed to signal workflow {execution_id}:" + f" {signal_error}") + except Exception as e: + logger.error(f"[async-request] failed to signal workflow: {e}") + + except Exception as e: + logger.error(f"[async-request] background task error: {e}") + + # Start the background task + asyncio.create_task(_handle_async_request_task()) + + # Return immediately with 200 status to indicate request was received + return JSONResponse( + {"status": "received", "execution_id": execution_id, "method": method, "signal_name": signal_name} + ) + @mcp_server.custom_route( "/internal/workflows/log", methods=["POST"], include_in_schema=False ) @@ -762,23 +783,10 @@ async def _internal_workflows_log(request: Request): except Exception: pass - # Optional shared-secret auth - gw_token = os.environ.get("MCP_GATEWAY_TOKEN") - if gw_token: - bearer = request.headers.get("Authorization", "") - bearer_token = ( - bearer.split(" ", 1)[1] - if bearer.lower().startswith("bearer ") - else "" - ) - header_tok = request.headers.get("X-MCP-Gateway-Token", "") - if not ( - secrets.compare_digest(header_tok, gw_token) - or secrets.compare_digest(bearer_token, gw_token) - ): - return JSONResponse( - {"ok": False, "error": "unauthorized"}, status_code=401 - ) + # Check authentication + auth_error = _check_gateway_auth(request) + if auth_error: + return auth_error # Prefer latest upstream session first latest_session = _get_fallback_upstream_session() @@ -853,21 +861,10 @@ async def _internal_human_prompts(request: Request): except Exception: pass - # Optional shared-secret auth - gw_token = os.environ.get("MCP_GATEWAY_TOKEN") - if gw_token: - bearer = request.headers.get("Authorization", "") - bearer_token = ( - bearer.split(" ", 1)[1] - if bearer.lower().startswith("bearer ") - else "" - ) - header_tok = request.headers.get("X-MCP-Gateway-Token", "") - if not ( - secrets.compare_digest(header_tok, gw_token) - or secrets.compare_digest(bearer_token, gw_token) - ): - return JSONResponse({"error": "unauthorized"}, status_code=401) + # Check authentication + auth_error = _check_gateway_auth(request) + if auth_error: + return auth_error # Prefer latest upstream session first latest_session = _get_fallback_upstream_session() @@ -979,7 +976,7 @@ async def _internal_human_prompts(request: Request): @lowlevel_server.set_logging_level() async def _set_level( - level: str, + level: str, ) -> None: # mcp.types.LoggingLevel is a Literal[str] try: LoggingConfig.set_min_level(level) @@ -1094,10 +1091,10 @@ async def list_workflow_runs( @mcp.tool(name="workflows-run") async def run_workflow( - ctx: MCPContext, - workflow_name: str, - run_parameters: Dict[str, Any] | None = None, - **kwargs: Any, + ctx: MCPContext, + workflow_name: str, + run_parameters: Dict[str, Any] | None = None, + **kwargs: Any, ) -> Dict[str, str]: """ Run a workflow with the given name. @@ -1121,9 +1118,9 @@ async def run_workflow( @mcp.tool(name="workflows-get_status") async def get_workflow_status( - ctx: MCPContext, - run_id: str | None = None, - workflow_id: str | None = None, + ctx: MCPContext, + run_id: str | None = None, + workflow_id: str | None = None, ) -> Dict[str, Any]: """ Get the status of a running workflow. @@ -1161,11 +1158,11 @@ async def get_workflow_status( @mcp.tool(name="workflows-resume") async def resume_workflow( - ctx: MCPContext, - run_id: str | None = None, - workflow_id: str | None = None, - signal_name: str | None = "resume", - payload: Dict[str, Any] | None = None, + ctx: MCPContext, + run_id: str | None = None, + workflow_id: str | None = None, + signal_name: str | None = "resume", + payload: Dict[str, Any] | None = None, ) -> bool: """ Resume a paused workflow. @@ -1235,7 +1232,7 @@ async def resume_workflow( @mcp.tool(name="workflows-cancel") async def cancel_workflow( - ctx: MCPContext, run_id: str | None = None, workflow_id: str | None = None + ctx: MCPContext, run_id: str | None = None, workflow_id: str | None = None ) -> bool: """ Cancel a running workflow. @@ -1365,14 +1362,14 @@ def create_declared_function_tools(mcp: FastMCP, server_context: ServerContext): import time async def _wait_for_completion( - ctx: MCPContext, - run_id: str, - *, - workflow_id: str | None = None, - timeout: float | None = None, - registration_grace: float = 1.0, - poll_initial: float = 0.05, - poll_max: float = 1.0, + ctx: MCPContext, + run_id: str, + *, + workflow_id: str | None = None, + timeout: float | None = None, + registration_grace: float = 1.0, + poll_initial: float = 0.05, + poll_max: float = 1.0, ): registry = _resolve_workflow_registry(ctx) if not registry: @@ -1465,8 +1462,8 @@ async def _wrapper(**kwargs): return getattr(result, "value", None) # If status payload returned a dict that looks like WorkflowResult, unwrap safely via 'kind' if ( - isinstance(result, dict) - and result.get("kind") == "workflow_result" + isinstance(result, dict) + and result.get("kind") == "workflow_result" ): return result.get("value") return result @@ -1572,9 +1569,9 @@ async def _async_wrapper(**kwargs): if p.name in ("ctx", "context"): continue if ( - _Ctx is not None - and p.annotation is not inspect._empty - and p.annotation is _Ctx + _Ctx is not None + and p.annotation is not inspect._empty + and p.annotation is _Ctx ): continue params.append(p) @@ -1627,7 +1624,7 @@ async def _adapter(**kw): def create_workflow_specific_tools( - mcp: FastMCP, workflow_name: str, workflow_cls: Type["Workflow"] + mcp: FastMCP, workflow_name: str, workflow_cls: Type["Workflow"] ): """Create specific tools for a given workflow.""" param_source = _get_param_source_function_from_workflow(workflow_cls) @@ -1676,8 +1673,8 @@ def _schema_fn_proxy(*args, **kwargs): """, ) async def run( - ctx: MCPContext, - run_parameters: Dict[str, Any] | None = None, + ctx: MCPContext, + run_parameters: Dict[str, Any] | None = None, ) -> Dict[str, str]: _set_upstream_from_request_ctx_if_available(ctx) return await _workflow_run(ctx, workflow_name, run_parameters) @@ -1687,7 +1684,7 @@ async def run( def _get_server_descriptions( - server_registry: ServerRegistry | None, server_names: List[str] + server_registry: ServerRegistry | None, server_names: List[str] ) -> List: servers: List[dict[str, str]] = [] if server_registry: @@ -1709,7 +1706,7 @@ def _get_server_descriptions( def _get_server_descriptions_as_string( - server_registry: ServerRegistry | None, server_names: List[str] + server_registry: ServerRegistry | None, server_names: List[str] ) -> str: servers = _get_server_descriptions(server_registry, server_names) @@ -1729,10 +1726,10 @@ def _get_server_descriptions_as_string( async def _workflow_run( - ctx: MCPContext, - workflow_name: str, - run_parameters: Dict[str, Any] | None = None, - **kwargs: Any, + ctx: MCPContext, + workflow_name: str, + run_parameters: Dict[str, Any] | None = None, + **kwargs: Any, ) -> Dict[str, str]: # Use Temporal run_id as the routing key for gateway callbacks. # We don't have it until after the workflow is started; we'll register mapping post-start. @@ -1902,7 +1899,7 @@ def _normalize_gateway_url(url: str | None) -> str | None: async def _workflow_status( - ctx: MCPContext, run_id: str | None = None, workflow_id: str | None = None + ctx: MCPContext, run_id: str | None = None, workflow_id: str | None = None ) -> Dict[str, Any]: # Ensure upstream session so status-related logs are forwarded try: @@ -1942,5 +1939,4 @@ async def _workflow_status( return status - # endregion diff --git a/src/mcp_agent/workflows/factory.py b/src/mcp_agent/workflows/factory.py index df22718db..3a661adb3 100644 --- a/src/mcp_agent/workflows/factory.py +++ b/src/mcp_agent/workflows/factory.py @@ -72,22 +72,7 @@ def create_llm( model: str | ModelPreferences | None = None, request_params: RequestParams | None = None, context: Context | None = None, -) -> AugmentedLLM: - """ - Create an Augmented LLM from an agent or agent spec. - """ - agent = ( - agent if isinstance(agent, Agent) else agent_from_spec(agent, context=context) - ) - - factory = _llm_factory( - provider=provider, - model=model, - request_params=request_params, - context=context, - ) - - return factory(agent=agent) +) -> AugmentedLLM: ... @overload @@ -99,21 +84,45 @@ def create_llm( model: str | ModelPreferences | None = None, request_params: RequestParams | None = None, context: Context | None = None, +) -> AugmentedLLM: ... + + +def create_llm( + agent: Agent | AgentSpec | None = None, + agent_name: str | None = None, + server_names: List[str] | None = None, + instruction: str | None = None, + provider: str = "openai", + model: str | ModelPreferences | None = None, + request_params: RequestParams | None = None, + context: Context | None = None, ) -> AugmentedLLM: """ - Create an Augmented LLM. + Create an Augmented LLM from an agent, agent spec, or agent name. """ + if isinstance(agent_name, str): + # Handle the case where first argument is agent_name (string) + agent_obj = agent_from_spec( + AgentSpec( + name=agent_name, instruction=instruction, server_names=server_names or [] + ), + context=context, + ) + elif isinstance(agent, AgentSpec): + # Handle AgentSpec case + agent_obj = agent_from_spec(agent, context=context) + else: + # Handle Agent case + agent_obj = agent - agent = agent_from_spec( - AgentSpec( - name=agent_name, instruction=instruction, server_names=server_names or [] - ), - context=context, - ) factory = _llm_factory( - provider=provider, model=model, request_params=request_params, context=context + provider=provider, + model=model, + request_params=request_params, + context=context, ) - return factory(agent=agent) + + return factory(agent=agent_obj) async def create_router_llm(