diff --git a/README.md b/README.md index 5ea198009..21f6a5448 100644 --- a/README.md +++ b/README.md @@ -568,16 +568,16 @@ orchestrator = Orchestrator( The [Swarm example](examples/workflows/workflow_swarm/main.py) shows this in action. ```python -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback lost_baggage = SwarmAgent( name="Lost baggage traversal", instruction=lambda context_variables: f""" { - FLY_AIR_AGENT_PROMPT.format( - customer_context=context_variables.get("customer_context", "None"), - flight_context=context_variables.get("flight_context", "None"), - ) + FLY_AIR_AGENT_PROMPT.format( + customer_context=context_variables.get("customer_context", "None"), + flight_context=context_variables.get("flight_context", "None"), + ) }\n Lost baggage policy: policies/lost_baggage_policy.md""", functions=[ escalate_to_agent, @@ -586,7 +586,7 @@ lost_baggage = SwarmAgent( case_resolved, ], server_names=["fetch", "filesystem"], - human_input_callback=console_input_callback, # Request input from the console + human_input_callback=console_input_callback, # Request input from the console ) ``` diff --git a/examples/human_input/temporal/README.md b/examples/human_input/temporal/README.md new file mode 100644 index 000000000..b2c7c89c1 --- /dev/null +++ b/examples/human_input/temporal/README.md @@ -0,0 +1,92 @@ +# Human interactions in Temporal + +This example demonstrates how to implement human interactions in an MCP running as a Temporal workflow. +Human input can be used for approvals or data entry. +In this case, we ask a human to provide their name, so we can create a personalised greeting. + +## Set up + +First, clone the repo and navigate to the human_input example: + +```bash +git clone https://github.com/lastmile-ai/mcp-agent.git +cd mcp-agent/examples/human_input/temporal +``` + +Install `uv` (if you don’t have it): + +```bash +pip install uv +``` + +## Set up api keys + +In `mcp_agent.secrets.yaml`, set your OpenAI `api_key`. + +## Setting Up Temporal Server + +Before running this example, you need to have a Temporal server running: + +1. Install the Temporal CLI by following the instructions at: https://docs.temporal.io/cli/ + +2. Start a local Temporal server: + ```bash + temporal server start-dev + ``` + +This will start a Temporal server on `localhost:7233` (the default address configured in `mcp_agent.config.yaml`). + +You can use the Temporal Web UI to monitor your workflows by visiting `http://localhost:8233` in your browser. + +## Run locally + +In three separate terminal windows, run the following: + +```bash +# this runs the mcp app +uv run main.py +``` + +```bash +# this runs the temporal worker that will execute the workflows +uv run worker.py +``` + +```bash +# this runs the client +uv run client.py +``` + +You will be prompted for input after the agent makes the initial tool call. + +## Details + +Notice how in `main.py` the `human_input_callback` is set to `elicitation_input_callback`. +This makes sure that human input is sought via elicitation. +In `client.py`, on the other hand, it is set to `console_elicitation_callback`. +This way, the client will prompt for input in the console whenever an upstream request for human input is made. + +The following diagram shows the components involved and the flow of requests and responses. + +```plaintext +┌──────────┐ +│ LLM │ +│ │ +└──────────┘ + ▲ + │ + 1 + │ + ▼ +┌──────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ +│ Temporal │───2──▶│ MCP App │◀──3──▶│ Client │◀──4──▶│ User │ +│ worker │◀──5───│ │ │ │ │ (via console)│ +└──────────┘ └──────────────┘ └──────────────┘ └──────────────┘ +``` + +In the diagram, +- (1) uses the tool calling mechanism to call a system-provided tool for human input, +- (2) uses a HTTPS request to tell the MCP App that the workflow wants to make a request, +- (3) uses the MCP protocol for sending the request to the client and receiving the response, +- (4) uses a console prompt to get the input from the user, and +- (5) uses a Temporal signal to send the response back to the workflow. diff --git a/examples/human_input/temporal/client.py b/examples/human_input/temporal/client.py new file mode 100644 index 000000000..8b75a9cce --- /dev/null +++ b/examples/human_input/temporal/client.py @@ -0,0 +1,199 @@ +import asyncio +import time +from mcp_agent.app import MCPApp +from mcp_agent.config import Settings, LoggerSettings, MCPSettings +import yaml +from mcp_agent.elicitation.handler import console_elicitation_callback +from mcp_agent.config import MCPServerSettings +from mcp_agent.core.context import Context +from mcp_agent.mcp.gen_client import gen_client +from datetime import timedelta +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import ClientSession +from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession +from mcp.types import CallToolResult, LoggingMessageNotificationParams +from mcp_agent.human_input.console_handler import console_input_callback +try: + from exceptiongroup import ExceptionGroup as _ExceptionGroup # Python 3.10 backport +except Exception: # pragma: no cover + _ExceptionGroup = None # type: ignore +try: + from anyio import BrokenResourceError as _BrokenResourceError +except Exception: # pragma: no cover + _BrokenResourceError = None # type: ignore + + +async def main(): + # Create MCPApp to get the server registry, with console handlers + # IMPORTANT: This client acts as the “upstream MCP client” for the server. + # When the server requests sampling (sampling/createMessage), the client-side + # MCPApp must be able to service that request locally (approval prompts + LLM call). + # Those client-local flows are not running inside a Temporal workflow, so they + # must use the asyncio executor. If this were set to "temporal", local sampling + # would crash with: "TemporalExecutor.execute must be called from within a workflow". + # + # We programmatically construct Settings here (mirroring examples/basic/mcp_basic_agent/main.py) + # so everything is self-contained in this client: + settings = Settings( + execution_engine="asyncio", + logger=LoggerSettings(level="info"), + mcp=MCPSettings( + servers={ + "basic_agent_server": MCPServerSettings( + name="basic_agent_server", + description="Local workflow server running the basic agent example", + transport="sse", + # Use a routable loopback host; 0.0.0.0 is a bind address, not a client URL + url="http://127.0.0.1:8000/sse", + ) + } + ), + ) + # Load secrets (API keys, etc.) if a secrets file is available and merge into settings. + # We intentionally deep-merge the secrets on top of our base settings so + # credentials are applied without overriding our executor or server endpoint. + try: + secrets_path = Settings.find_secrets() + if secrets_path and secrets_path.exists(): + with open(secrets_path, "r", encoding="utf-8") as f: + secrets_dict = yaml.safe_load(f) or {} + + def _deep_merge(base: dict, overlay: dict) -> dict: + out = dict(base) + for k, v in (overlay or {}).items(): + if k in out and isinstance(out[k], dict) and isinstance(v, dict): + out[k] = _deep_merge(out[k], v) + else: + out[k] = v + return out + + base_dict = settings.model_dump(mode="json") + merged = _deep_merge(base_dict, secrets_dict) + settings = Settings(**merged) + except Exception: + # Best-effort: continue without secrets if parsing fails + pass + app = MCPApp( + name="workflow_mcp_client", + # In the client, we want to use `console_input_callback` to enable direct interaction through the console + human_input_callback=console_input_callback, + elicitation_callback=console_elicitation_callback, + settings=settings, + ) + async with app.run() as client_app: + logger = client_app.logger + context = client_app.context + + # Connect to the workflow server + try: + logger.info("Connecting to workflow server...") + + # Server connection is configured via Settings above (no runtime mutation needed) + + # Connect to the workflow server + # Define a logging callback to receive server-side log notifications + async def on_server_log(params: LoggingMessageNotificationParams) -> None: + # Pretty-print server logs locally for demonstration + level = params.level.upper() + name = params.logger or "server" + # params.data can be any JSON-serializable data + print(f"[SERVER LOG] [{level}] [{name}] {params.data}") + + # Provide a client session factory that installs our logging callback + # and prints non-logging notifications to the console + class ConsolePrintingClientSession(MCPAgentClientSession): + async def _received_notification(self, notification): # type: ignore[override] + try: + method = getattr(notification.root, "method", None) + except Exception: + method = None + + # Avoid duplicating server log prints (handled by logging_callback) + if method and method != "notifications/message": + try: + data = notification.model_dump() + except Exception: + data = str(notification) + print(f"[SERVER NOTIFY] {method}: {data}") + + return await super()._received_notification(notification) + + def make_session( + read_stream: MemoryObjectReceiveStream, + write_stream: MemoryObjectSendStream, + read_timeout_seconds: timedelta | None, + context: Context | None = None, + ) -> ClientSession: + return ConsolePrintingClientSession( + read_stream=read_stream, + write_stream=write_stream, + read_timeout_seconds=read_timeout_seconds, + logging_callback=on_server_log, + context=context, + ) + + # Connect to the workflow server + async with gen_client( + "basic_agent_server", + context.server_registry, + client_session_factory=make_session, + ) as server: + # Ask server to send logs at the requested level (default info) + level = "info" + print(f"[client] Setting server logging level to: {level}") + try: + await server.set_logging_level(level) + except Exception: + # Older servers may not support logging capability + print("[client] Server does not support logging/setLevel") + + # Call the `greet` tool defined via `@app.tool` + run_result = await server.call_tool( + "greet", + arguments={} + ) + print(f"[client] Workflow run result: {run_result}") + except Exception as e: + # Tolerate benign shutdown races from SSE client (BrokenResourceError within ExceptionGroup) + if _ExceptionGroup is not None and isinstance(e, _ExceptionGroup): + subs = getattr(e, "exceptions", []) or [] + if ( + _BrokenResourceError is not None + and subs + and all(isinstance(se, _BrokenResourceError) for se in subs) + ): + logger.debug("Ignored BrokenResourceError from SSE shutdown") + else: + raise + elif _BrokenResourceError is not None and isinstance( + e, _BrokenResourceError + ): + logger.debug("Ignored BrokenResourceError from SSE shutdown") + elif "BrokenResourceError" in str(e): + logger.debug( + "Ignored BrokenResourceError from SSE shutdown (string match)" + ) + else: + raise + + +def _tool_result_to_json(tool_result: CallToolResult): + if tool_result.content and len(tool_result.content) > 0: + text = tool_result.content[0].text + try: + # Try to parse the response as JSON if it's a string + import json + + return json.loads(text) + except (json.JSONDecodeError, TypeError): + # If it's not valid JSON, just use the text + return None + + +if __name__ == "__main__": + start = time.time() + asyncio.run(main()) + end = time.time() + t = end - start + + print(f"Total run time: {t:.2f}s") diff --git a/examples/human_input/temporal/main.py b/examples/human_input/temporal/main.py new file mode 100644 index 000000000..e2858b17d --- /dev/null +++ b/examples/human_input/temporal/main.py @@ -0,0 +1,83 @@ +""" +Example demonstrating how to use the elicitation-based human input handler +for Temporal workflows. + +This example shows how the new handler enables LLMs to request user input +when running in Temporal workflows by routing requests through the MCP +elicitation framework instead of direct console I/O. +""" +import asyncio +from mcp_agent.app import MCPApp +from mcp_agent.human_input.elicitation_handler import elicitation_input_callback + +from mcp_agent.agents.agent import Agent +from mcp_agent.core.context import Context +from mcp_agent.server.app_server import create_mcp_server_for_app +from mcp_agent.workflows.llm.augmented_llm_openai import OpenAIAugmentedLLM + + +# Create a single FastMCPApp instance (which extends MCPApp) +# We don't need to explicitly create a tool for human interaction; providing the human_input_callback will +# automatically create a tool for the agent to use. +app = MCPApp( + name="basic_agent_server", + description="Basic agent server example", + human_input_callback=elicitation_input_callback, # Use elicitation handler for human input in temporal workflows +) + + +@app.tool +async def greet(app_ctx: Context | None = None) -> str: + """ + Run the basic agent workflow using the app.tool decorator to set up the workflow. + The code in this function is run in workflow context. + LLM calls are executed in the activity context. + You can use the app_ctx to access the executor to run activities explicitly. + Functions decorated with @app.workflow_task will be run in activity context. + + Args: + input: none + + Returns: + str: The greeting result from the agent + """ + + app = app_ctx.app + + logger = app.logger + logger.info("[workflow-mode] Running greet_tool") + + greeting_agent = Agent( + name="greeter", + instruction="""You are a friendly assistant.""", + server_names=[], + ) + + async with greeting_agent: + finder_llm = await greeting_agent.attach_llm(OpenAIAugmentedLLM) + + result = await finder_llm.generate_str( + message="Ask the user for their name and greet them.", + ) + logger.info("[workflow-mode] greet_tool agent result", data={"result": result}) + + return result + + +async def main(): + async with app.run() as agent_app: + # Log registered workflows and agent configurations + agent_app.logger.info(f"Creating MCP server for {agent_app.name}") + + agent_app.logger.info("Registered workflows:") + for workflow_id in agent_app.workflows: + agent_app.logger.info(f" - {workflow_id}") + # Create the MCP server that exposes both workflows and agent configurations + mcp_server = create_mcp_server_for_app(agent_app) + + # Run the server + await mcp_server.run_sse_async() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/human_input/temporal/mcp_agent.config.yaml b/examples/human_input/temporal/mcp_agent.config.yaml new file mode 100644 index 000000000..186222535 --- /dev/null +++ b/examples/human_input/temporal/mcp_agent.config.yaml @@ -0,0 +1,22 @@ +$schema: ../../../../schema/mcp-agent.config.schema.json + +execution_engine: temporal + +temporal: + host: "localhost:7233" # Default Temporal server address + namespace: "default" # Default Temporal namespace + task_queue: "mcp-agent" # Task queue for workflows and activities + max_concurrent_activities: 10 # Maximum number of concurrent activities + +logger: + transports: [file] + level: debug + path_settings: + path_pattern: "logs/mcp-agent-{unique_id}.jsonl" + unique_id: "timestamp" # Options: "timestamp" or "session_id" + timestamp_format: "%Y%m%d_%H%M%S" + +openai: + # Secrets (API keys, etc.) are stored in an mcp_agent.secrets.yaml file which can be gitignored + # default_model: "o3-mini" + default_model: "gpt-4o-mini" diff --git a/examples/human_input/temporal/mcp_agent.secrets.yaml.example b/examples/human_input/temporal/mcp_agent.secrets.yaml.example new file mode 100644 index 000000000..930cf3648 --- /dev/null +++ b/examples/human_input/temporal/mcp_agent.secrets.yaml.example @@ -0,0 +1,7 @@ +$schema: ../../../../schema/mcp-agent.config.schema.json + +openai: + api_key: openai_api_key + +anthropic: + api_key: anthropic_api_key diff --git a/examples/human_input/temporal/requirements.txt b/examples/human_input/temporal/requirements.txt new file mode 100644 index 000000000..5f239ce9d --- /dev/null +++ b/examples/human_input/temporal/requirements.txt @@ -0,0 +1,7 @@ +# Core framework dependency +mcp-agent + +# Additional dependencies specific to this example +anthropic +openai +temporalio diff --git a/examples/human_input/temporal/worker.py b/examples/human_input/temporal/worker.py new file mode 100644 index 000000000..39b2a3c67 --- /dev/null +++ b/examples/human_input/temporal/worker.py @@ -0,0 +1,31 @@ +""" +Worker script for the Temporal workflow example. +This script starts a Temporal worker that can execute workflows and activities. +Run this script in a separate terminal window before running the main.py script. + +This leverages the TemporalExecutor's start_worker method to handle the worker setup. +""" + +import asyncio +import logging + + +from mcp_agent.executor.temporal import create_temporal_worker_for_app + +from main import app + +# Initialize logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def main(): + """ + Start a Temporal worker for the example workflows using the app's executor. + """ + async with create_temporal_worker_for_app(app) as worker: + await worker.run() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/mcp/mcp_elicitation/main.py b/examples/mcp/mcp_elicitation/main.py index 1dcce620b..2d091501f 100644 --- a/examples/mcp/mcp_elicitation/main.py +++ b/examples/mcp/mcp_elicitation/main.py @@ -3,7 +3,7 @@ from mcp_agent.app import MCPApp from mcp_agent.agents.agent import Agent -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback from mcp_agent.elicitation.handler import console_elicitation_callback from mcp_agent.workflows.llm.augmented_llm_openai import OpenAIAugmentedLLM diff --git a/examples/mcp/mcp_elicitation/temporal/client.py b/examples/mcp/mcp_elicitation/temporal/client.py index b6c4d114c..6a26e6c74 100644 --- a/examples/mcp/mcp_elicitation/temporal/client.py +++ b/examples/mcp/mcp_elicitation/temporal/client.py @@ -14,7 +14,8 @@ from mcp import ClientSession from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession from mcp.types import CallToolResult, LoggingMessageNotificationParams -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback + try: from exceptiongroup import ExceptionGroup as _ExceptionGroup # Python 3.10 backport except Exception: # pragma: no cover diff --git a/examples/mcp_agent_server/asyncio/client.py b/examples/mcp_agent_server/asyncio/client.py index 6c229098c..3bdee2a95 100644 --- a/examples/mcp_agent_server/asyncio/client.py +++ b/examples/mcp_agent_server/asyncio/client.py @@ -12,7 +12,7 @@ from mcp_agent.executor.workflow import WorkflowExecution from mcp_agent.mcp.gen_client import gen_client from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback from mcp_agent.elicitation.handler import console_elicitation_callback from rich import print diff --git a/examples/mcp_agent_server/asyncio/main.py b/examples/mcp_agent_server/asyncio/main.py index 5542851fe..68f7574bc 100644 --- a/examples/mcp_agent_server/asyncio/main.py +++ b/examples/mcp_agent_server/asyncio/main.py @@ -25,7 +25,7 @@ from mcp_agent.workflows.parallel.parallel_llm import ParallelLLM from mcp_agent.executor.workflow import Workflow, WorkflowResult from mcp_agent.tracing.token_counter import TokenNode -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback from mcp_agent.elicitation.handler import console_elicitation_callback from mcp_agent.mcp.gen_client import gen_client from mcp_agent.config import MCPServerSettings diff --git a/examples/mcp_agent_server/temporal/main.py b/examples/mcp_agent_server/temporal/main.py index 538ecfb8f..95ff95d5d 100644 --- a/examples/mcp_agent_server/temporal/main.py +++ b/examples/mcp_agent_server/temporal/main.py @@ -21,7 +21,7 @@ from mcp_agent.core.context import Context from mcp_agent.elicitation.handler import console_elicitation_callback from mcp_agent.executor.workflow import Workflow, WorkflowResult -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback from mcp_agent.mcp.gen_client import gen_client from mcp_agent.server.app_server import create_mcp_server_for_app from mcp_agent.workflows.llm.augmented_llm_openai import OpenAIAugmentedLLM diff --git a/examples/usecases/mcp_realtor_agent/main.py b/examples/usecases/mcp_realtor_agent/main.py index 744618fb7..d8d09681f 100644 --- a/examples/usecases/mcp_realtor_agent/main.py +++ b/examples/usecases/mcp_realtor_agent/main.py @@ -12,7 +12,7 @@ from datetime import datetime from mcp_agent.app import MCPApp from mcp_agent.agents.agent import Agent -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback from mcp_agent.elicitation.handler import console_elicitation_callback from mcp_agent.workflows.orchestrator.orchestrator import Orchestrator from mcp_agent.workflows.llm.augmented_llm import RequestParams diff --git a/examples/workflows/workflow_swarm/main.py b/examples/workflows/workflow_swarm/main.py index d6c090b31..46813d449 100644 --- a/examples/workflows/workflow_swarm/main.py +++ b/examples/workflows/workflow_swarm/main.py @@ -5,7 +5,7 @@ from mcp_agent.app import MCPApp from mcp_agent.workflows.swarm.swarm import DoneAgent, SwarmAgent from mcp_agent.workflows.swarm.swarm_anthropic import AnthropicSwarm -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback app = MCPApp( name="airline_customer_service", human_input_callback=console_input_callback diff --git a/src/mcp_agent/data/examples/mcp_agent_server/elicitation/client.py b/src/mcp_agent/data/examples/mcp_agent_server/elicitation/client.py index 5bad8e3b2..bbba1167f 100644 --- a/src/mcp_agent/data/examples/mcp_agent_server/elicitation/client.py +++ b/src/mcp_agent/data/examples/mcp_agent_server/elicitation/client.py @@ -16,7 +16,7 @@ from mcp_agent.config import Settings from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession from mcp_agent.mcp.gen_client import gen_client -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback from mcp_agent.elicitation.handler import console_elicitation_callback from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession diff --git a/src/mcp_agent/data/examples/mcp_agent_server/elicitation/server.py b/src/mcp_agent/data/examples/mcp_agent_server/elicitation/server.py index 57fbcf83a..2243764ab 100644 --- a/src/mcp_agent/data/examples/mcp_agent_server/elicitation/server.py +++ b/src/mcp_agent/data/examples/mcp_agent_server/elicitation/server.py @@ -15,7 +15,7 @@ from mcp_agent.app import MCPApp from mcp_agent.core.context import Context as AppContext from mcp_agent.server.app_server import create_mcp_server_for_app -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback from mcp_agent.elicitation.handler import console_elicitation_callback from mcp.types import ElicitRequestedSchema from pydantic import BaseModel, Field diff --git a/src/mcp_agent/data/examples/mcp_agent_server/reference/client.py b/src/mcp_agent/data/examples/mcp_agent_server/reference/client.py index 9ed0747fb..da3f639f2 100644 --- a/src/mcp_agent/data/examples/mcp_agent_server/reference/client.py +++ b/src/mcp_agent/data/examples/mcp_agent_server/reference/client.py @@ -19,7 +19,7 @@ from mcp_agent.core.context import Context from mcp_agent.config import Settings from mcp_agent.mcp.gen_client import gen_client -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback from mcp_agent.elicitation.handler import console_elicitation_callback from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession diff --git a/src/mcp_agent/data/examples/mcp_agent_server/reference/server.py b/src/mcp_agent/data/examples/mcp_agent_server/reference/server.py index 447f8d335..a2aeb8447 100644 --- a/src/mcp_agent/data/examples/mcp_agent_server/reference/server.py +++ b/src/mcp_agent/data/examples/mcp_agent_server/reference/server.py @@ -25,7 +25,7 @@ from mcp_agent.app import MCPApp from mcp_agent.core.context import Context as AppContext from mcp_agent.server.app_server import create_mcp_server_for_app -from mcp_agent.human_input.handler import console_input_callback +from mcp_agent.human_input.console_handler import console_input_callback from mcp_agent.elicitation.handler import console_elicitation_callback from mcp_agent.agents.agent import Agent diff --git a/src/mcp_agent/human_input/handler.py b/src/mcp_agent/human_input/console_handler.py similarity index 100% rename from src/mcp_agent/human_input/handler.py rename to src/mcp_agent/human_input/console_handler.py diff --git a/src/mcp_agent/human_input/elicitation_handler.py b/src/mcp_agent/human_input/elicitation_handler.py new file mode 100644 index 000000000..1aea42360 --- /dev/null +++ b/src/mcp_agent/human_input/elicitation_handler.py @@ -0,0 +1,123 @@ +import asyncio + +import mcp.types as types +from mcp_agent.human_input.types import HumanInputRequest, HumanInputResponse +from mcp_agent.logging.logger import get_logger + +logger = get_logger(__name__) + + +def _create_elicitation_message(request: HumanInputRequest) -> str: + """Convert HumanInputRequest to elicitation message format.""" + message = request.prompt + if request.description: + message = f"{request.description}\n\n{message}" + + return message + + +def _handle_elicitation_response( + result: types.ElicitResult, + request: HumanInputRequest +) -> HumanInputResponse: + """Convert ElicitResult back to HumanInputResponse.""" + request_id = request.request_id or "" + + # Handle different action types + if result.action == "accept": + if result.content and isinstance(result.content, dict): + response_text = result.content.get("response", "") + + # Handle slash commands that might be in the response + response_text = response_text.strip() + if response_text.lower() in ["/decline", "/cancel"]: + return HumanInputResponse(request_id=request_id, response=response_text.lower()) + + return HumanInputResponse(request_id=request_id, response=response_text) + else: + # Fallback if content is not in expected format + return HumanInputResponse(request_id=request_id, response="") + + elif result.action == "decline": + return HumanInputResponse(request_id=request_id, response="decline") + + elif result.action == "cancel": + return HumanInputResponse(request_id=request_id, response="cancel") + + else: + # Unknown action, treat as cancel + logger.warning(f"Unknown elicitation action: {result.action}") + return HumanInputResponse(request_id=request_id, response="cancel") + + +async def elicitation_input_callback(request: HumanInputRequest) -> HumanInputResponse: + """ + Handle human input requests using MCP elicitation. + """ + + # Try to get the context and session proxy + try: + from mcp_agent.core.context import get_current_context + context = get_current_context() + if context is None: + raise RuntimeError("No context available for elicitation") + except Exception: + raise RuntimeError("No context available for elicitation") + + upstream_session = context.upstream_session + + if not upstream_session: + raise RuntimeError("Session required for elicitation") + + try: + message = _create_elicitation_message(request) + + logger.debug( + "Sending elicitation request for human input", + data={ + "request_id": request.request_id, + "description": request.description, + "timeout_seconds": request.timeout_seconds + } + ) + + # Send the elicitation request + result = await upstream_session.elicit( + message=message, + requestedSchema={ + "type": "object", + "properties": { + "response": { + "type": "string", + "description": "The response or input" + } + }, + "required": ["response"] + }, + related_request_id=request.request_id + ) + + # Convert the result back to HumanInputResponse + response = _handle_elicitation_response(result, request) + + logger.debug( + "Received elicitation response for human input", + data={ + "request_id": request.request_id, + "action": result.action, + "response_length": len(response.response) + } + ) + + return response + + except asyncio.TimeoutError: + logger.warning(f"Elicitation timeout for request {request.request_id}") + raise TimeoutError("No response received within timeout period") from None + + except Exception as e: + logger.error( + f"Elicitation failed for human input request {request.request_id}", + data={"error": str(e)} + ) + raise RuntimeError(f"Elicitation failed: {e}") from e diff --git a/tests/human_input/test_elicitation_handler.py b/tests/human_input/test_elicitation_handler.py new file mode 100644 index 000000000..d9686d29b --- /dev/null +++ b/tests/human_input/test_elicitation_handler.py @@ -0,0 +1,159 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock +import mcp.types as types +from mcp_agent.executor.temporal.session_proxy import SessionProxy +from mcp_agent.human_input.types import HumanInputRequest, HumanInputResponse +from mcp_agent.human_input.elicitation_handler import ( + elicitation_input_callback, + _create_elicitation_message, + _handle_elicitation_response +) + + +class TestElicitationHandler: + """Test the elicitation-based human input handler.""" + + def test_create_elicitation_message_basic(self): + """Test basic message creation.""" + request = HumanInputRequest(prompt="Please enter your name") + message = _create_elicitation_message(request) + + assert "Please enter your name" in message + + def test_create_elicitation_message_with_description(self): + """Test message creation with description.""" + request = HumanInputRequest( + prompt="Enter your name", + description="We need your name for the booking" + ) + message = _create_elicitation_message(request) + + assert "We need your name for the booking" in message + assert "Enter your name" in message + + def test_create_elicitation_message_with_timeout(self): + """Test message creation with timeout.""" + request = HumanInputRequest( + prompt="Enter your name", + timeout_seconds=30 + ) + message = _create_elicitation_message(request) + + assert "Enter your name" in message + assert "Timeout" not in message + assert "30" not in message + + def test_handle_elicitation_response_accept(self): + """Test handling accept response.""" + request = HumanInputRequest(prompt="Test", request_id="test-123") + result = types.ElicitResult( + action="accept", + content={"response": "John Doe"} + ) + + response = _handle_elicitation_response(result, request) + + assert isinstance(response, HumanInputResponse) + assert response.request_id == "test-123" + assert response.response == "John Doe" + + def test_handle_elicitation_response_decline(self): + """Test handling decline response.""" + request = HumanInputRequest(prompt="Test", request_id="test-123") + result = types.ElicitResult(action="decline") + + response = _handle_elicitation_response(result, request) + + assert response.request_id == "test-123" + assert response.response == "decline" + + def test_handle_elicitation_response_cancel(self): + """Test handling cancel response.""" + request = HumanInputRequest(prompt="Test", request_id="test-123") + result = types.ElicitResult(action="cancel") + + response = _handle_elicitation_response(result, request) + + assert response.request_id == "test-123" + assert response.response == "cancel" + + + @pytest.mark.asyncio + async def test_elicitation_input_callback_success(self): + """Test successful elicitation callback.""" + # Mock the context and session proxy + mock_context = MagicMock() + mock_session = AsyncMock(spec=SessionProxy) + + # Mock the elicit method to return a successful response + mock_session.elicit.return_value = types.ElicitResult( + action="accept", + content={"response": "Test response"} + ) + + mock_context.upstream_session = mock_session + + # Mock get_current_context() to return our mock context + with pytest.MonkeyPatch.context() as m: + m.setattr("mcp_agent.core.context.get_current_context", lambda: mock_context) + + request = HumanInputRequest( + prompt="Please enter something", + request_id="test-123" + ) + + response = await elicitation_input_callback(request) + + assert isinstance(response, HumanInputResponse) + assert response.request_id == "test-123" + assert response.response == "Test response" + + # Verify the session proxy was called correctly + mock_session.elicit.assert_called_once() + call_args = mock_session.elicit.call_args + assert "Please enter something" in call_args.kwargs["message"] + assert call_args.kwargs["related_request_id"] == "test-123" + + @pytest.mark.asyncio + async def test_elicitation_input_callback_no_context(self): + """Test callback when no context is available.""" + with pytest.MonkeyPatch.context() as m: + m.setattr("mcp_agent.core.context.get_current_context", lambda: None) + + request = HumanInputRequest(prompt="Test") + + with pytest.raises(RuntimeError, match="No context available"): + await elicitation_input_callback(request) + + @pytest.mark.asyncio + async def test_elicitation_input_callback_no_session(self): + """Test callback when SessionProxy is not available.""" + mock_context = MagicMock() + mock_context.upstream_session = None + + with pytest.MonkeyPatch.context() as m: + m.setattr("mcp_agent.core.context.get_current_context", lambda: mock_context) + + request = HumanInputRequest(prompt="Test") + + with pytest.raises(RuntimeError, match="Session required for elicitation"): + await elicitation_input_callback(request) + + @pytest.mark.asyncio + async def test_elicitation_input_callback_elicit_failure(self): + """Test callback when elicitation fails.""" + mock_context = MagicMock() + mock_session = AsyncMock(spec=SessionProxy) + + # Mock the elicit method to raise an exception + mock_session.elicit.side_effect = Exception("Elicitation failed") + + mock_context.upstream_session = mock_session + + with pytest.MonkeyPatch.context() as m: + m.setattr("mcp_agent.core.context.get_current_context", lambda: mock_context) + + request = HumanInputRequest(prompt="Test") + + with pytest.raises(RuntimeError, match="Elicitation failed"): + await elicitation_input_callback(request)