diff --git a/examples/mcp/mcp_prompts_and_resources/README.md b/examples/mcp/mcp_prompts_and_resources/README.md index 64befc6f4..94d0ba478 100644 --- a/examples/mcp/mcp_prompts_and_resources/README.md +++ b/examples/mcp/mcp_prompts_and_resources/README.md @@ -21,9 +21,7 @@ This example demonstrates **both resources and prompts**. - **Resources:** - `demo://docs/readme`: A sample README file (Markdown) - - `demo://config/settings`: Example configuration settings (JSON) - - `demo://data/users`: Example user data (JSON) - - `demo://status/health`: Dynamic server health/status info (JSON) + - `demo://data/friends`: Example user data (JSON) - **Prompt:** - `echo`: A simple prompt that echoes back the provided message diff --git a/examples/mcp/mcp_prompts_and_resources/demo_server.py b/examples/mcp/mcp_prompts_and_resources/demo_server.py index abef5f4cc..6ff086737 100644 --- a/examples/mcp/mcp_prompts_and_resources/demo_server.py +++ b/examples/mcp/mcp_prompts_and_resources/demo_server.py @@ -1,54 +1,14 @@ from mcp.server.fastmcp import FastMCP -import datetime +from mcp.types import ModelPreferences, ModelHint, SamplingMessage, TextContent import json -# Store server start time -SERVER_START_TIME = datetime.datetime.utcnow() - mcp = FastMCP("Resource Demo MCP Server") -# Define some static resources -STATIC_RESOURCES = { - "demo://docs/readme": { - "name": "README", - "description": "A sample README file.", - "content_type": "text/markdown", - "content": "# Demo Resource Server\n\nThis is a sample README resource provided by the demo MCP server.", - }, - "demo://data/users": { - "name": "User Data", - "description": "Sample user data in JSON format.", - "content_type": "application/json", - "content": json.dumps( - [ - {"id": 1, "name": "Alice"}, - {"id": 2, "name": "Bob"}, - {"id": 3, "name": "Charlie"}, - ], - indent=2, - ), - }, -} - @mcp.resource("demo://docs/readme") def get_readme(): """Provide the README file content.""" - meta = STATIC_RESOURCES["demo://docs/readme"] - return meta["content"] - - -@mcp.resource("demo://data/users") -def get_users(): - """Provide user data.""" - meta = STATIC_RESOURCES["demo://data/users"] - return meta["content"] - - -@mcp.resource("demo://{city}/weather") -def get_weather(city: str) -> str: - """Provide a simple weather report for a given city.""" - return f"It is sunny in {city} today!" + return "# Demo Resource Server\n\nThis is a sample README resource provided by the demo MCP server." @mcp.prompt() @@ -60,6 +20,53 @@ def echo(message: str) -> str: return f"Prompt: {message}" +@mcp.resource("demo://data/friends") +def get_users(): + """Provide my friend list.""" + return ( + json.dumps( + [ + {"id": 1, "friend": "Alice"}, + ] + ) + ) + + +@mcp.prompt() +def get_haiku_prompt(topic: str) -> str: + """Get a haiku prompt about a given topic.""" + return f"I am fascinated about {topic}. Can you generate a haiku combining {topic} + my friend name?" + + +@mcp.tool() +async def get_haiku(topic: str) -> str: + """Get a haiku about a given topic.""" + haiku = await mcp.get_context().session.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent( + type="text", text=f"Generate a haiku about {topic}." + ), + ) + ], + system_prompt="You are a poet.", + max_tokens=100, + temperature=0.7, + model_preferences=ModelPreferences( + hints=[ModelHint(name="gpt-4o-mini")], + costPriority=0.1, + speedPriority=0.8, + intelligencePriority=0.1, + ), + ) + + if isinstance(haiku.content, TextContent): + return haiku.content.text + else: + return "Haiku generation failed, unexpected content type." + + def main(): """Main entry point for the MCP server.""" mcp.run() diff --git a/examples/mcp/mcp_prompts_and_resources/main.py b/examples/mcp/mcp_prompts_and_resources/main.py index 013e1bdf1..5bd067a61 100644 --- a/examples/mcp/mcp_prompts_and_resources/main.py +++ b/examples/mcp/mcp_prompts_and_resources/main.py @@ -11,6 +11,8 @@ ) from mcp_agent.agents.agent import Agent from mcp_agent.workflows.llm.augmented_llm_openai import OpenAIAugmentedLLM +from mcp_agent.human_input.handler import console_input_callback + settings = Settings( execution_engine="asyncio", @@ -30,7 +32,9 @@ # Settings can either be specified programmatically, # or loaded from mcp_agent.config.yaml/mcp_agent.secrets.yaml -app = MCPApp(name="mcp_basic_agent") # settings=settings) +app = MCPApp( + name="mcp_basic_agent", human_input_callback=console_input_callback +) # settings=settings) async def example_usage(): @@ -71,13 +75,16 @@ async def example_usage(): ) llm = await agent.attach_llm(OpenAIAugmentedLLM) - res = await llm.generate_str( + summary = await llm.generate_str( [ "Summarise what are my prompts and resources?", *combined_messages, ] ) - logger.info(f"Summary: {res}") + logger.info(f"Summary: {summary}") + + haiku = await llm.generate_str("Write me a haiku") + logger.info(f"Haiku: {haiku}") if __name__ == "__main__": diff --git a/examples/mcp/mcp_sampling/README.md b/examples/mcp/mcp_sampling/README.md new file mode 100644 index 000000000..c47c1cb88 --- /dev/null +++ b/examples/mcp/mcp_sampling/README.md @@ -0,0 +1,101 @@ +# MCP Sampling Example + +This example demonstrates how to use **MCP sampling** in an agent application. +It shows how to connect to an MCP server that exposes a tool that uses a sampling request to generate a response. + +--- + +## What is MCP sampling? +Sampling in MCP allows servers to implement agentic behaviors, by enabling LLM calls to occur nested inside other MCP server features. +Following the MCP recommendations, users are prompted to approve sampling requests, as well as the output produced by the LLM for the sampling request. +More details can be found in the [MCP documentation](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling). + +This example demonstrates sampling using [MCP agent servers](https://github.com/lastmile-ai/mcp-agent/blob/main/examples/mcp_agent_server/README.md). +It is also possible to use sampling when explicitly creating an MCP client. The code for that would look like the following: + +```python +settings = ... # MCP agent configuration +registry = ServerRegistry(settings) + +@mcp.tool() +async def my_tool(input: str, ctx: Context) -> str: + async with gen_client("my_server", registry, upstream_session=ctx.session) as my_client: + result = await my_client.call_tool("some_tool", {"input": input}) + ... # etc +``` + +--- + +## Example Overview + +- **nested_server.py** implements a simple MCP server that uses sampling to generate a haiku about a given topic +- **demo_server.py** implements a simple MCP server that implements an agent generating haikus using the tool exposed by `nested_server.py` +- **main.py** shows how to: + 1. Connect an agent to the demo MCP server, and then + 2. Invoke the agent implemented by the demo MCP server, thereby triggering a sampling request. + +--- + +## Architecture + +```plaintext +┌────────────────────┐ +│ nested_server │──────┐ +│ MCP Server │ │ +└─────────┬──────────┘ │ + │ │ + ▼ │ +┌────────────────────┐ │ +│ demo_server │ │ +│ MCP Server │ │ +└─────────┬──────────┘ │ + │ sampling, via user approval + ▼ │ +┌────────────────────┐ │ +│ Agent (Python) │ │ +│ + LLM (OpenAI) │◀─────┘ +└─────────┬──────────┘ + │ + ▼ + [User/Developer] +``` + +--- + +## 1. Setup + +Clone the repo and navigate to this example: + +```bash +git clone https://github.com/lastmile-ai/mcp-agent.git +cd mcp-agent/examples/mcp/mcp_sampling +``` + +--- + +## 2. Run the Agent Example + +Run the agent script which should auto install all necessary dependencies: + +```bash +uv run main.py +``` + +You should see logs showing: + +- The agent connecting to the demo server, and calling the tool +- A request to approve the sampling request; type `approve` to approve (anything else will deny the request) +- A request to approve the result of the sampling request +- The final result of the tool call + +--- + +## References + +- [Model Context Protocol (MCP) Introduction](https://modelcontextprotocol.io/introduction) +- [MCP Agent Framework](https://github.com/lastmile-ai/mcp-agent) +- [MCP Server Sampling](https://modelcontextprotocol.io/specification/2025-06-18/client/sampling) + +--- + +This example is a minimal, practical demonstration of how to use **MCP sampling** as first-class context for agent applications. diff --git a/examples/mcp/mcp_sampling/asyncio/demo_server.py b/examples/mcp/mcp_sampling/asyncio/demo_server.py new file mode 100644 index 000000000..b4605d52c --- /dev/null +++ b/examples/mcp/mcp_sampling/asyncio/demo_server.py @@ -0,0 +1,120 @@ +""" +A simple workflow server which generates haikus on request using a tool. +""" + +import asyncio +import logging + +import yaml +from mcp.server.fastmcp import FastMCP + +from mcp_agent.app import MCPApp +from mcp_agent.config import Settings, LoggerSettings, MCPSettings, MCPServerSettings, LogPathSettings +from mcp_agent.server.app_server import create_mcp_server_for_app +from mcp_agent.agents.agent import Agent +from mcp_agent.workflows.llm.augmented_llm_openai import OpenAIAugmentedLLM +from mcp_agent.executor.workflow import Workflow, WorkflowResult + +# Initialize logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Note: This is purely optional: +# if not provided, a default FastMCP server will be created by MCPApp using create_mcp_server_for_app() +mcp = FastMCP(name="haiku_generation_server", description="Server to generate haikus") + + +# Create settings explicitly, as we want to use a different configuration from the main app +secrets_file = Settings.find_secrets() +if secrets_file and secrets_file.exists(): + with open(secrets_file, "r", encoding="utf-8") as f: + yaml_secrets = yaml.safe_load(f) or {} + openai_secret = yaml_secrets["openai"] + + +settings = Settings( + execution_engine="asyncio", + logger=LoggerSettings( + type="file", + level="debug", + path_settings=LogPathSettings( + path_pattern="logs/demo_server-{unique_id}.jsonl", + unique_id="timestamp", + timestamp_format="%Y%m%d_%H%M%S"), + ), + mcp=MCPSettings( + servers={ + "haiku_server": MCPServerSettings( + command="uv", + args=["run", "nested_server.py"], + description="nested server providing a haiku generator" + ) + } + ), + openai=openai_secret +) + +# Define the MCPApp instance +app = MCPApp( + name="haiku_server", + description="Haiku server", + mcp=mcp, + settings=settings +) + +@app.workflow +class HaikuWorkflow(Workflow[str]): + """ + A workflow that generates haikus on request. + """ + + @app.workflow_run + async def run(self, input: str) -> WorkflowResult[str]: + """ + Run the haiku agent workflow. + + Args: + input: The topic to create a haiku about + + Returns: + WorkflowResult containing the processed data. + """ + + logger = app.logger + + haiku_agent = Agent( + name="poet", + instruction="""You are an agent with access to a tool that helps you write haikus.""", + server_names=["haiku_server"], + ) + + async with haiku_agent: + llm = await haiku_agent.attach_llm(OpenAIAugmentedLLM) + + result = await llm.generate_str( + message=f"Write a haiku about {input} using the tool at your disposal", + ) + logger.info(f"Input: {input}, Result: {result}") + + return WorkflowResult(value=result) + + +async def main(): + async with app.run() as agent_app: + # Log registered workflows and agent configurations + logger.info(f"Creating MCP server for {agent_app.name}") + + logger.info("Registered workflows:") + for workflow_id in agent_app.workflows: + logger.info(f" - {workflow_id}") + + # Create the MCP server that exposes both workflows and agent configurations + mcp_server = create_mcp_server_for_app(agent_app, **({})) + logger.info(f"MCP Server settings: {mcp_server.settings}") + + # Run the server + await mcp_server.run_stdio_async() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/mcp/mcp_sampling/asyncio/main.py b/examples/mcp/mcp_sampling/asyncio/main.py new file mode 100644 index 000000000..60fbfb014 --- /dev/null +++ b/examples/mcp/mcp_sampling/asyncio/main.py @@ -0,0 +1,45 @@ +import asyncio +import time + +from mcp_agent.app import MCPApp +from mcp_agent.agents.agent import Agent +from mcp_agent.workflows.llm.augmented_llm_openai import OpenAIAugmentedLLM +from mcp_agent.human_input.handler import console_input_callback + +# Settings can either be specified programmatically, +# or loaded from mcp_agent.config.yaml/mcp_agent.secrets.yaml +app = MCPApp( + name="mcp_basic_agent", human_input_callback=console_input_callback +) + + +async def example_usage(): + async with app.run() as agent_app: + logger = agent_app.logger + + # --- Example: Using the demo_server MCP server --- + agent = Agent( + name="agent", + instruction="Demo agent for MCP sampling", + server_names=["demo_server"], + ) + + async with agent: + llm = await agent.attach_llm(OpenAIAugmentedLLM) + + # using the MCP server with sampling + haiku = await llm.generate_str("Write me a haiku about flowers") + logger.info(f"Generated haiku: {haiku}") + + # not using sampling + definition = await llm.generate_str("What does the acronym MCP stand for in the context of generative AI?") + logger.info(f"{definition}") + + +if __name__ == "__main__": + start = time.time() + asyncio.run(example_usage()) + end = time.time() + t = end - start + + print(f"Total run time: {t:.2f}s") diff --git a/examples/mcp/mcp_sampling/asyncio/mcp_agent.config.yaml b/examples/mcp/mcp_sampling/asyncio/mcp_agent.config.yaml new file mode 100644 index 000000000..7d7278c52 --- /dev/null +++ b/examples/mcp/mcp_sampling/asyncio/mcp_agent.config.yaml @@ -0,0 +1,24 @@ +$schema: ../../../schema/mcp-agent.config.schema.json + +execution_engine: asyncio + +logger: + transports: [console, file] + level: debug + progress_display: true + path_settings: + path_pattern: "logs/mcp-agent-{unique_id}.jsonl" + unique_id: "timestamp" # Options: "timestamp" or "session_id" + timestamp_format: "%Y%m%d_%H%M%S" + +mcp: + servers: + demo_server: + command: "uv" + args: ["run", "demo_server.py"] + description: "Demo MCP server for resources and prompts" + +openai: + # Secrets (API keys, etc.) are stored in an mcp_agent.secrets.yaml file which can be gitignored + # default_model: "o3-mini" + default_model: "gpt-4o-mini" diff --git a/examples/mcp/mcp_sampling/asyncio/nested_server.py b/examples/mcp/mcp_sampling/asyncio/nested_server.py new file mode 100644 index 000000000..80cead647 --- /dev/null +++ b/examples/mcp/mcp_sampling/asyncio/nested_server.py @@ -0,0 +1,43 @@ +from mcp.server.fastmcp import FastMCP +from mcp.types import ModelPreferences, ModelHint, SamplingMessage, TextContent + +mcp = FastMCP("Haiku demo server") + + +@mcp.tool() +async def get_haiku(topic: str) -> str: + """Use sampling to generate a haiku about the given topic.""" + + haiku = await mcp.get_context().session.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent( + type="text", text=f"Generate a quirky haiku about {topic}." + ), + ) + ], + system_prompt="You are a poet.", + max_tokens=100, + temperature=0.7, + model_preferences=ModelPreferences( + hints=[ModelHint(name="gpt-4o-mini")], + costPriority=0.1, + speedPriority=0.8, + intelligencePriority=0.1, + ), + ) + + if isinstance(haiku.content, TextContent): + return haiku.content.text + else: + return "Haiku generation failed, unexpected content type." + + +def main(): + """Main entry point for the MCP server.""" + mcp.run() + + +if __name__ == "__main__": + main() diff --git a/examples/mcp/mcp_sampling/requirements.txt b/examples/mcp/mcp_sampling/requirements.txt new file mode 100644 index 000000000..133b9c325 --- /dev/null +++ b/examples/mcp/mcp_sampling/requirements.txt @@ -0,0 +1,7 @@ +# Core framework dependency +mcp-agent @ file://../../../ # Link to the local mcp-agent project root + +# Additional dependencies specific to this example +anthropic +openai +temporalio diff --git a/examples/mcp/mcp_sampling/temporal/client.py b/examples/mcp/mcp_sampling/temporal/client.py new file mode 100644 index 000000000..74b0d6065 --- /dev/null +++ b/examples/mcp/mcp_sampling/temporal/client.py @@ -0,0 +1,135 @@ +import asyncio +import json +import time +from mcp.types import CallToolResult +from mcp_agent.app import MCPApp +from mcp_agent.config import MCPServerSettings +from mcp_agent.executor.workflow import WorkflowExecution +from mcp_agent.mcp.gen_client import gen_client +from mcp_agent.human_input.handler import console_input_callback + + +async def main(): + # Create MCPApp to get the server registry + app = MCPApp(name="workflow_mcp_client", human_input_callback=console_input_callback) + async with app.run() as client_app: + logger = client_app.logger + context = client_app.context + + # Connect to the workflow server + logger.info("Connecting to workflow server...") + + # Override the server configuration to point to our local script + context.server_registry.registry["demo_server"] = MCPServerSettings( + name="demo_server", + description="Local workflow server running the basic agent example", + transport="sse", + url="http://0.0.0.0:8000/sse", + ) + + # Connect to the workflow server + async with gen_client("demo_server", context.server_registry, context=context) as server: + # Call the BasicAgentWorkflow + logger.info(f"{type(server)}") + run_result = await server.call_tool( + "workflows-HaikuWorkflow-run", + arguments={ + "run_parameters": { + "input": "space exploration" + } + }, + ) + + execution = WorkflowExecution(**json.loads(run_result.content[0].text)) + run_id = execution.run_id + logger.info( + f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" + ) + + # Wait for the workflow to complete + while True: + get_status_result = await server.call_tool( + "workflows-HaikuWorkflow-get_status", + arguments={"run_id": run_id}, + ) + + workflow_status = _tool_result_to_json(get_status_result) + if workflow_status is None: + logger.error( + f"Failed to parse workflow status response: {get_status_result}" + ) + break + + logger.info( + f"Workflow run {run_id} status:", + data=workflow_status, + ) + + if not workflow_status.get("status"): + logger.error( + f"Workflow run {run_id} status is empty. get_status_result:", + data=get_status_result, + ) + break + + if workflow_status.get("status") == "completed": + logger.info( + f"Workflow run {run_id} completed successfully! Result:", + data=workflow_status.get("result"), + ) + + break + elif workflow_status.get("status") == "error": + logger.error( + f"Workflow run {run_id} failed with error:", + data=workflow_status, + ) + break + elif workflow_status.get("status") == "running": + logger.info( + f"Workflow run {run_id} is still running...", + ) + elif workflow_status.get("status") == "cancelled": + logger.error( + f"Workflow run {run_id} was cancelled.", + data=workflow_status, + ) + break + else: + logger.error( + f"Unknown workflow status: {workflow_status.get('status')}", + data=workflow_status, + ) + break + + await asyncio.sleep(5) + + # TODO: UNCOMMENT ME to try out cancellation: + # await server.call_tool( + # "workflows-cancel", + # arguments={"workflow_id": "BasicAgentWorkflow", "run_id": run_id}, + # ) + + print(run_result) + + +def _tool_result_to_json(tool_result: CallToolResult): + if tool_result.content and len(tool_result.content) > 0: + text = tool_result.content[0].text + try: + # Try to parse the response as JSON if it's a string + import json + + return json.loads(text) + except (json.JSONDecodeError, TypeError): + # If it's not valid JSON, just use the text + return None + + +if __name__ == "__main__": + start = time.time() + asyncio.run(main()) + end = time.time() + t = end - start + + print(f"Total run time: {t:.2f}s") diff --git a/examples/mcp/mcp_sampling/temporal/demo_server_sse.py b/examples/mcp/mcp_sampling/temporal/demo_server_sse.py new file mode 100644 index 000000000..22c77c195 --- /dev/null +++ b/examples/mcp/mcp_sampling/temporal/demo_server_sse.py @@ -0,0 +1,135 @@ +""" +A simple workflow server which generates haikus on request using a tool. +""" + +import asyncio +import logging + +import yaml +from mcp.server.fastmcp import FastMCP + +from mcp_agent.app import MCPApp +from mcp_agent.config import Settings, LoggerSettings, MCPSettings, MCPServerSettings, LogPathSettings, TemporalSettings +from mcp_agent.server.app_server import create_mcp_server_for_app +from mcp_agent.agents.agent import Agent +from mcp_agent.workflows.llm.augmented_llm_openai import OpenAIAugmentedLLM +from mcp_agent.executor.workflow import Workflow, WorkflowResult + +# Initialize logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# from http.client import HTTPConnection +# +# HTTPConnection.debuglevel = 1 +# +# logging.basicConfig() +# logging.getLogger().setLevel(logging.DEBUG) +# requests_log = logging.getLogger("requests.packages.urllib3") +# requests_log.setLevel(logging.DEBUG) +# requests_log.propagate = True + +# Note: This is purely optional: +# if not provided, a default FastMCP server will be created by MCPApp using create_mcp_server_for_app() +mcp = FastMCP(name="haiku_generation_server", description="Server to generate haikus") + +# Create settings explicitly, as we want to use a different configuration from the main app +secrets_file = Settings.find_secrets() +if secrets_file and secrets_file.exists(): + with open(secrets_file, "r", encoding="utf-8") as f: + yaml_secrets = yaml.safe_load(f) or {} + openai_secret = yaml_secrets["openai"] + +settings = Settings( + execution_engine="temporal", + temporal=TemporalSettings( + host="localhost:7233", + namespace="default", + task_queue="mcp-agent", + max_concurrent_activities=10, + ), + logger=LoggerSettings( + type="file", + level="debug", + path_settings=LogPathSettings( + path_pattern="logs/demo_server_sse-{unique_id}.jsonl", + unique_id="timestamp", + timestamp_format="%Y%m%d_%H%M%S"), + ), + mcp=MCPSettings( + servers={ + "haiku_server": MCPServerSettings( + command="uv", + args=["run", "nested_server.py"], + description="nested server providing a haiku generator" + ) + } + ), + openai=openai_secret +) + +# Define the MCPApp instance +app = MCPApp( + name="haiku_server", + description="Haiku server", + mcp=mcp, + settings=settings +) + +@app.workflow +class HaikuWorkflow(Workflow[str]): + """ + A workflow that generates haikus on request. + """ + + @app.workflow_run + async def run(self, input: str) -> WorkflowResult[str]: + """ + Run the haiku agent workflow. + + Args: + input: The topic to create a haiku about + + Returns: + WorkflowResult containing the processed data. + """ + + logger = app.logger + + logger.info("Running HaikuWorkflow") + haiku_agent = Agent( + name="poet", + instruction="""You are an agent with access to a tool that helps you write haikus.""", + server_names=["haiku_server"], + ) + + async with haiku_agent: + llm = await haiku_agent.attach_llm(OpenAIAugmentedLLM) + + result = await llm.generate_str( + message=f"Write a haiku about {input} using the tool at your disposal", + ) + logger.info(f"Input: {input}, Result: {result}") + + return WorkflowResult(value=result) + + +async def main(): + async with app.run() as agent_app: + # Log registered workflows and agent configurations + logger.info(f"Creating MCP server for {agent_app.name}") + + logger.info("Registered workflows:") + for workflow_id in agent_app.workflows: + logger.info(f" - {workflow_id}") + + # Create the MCP server that exposes both workflows and agent configurations + mcp_server = create_mcp_server_for_app(agent_app, **({})) + logger.info(f"MCP Server settings: {mcp_server.settings}") + + # Run the server + await mcp_server.run_sse_async() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/examples/mcp/mcp_sampling/temporal/demo_server_sse_worker.py b/examples/mcp/mcp_sampling/temporal/demo_server_sse_worker.py new file mode 100644 index 000000000..cf0433f5c --- /dev/null +++ b/examples/mcp/mcp_sampling/temporal/demo_server_sse_worker.py @@ -0,0 +1,40 @@ +""" +Worker script for the Temporal workflow example. +This script starts a Temporal worker that can execute workflows and activities. +Run this script in a separate terminal window before running the main.py script. + +This leverages the TemporalExecutor's start_worker method to handle the worker setup. +""" + +import asyncio +import logging + + +from mcp_agent.executor.temporal import create_temporal_worker_for_app + +from demo_server_sse import app + +# Initialize logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# from http.client import HTTPConnection +# +# HTTPConnection.debuglevel = 1 +# +# logging.basicConfig() +# logging.getLogger().setLevel(logging.DEBUG) +# requests_log = logging.getLogger("requests.packages.urllib3") +# requests_log.setLevel(logging.DEBUG) +# requests_log.propagate = True + +async def main(): + """ + Start a Temporal worker for the example workflows using the app's executor. + """ + async with create_temporal_worker_for_app(app) as worker: + await worker.run() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/mcp/mcp_sampling/temporal/mcp_agent.config.yaml b/examples/mcp/mcp_sampling/temporal/mcp_agent.config.yaml new file mode 100644 index 000000000..508119593 --- /dev/null +++ b/examples/mcp/mcp_sampling/temporal/mcp_agent.config.yaml @@ -0,0 +1,23 @@ +$schema: ../../../schema/mcp-agent.config.schema.json + +execution_engine: "asyncio" + +logger: + transports: [console, file] + level: debug + progress_display: true + path_settings: + path_pattern: "logs/mcp-agent-{unique_id}.jsonl" + unique_id: "timestamp" # Options: "timestamp" or "session_id" + timestamp_format: "%Y%m%d_%H%M%S" + +mcp: + servers: + demo_server_sse: + transport: sse + url: http://localhost:8000/sse + +openai: + # Secrets (API keys, etc.) are stored in an mcp_agent.secrets.yaml file which can be gitignored + # default_model: "o3-mini" + default_model: "gpt-4o-mini" diff --git a/examples/mcp/mcp_sampling/temporal/nested_server.py b/examples/mcp/mcp_sampling/temporal/nested_server.py new file mode 100644 index 000000000..f864b6792 --- /dev/null +++ b/examples/mcp/mcp_sampling/temporal/nested_server.py @@ -0,0 +1,65 @@ +from mcp.server.fastmcp import FastMCP +from mcp.types import ModelPreferences, ModelHint, SamplingMessage, TextContent +from mcp_agent.app import MCPApp +from mcp_agent.config import Settings, LoggerSettings, LogPathSettings + +mcp = FastMCP("Haiku demo server") + +settings = Settings( + execution_engine="asyncio", + logger=LoggerSettings( + type="file", + level="debug", + path_settings=LogPathSettings( + path_pattern="asyncio/logs/nested_server-{unique_id}.jsonl", + unique_id="timestamp", + timestamp_format="%Y%m%d_%H%M%S"), + ), +) + +app = MCPApp( + name="haiku_agent", + settings=settings, +) + + +@mcp.tool() +async def get_haiku(topic: str) -> str: + """Use sampling to generate a haiku about the given topic.""" + + app.logger.info(f"Generating haiku about topic: {topic} via sampling") + haiku = await mcp.get_context().session.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent( + type="text", text=f"Generate a quirky haiku about {topic}." + ), + ) + ], + system_prompt="You are a poet.", + max_tokens=100, + temperature=0.7, + model_preferences=ModelPreferences( + hints=[ModelHint(name="gpt-4o-mini")], + costPriority=0.1, + speedPriority=0.8, + intelligencePriority=0.1, + ), + ) + + if isinstance(haiku.content, TextContent): + app.logger.info(f"Generated haiku: {haiku.content.text}") + return haiku.content.text + else: + app.logger.error(f"Haiku generation failed, unexpected content type: {haiku.content}") + return "Haiku generation failed, unexpected content type." + + +def main(): + """Main entry point for the MCP server.""" + mcp.run() + + +if __name__ == "__main__": + main() diff --git a/src/mcp_agent/agents/agent.py b/src/mcp_agent/agents/agent.py index e55289d82..153eb7404 100644 --- a/src/mcp_agent/agents/agent.py +++ b/src/mcp_agent/agents/agent.py @@ -190,6 +190,13 @@ async def attach_llm( value = getattr(self.llm, attr, None) if value is not None: span.set_attribute(f"llm.{attr}", value) + + # Ensure a context exists before updating active LLM + if self.context is None: + # Fall back to global context for convenience; callers can also set agent.context explicitly + from mcp_agent.core.context import get_current_context + self.context = get_current_context() + return self.llm async def get_token_node(self, return_all_matches: bool = False): diff --git a/src/mcp_agent/core/context.py b/src/mcp_agent/core/context.py index d449c938a..6767f196c 100644 --- a/src/mcp_agent/core/context.py +++ b/src/mcp_agent/core/context.py @@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict -from mcp import ServerSession from mcp.server.fastmcp import FastMCP +from mcp.server.session import ServerSession from opentelemetry import trace diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index 7e0eed92d..e2b2e264f 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -204,6 +204,7 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution": Special kwargs that are extracted and not passed to run(): - __mcp_agent_workflow_id: Optional workflow ID to use (instead of auto-generating) - __mcp_agent_task_queue: Optional task queue to use (instead of default from config) + - __mcp_agent_workflow_memo: the memo passed to the temporal workflow Returns: WorkflowExecution: The execution details including run ID and workflow ID diff --git a/src/mcp_agent/mcp/gen_client.py b/src/mcp_agent/mcp/gen_client.py index b3e13d6ba..c35feb880 100644 --- a/src/mcp_agent/mcp/gen_client.py +++ b/src/mcp_agent/mcp/gen_client.py @@ -5,6 +5,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession +from mcp_agent.core.context import Context from mcp_agent.logging.logger import get_logger from mcp_agent.mcp.mcp_server_registry import ServerRegistry from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession @@ -21,6 +22,7 @@ async def gen_client( ClientSession, ] = MCPAgentClientSession, session_id: str | None = None, + context: Context | None = None, ) -> AsyncGenerator[ClientSession, None]: """ Create a client session to the specified server. @@ -37,6 +39,7 @@ async def gen_client( server_name=server_name, client_session_factory=client_session_factory, session_id=session_id, + context=context ) as session: yield session diff --git a/src/mcp_agent/mcp/mcp_agent_client_session.py b/src/mcp_agent/mcp/mcp_agent_client_session.py index 425a0e298..04b55c6a1 100644 --- a/src/mcp_agent/mcp/mcp_agent_client_session.py +++ b/src/mcp_agent/mcp/mcp_agent_client_session.py @@ -40,7 +40,6 @@ Implementation, JSONRPCMessage, ServerRequest, - TextContent, ListRootsResult, NotificationParams, RequestParams, @@ -62,6 +61,7 @@ MCP_TOOL_NAME, ) from mcp_agent.tracing.telemetry import get_tracer, record_attributes +from mcp_agent.mcp.sampling_handler import SamplingHandler if TYPE_CHECKING: from mcp_agent.core.context import Context @@ -116,6 +116,7 @@ def __init__( ) self.server_config: Optional[MCPServerSettings] = None + self._sampling_handler = SamplingHandler(context=self.context) # Session ID handling for Streamable HTTP transport self._get_session_id_callback: Optional[Callable[[], str | None]] = None @@ -334,46 +335,17 @@ async def _handle_sampling_callback( context: RequestContext["ClientSession", Any], params: CreateMessageRequestParams, ) -> CreateMessageResult | ErrorData: - logger.info("Handling sampling request: %s", params) - config = self.context.config + logger.debug(f"Handling sampling request: {params}") server_session = self.context.upstream_session - if server_session is None: - # TODO: saqadri - consider whether we should be handling the sampling request here as a client - logger.warning( - "Error: No upstream client available for sampling requests. Request:", - data=params, - ) - try: - from anthropic import AsyncAnthropic - - client = AsyncAnthropic(api_key=config.anthropic.api_key) - - response = await client.messages.create( - model="claude-3-sonnet-20240229", - max_tokens=params.maxTokens, - messages=[ - { - "role": m.role, - "content": m.content.text - if hasattr(m.content, "text") - else m.content.data, - } - for m in params.messages - ], - system=getattr(params, "systemPrompt", None), - temperature=getattr(params, "temperature", 0.7), - stop_sequences=getattr(params, "stopSequences", None), - ) - return CreateMessageResult( - model="claude-3-sonnet-20240229", - role="assistant", - content=TextContent(type="text", text=response.content[0].text), - ) - except Exception as e: - logger.error(f"Error handling sampling request: {e}") - return ErrorData(code=-32603, message=str(e)) + if server_session is None: + # Enhanced sampling with human approval workflow + logger.debug("No upstream server session, handling sampling locally") + return await self._sampling_handler.handle_sampling( + context=self.context, + params=params) else: + logger.debug("Passing sampling request to upstream server session") try: # If a server_session is available, we'll pass-through the sampling request to the upstream client result = await server_session.send_request( diff --git a/src/mcp_agent/mcp/mcp_aggregator.py b/src/mcp_agent/mcp/mcp_aggregator.py index 80a3e1f12..a48a2eba7 100644 --- a/src/mcp_agent/mcp/mcp_aggregator.py +++ b/src/mcp_agent/mcp/mcp_aggregator.py @@ -1223,6 +1223,7 @@ def getter(item: NamespacedResource): # No match found return None, None + async def _start_server(self, server_name: str): if self.connection_persistence: logger.info( diff --git a/src/mcp_agent/mcp/mcp_server_registry.py b/src/mcp_agent/mcp/mcp_server_registry.py index 9f832781e..7ef0be1d3 100644 --- a/src/mcp_agent/mcp/mcp_server_registry.py +++ b/src/mcp_agent/mcp/mcp_server_registry.py @@ -9,7 +9,8 @@ from contextlib import asynccontextmanager from datetime import timedelta -from typing import Callable, Dict, AsyncGenerator +from typing import Callable, Dict, AsyncGenerator, Optional, TYPE_CHECKING + from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp import ClientSession @@ -33,6 +34,9 @@ from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession from mcp_agent.mcp.mcp_connection_manager import MCPConnectionManager +if TYPE_CHECKING: + from mcp_agent.core.context import Context + logger = get_logger(__name__) InitHookCallable = Callable[[ClientSession | None, MCPServerAuthSettings | None], bool] @@ -86,7 +90,7 @@ def __init__(self, config: Settings | None = None, config_path: str | None = Non self.connection_manager = MCPConnectionManager(self) def load_registry_from_file( - self, config_path: str | None = None + self, config_path: str | None = None ) -> Dict[str, MCPServerSettings]: """ Load the YAML configuration file and validate it. @@ -101,15 +105,17 @@ def load_registry_from_file( servers = get_settings(config_path).mcp.servers or {} return servers + @asynccontextmanager async def start_server( - self, - server_name: str, - client_session_factory: Callable[ - [MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None], - ClientSession, - ] = ClientSession, - session_id: str | None = None, + self, + server_name: str, + client_session_factory: Callable[ + [MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None], + ClientSession, + ] = ClientSession, + session_id: str | None = None, + context: Optional["Context"] = None, ) -> AsyncGenerator[ClientSession, None]: """ Starts the server process based on its configuration. To initialize, call initialize_server @@ -151,6 +157,7 @@ async def start_server( read_stream, write_stream, read_timeout_seconds, + context=context ) async with session: logger.info( @@ -198,12 +205,13 @@ async def start_server( # For Streamable HTTP, we get an additional callback for session ID async with streamablehttp_client( - **kwargs, + **kwargs, ) as (read_stream, write_stream, session_id_callback): session = client_session_factory( read_stream, write_stream, read_timeout_seconds, + context=context, ) if session_id_callback and isinstance(session, MCPAgentClientSession): @@ -236,13 +244,14 @@ async def start_server( # Use sse_client to get the read and write streams async with sse_client(**kwargs) as ( - read_stream, - write_stream, + read_stream, + write_stream, ): session = client_session_factory( read_stream, write_stream, read_timeout_seconds, + context=context ) async with session: logger.info( @@ -260,13 +269,14 @@ async def start_server( ) async with websocket_client(url=config.url) as ( # pylint: disable=W0135 - read_stream, - write_stream, + read_stream, + write_stream, ): session = client_session_factory( read_stream, write_stream, read_timeout_seconds, + context=context, ) async with session: logger.info( @@ -282,14 +292,15 @@ async def start_server( @asynccontextmanager async def initialize_server( - self, - server_name: str, - client_session_factory: Callable[ - [MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None], - ClientSession, - ] = ClientSession, - init_hook: InitHookCallable = None, - session_id: str | None = None, + self, + server_name: str, + client_session_factory: Callable[ + [MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None], + ClientSession, + ] = ClientSession, + init_hook: InitHookCallable = None, + session_id: str | None = None, + context: Optional["Context"] = None, ) -> AsyncGenerator[ClientSession, None]: """ Initialize a server based on its configuration. @@ -312,9 +323,10 @@ async def initialize_server( config = self.registry[server_name] async with self.start_server( - server_name, - client_session_factory=client_session_factory, - session_id=session_id, + server_name, + client_session_factory=client_session_factory, + session_id=session_id, + context=context, ) as session: try: logger.info(f"{server_name}: Initializing server...") diff --git a/src/mcp_agent/mcp/sampling_handler.py b/src/mcp_agent/mcp/sampling_handler.py new file mode 100644 index 000000000..ea3104691 --- /dev/null +++ b/src/mcp_agent/mcp/sampling_handler.py @@ -0,0 +1,366 @@ +""" +MCP Agent Sampling Handler + +Handles sampling requests from MCP servers with human-in-the-loop approval workflow +and direct LLM provider integration. +""" + +from typing import TYPE_CHECKING +from uuid import uuid4 + +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + ErrorData, + TextContent, + ImageContent, + SamplingMessage, CreateMessageRequest, ServerRequest, +) +from mcp.server.fastmcp.exceptions import ToolError + +from mcp_agent.core.context_dependent import ContextDependent +from mcp_agent.logging.logger import get_logger +from mcp_agent.workflows.llm.augmented_llm import ( + AugmentedLLM, + RequestParams as LLMRequestParams, +) +from mcp_agent.workflows.llm.llm_selector import ModelSelector + +if TYPE_CHECKING: + from mcp_agent.core.context import Context + +logger = get_logger(__name__) + + +class SamplingHandler(ContextDependent): + """Handles MCP sampling requests with human approval workflow and LLM generation""" + + def __init__(self, context: "Context"): + super().__init__(context=context) + + async def handle_sampling(self, + context: "Context", + params: CreateMessageRequestParams) -> CreateMessageResult | ErrorData: + logger.debug(f"Handling sampling request: {params}") + server_session = context.upstream_session + + if server_session is None: + # Enhanced sampling with human approval workflow + logger.debug("No upstream server session, handling sampling locally") + return await self._handle_sampling_with_human_approval(params) + else: + logger.debug("Passing sampling request to upstream server session") + try: + # If a server_session is available, we'll pass-through the sampling request to the upstream client + result = await server_session.send_request( + request=ServerRequest( + CreateMessageRequest( + method="sampling/createMessage", params=params + ) + ), + result_type=CreateMessageResult, + ) + + logger.debug(f"Received sampling response from upstream server: {result}") + + # Pass the result from the upstream client back to the server. We just act as a pass-through client here + return result + except Exception as e: + return ErrorData(code=-32603, message=str(e)) + + async def _handle_sampling_with_human_approval( + self, params: CreateMessageRequestParams + ) -> CreateMessageResult | ErrorData: + """Handle sampling with human-in-the-loop approval workflow""" + try: + # Stage 1: Human approval/modification of the request + approved_params, rejection_reason = await self._request_human_approval_for_sampling_request( + params + ) + if approved_params is None: + logger.info(f"Sampling request rejected by user: {rejection_reason}") + return ErrorData( + code=-32603, message=f"Sampling request rejected by user: {rejection_reason}" + ) + + # Stage 2: Generate response using available LLM providers + llm_result = await self._generate_with_active_llm_provider(approved_params) + if llm_result is None: + return ErrorData(code=-32603, message="Failed to generate a response") + + # Stage 3: Human approval/modification of the response + final_result, rejection_reason = await self._request_human_approval_for_sampling_response( + llm_result + ) + if final_result is None: + logger.info(f"Sampling response rejected by user: {rejection_reason}") + return ErrorData(code=-32603, message=f"Response rejected by user: {rejection_reason}") + + return final_result + + except Exception as e: + logger.error(f"Error in sampling with human approval: {e}") + return ErrorData(code=-32603, message=str(e)) + + async def _request_human_approval_for_sampling_request( + self, params: CreateMessageRequestParams + ) -> tuple[CreateMessageRequestParams | None, str]: + """Present sampling request to user for approval/modification""" + try: + if not self.context.human_input_handler: + logger.warning( + "No human input handler available, auto-approving request" + ) + return params, "" + + request_summary = self._format_sampling_request_for_human(params) + + from mcp_agent.human_input.types import HumanInputRequest + + request_id = f"sampling_request_{uuid4()}" + + request = HumanInputRequest( + prompt=f"""MCP server is requesting LLM completion. Please review and approve/reject: + +{request_summary} + +Respond with: +- 'approve' to proceed with the request as-is +- Anything else to reject (your input will be used as the rejection reason)""", + description="MCP Sampling Request Approval", + request_id=request_id, + metadata={ + "type": "sampling_request_approval", + "original_params": params.model_dump(), + }, + ) + + response = await self.context.human_input_handler(request) + return self._parse_human_modified_params(response.response, params) + + except Exception as e: + logger.error(f"Error requesting human approval for sampling request: {e}") + return params, "" # Fallback to original params + + async def _request_human_approval_for_sampling_response( + self, result: CreateMessageResult + ) -> tuple[CreateMessageResult | None, str]: + """Present LLM response to user for approval/modification""" + try: + if not self.context.human_input_handler: + logger.warning( + "No human input handler available, auto-approving response" + ) + return result, "" + + response_summary = self._format_sampling_response_for_human(result) + + from mcp_agent.human_input.types import HumanInputRequest + + request_id = f"sampling_response_{uuid4()}" + + request = HumanInputRequest( + prompt=f"""LLM has generated a response. Please review and approve/reject: + +{response_summary} + +Respond with: +- 'approve' to send the response as-is +- Anything else to reject (your input will be used as the rejection reason)""", + description="MCP Sampling Response Approval", + request_id=request_id, + metadata={ + "type": "sampling_response_approval", + "original_result": result.model_dump(), + }, + ) + + response = await self.context.human_input_handler(request) + return self._parse_human_modified_result(response.response, result) + + except Exception as e: + logger.error(f"Error requesting human approval for sampling response: {e}") + return result, "" # Fallback to original result + + async def _generate_with_active_llm_provider( + self, params: CreateMessageRequestParams + ) -> CreateMessageResult | None: + """Generate response using the active LLM provider directly to avoid recursion""" + + try: + # use the active model selector from context or fall back to the defaults + model_selector = self.context.model_selector or ModelSelector() + if params.modelPreferences: + model_info = model_selector.select_best_model(params.modelPreferences) + logger.info(f"Selected model based on preferences {model_info}") + + # break circular dependency by importing here + from mcp_agent.workflows.factory import create_llm + llm = create_llm(agent_name="sampling", + server_names=[], + instruction=None, + provider=model_info.provider, + model=model_info.name, + request_params=None, + context=None) # Do not pass current context. We want a clean LLM + else: + # fall back to default + raise ToolError("Model preferences must be provided for sampling requests") + + messages = self._extract_message_content(params.messages) + request_params = self._build_llm_request_params(params) + + result = await llm.generate_str( + message=messages, request_params=request_params + ) + + logger.info("Successfully generated response") + final_request_params = llm.get_request_params( + self._build_llm_request_params(params) + ) + model_name = await llm.select_model(final_request_params) + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text=result), + model=model_name or "unknown", + ) + + except Exception as e: + import traceback + logger.info(traceback.format_exc()) + logger.error(f"Unexpected error calling LLM: {e}") + return None + + def _create_provider_instance( + self, provider_class: type[AugmentedLLM] + ) -> AugmentedLLM | None: + """Create a minimal LLM instance for direct calls""" + try: + return provider_class(context=self.context) + except Exception as e: + logger.error(f"Failed to create provider instance: {e}") + return None + + def _extract_message_content(self, messages: list[SamplingMessage]) -> list[str]: + """Extract text content from MCP messages""" + extracted = [] + for msg in messages: + content = self._get_message_text(msg.content) + extracted.append(content) + return extracted + + def _get_message_text(self, content: TextContent | ImageContent) -> str: + """Extract text from message content with fallback handling""" + if hasattr(content, "text") and content.text: + return content.text + elif hasattr(content, "data") and content.data: + return str(content.data) + else: + return str(content) + + def _build_llm_request_params( + self, params: CreateMessageRequestParams + ) -> LLMRequestParams: + """Build LLM request parameters with safe defaults""" + return LLMRequestParams( + maxTokens=params.maxTokens or 2048, + temperature=getattr(params, "temperature", 0.7), + max_iterations=1, + parallel_tool_calls=False, + use_history=False, + messages=None, + modelPreferences=params.modelPreferences, + ) + + def _format_sampling_request_for_human( + self, params: CreateMessageRequestParams + ) -> str: + """Format sampling request for human review""" + messages_text = "" + for i, msg in enumerate(params.messages): + content = ( + msg.content.text if hasattr(msg.content, "text") else str(msg.content) + ) + messages_text += f" Message {i + 1} ({msg.role}): {content[:200]}{'...' if len(content) > 200 else ''}\n" + + system_prompt_display = ( + "None" + if params.systemPrompt is None + else ( + f"{params.systemPrompt[:100]}{'...' if len(params.systemPrompt) > 100 else ''}" + ) + ) + + stop_sequences_display = ( + "None" if params.stopSequences is None else str(params.stopSequences) + ) + + model_preferences_display = "None" + if params.modelPreferences is not None: + prefs = [] + if params.modelPreferences.hints: + hints = [ + hint.name + for hint in params.modelPreferences.hints + if hint.name is not None + ] + prefs.append(f"hints: {hints}") + if params.modelPreferences.costPriority is not None: + prefs.append(f"cost: {params.modelPreferences.costPriority}") + if params.modelPreferences.speedPriority is not None: + prefs.append(f"speed: {params.modelPreferences.speedPriority}") + if params.modelPreferences.intelligencePriority is not None: + prefs.append( + f"intelligence: {params.modelPreferences.intelligencePriority}" + ) + model_preferences_display = ", ".join(prefs) if prefs else "None" + + return f"""REQUEST DETAILS: +- Max Tokens: {params.maxTokens} +- System Prompt: {system_prompt_display} +- Temperature: {params.temperature if params.temperature is not None else 0.7} +- Stop Sequences: {stop_sequences_display} +- Model Preferences: {model_preferences_display} + +MESSAGES: +{messages_text}""" + + def _format_sampling_response_for_human(self, result: CreateMessageResult) -> str: + """Format sampling response for human review""" + content = ( + result.content.text + if hasattr(result.content, "text") + else str(result.content) + ) + return f"""RESPONSE DETAILS: +- Model: {result.model} +- Role: {result.role} + +CONTENT: +{content}""" + + def _parse_human_modified_params( + self, response: str, original_params: CreateMessageRequestParams + ) -> tuple[CreateMessageRequestParams | None, str]: + """Parse human response and return modified params or None if rejected""" + response_stripped = response.strip().lower() + + if response_stripped == "approve": + return original_params, "" + else: + # Anything else is treated as rejection with reasoning + rejection_reason = response.strip() + return None, rejection_reason + + def _parse_human_modified_result( + self, response: str, original_result: CreateMessageResult + ) -> tuple[CreateMessageResult | None, str]: + """Parse human response and return modified result or None if rejected""" + response_stripped = response.strip().lower() + + if response_stripped == "approve": + return original_result, "" + else: + # Anything else is treated as rejection with reasoning + rejection_reason = response.strip() + return None, rejection_reason diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 5e289a5e5..3dc308178 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -2,14 +2,13 @@ MCPAgentServer - Exposes MCPApp as MCP server, and mcp-agent workflows and agents as MCP tools. """ - +import asyncio import json from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING import os import secrets -import asyncio from mcp.server.fastmcp import Context as MCPContext, FastMCP from starlette.requests import Request @@ -123,7 +122,6 @@ def workflow_registry(self) -> WorkflowRegistry: """Get the workflow registry for this server context.""" return self.context.workflow_registry - def _get_attached_app(mcp: FastMCP) -> MCPApp | None: """Return the MCPApp instance attached to the FastMCP server, if any.""" return getattr(mcp, "_mcp_agent_app", None) @@ -428,7 +426,6 @@ async def _relay_request(request: Request): EmptyResult, ServerRequest, ) - body = await request.json() execution_id = request.path_params.get("execution_id") method = body.get("method") @@ -603,6 +600,8 @@ async def _internal_human_prompts(request: Request): except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) + + # Create or attach FastMCP server if app.mcp: # Using an externally provided FastMCP instance: attach app and context @@ -917,6 +916,8 @@ async def cancel_workflow( return mcp + + # region per-Workflow Tools @@ -1272,6 +1273,7 @@ def _schema_fn_proxy(*args, **kwargs): run_fn_tool = FastTool.from_function(_schema_fn_proxy) else: run_fn_tool = FastTool.from_function(param_source) + run_fn_tool_params = json.dumps(run_fn_tool.parameters, indent=2) @mcp.tool( diff --git a/tests/test_app.py b/tests/test_app.py index 677163457..b136427db 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -105,7 +105,6 @@ async def test_initialization_minimal(self): assert app.name == "test_app" assert app._human_input_callback is None assert app._signal_notification is None - assert app._upstream_session is None assert app._model_selector is None assert app._workflows == {} assert app._logger is None @@ -129,7 +128,7 @@ async def test_initialization_with_settings_path(self): @pytest.mark.asyncio async def test_initialization_with_callbacks( - self, human_input_callback, signal_notification + self, human_input_callback, signal_notification ): """Test initialization with callbacks.""" app = MCPApp( @@ -141,14 +140,6 @@ async def test_initialization_with_callbacks( assert app._human_input_callback is human_input_callback assert app._signal_notification is signal_notification - @pytest.mark.asyncio - async def test_initialization_with_upstream_session(self): - """Test initialization with upstream session.""" - mock_session = MagicMock() - app = MCPApp(name="test_app", upstream_session=mock_session) - - assert app._upstream_session is mock_session - @pytest.mark.asyncio async def test_initialization_with_model_selector(self): """Test initialization with model selector.""" @@ -202,7 +193,7 @@ async def test_non_windows_event_loop_policy(self, mock_set_policy): async def test_initialize_method(self, basic_app, mock_context): """Test initialize method.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ) as mock_init_context: await basic_app.initialize() @@ -214,7 +205,7 @@ async def test_initialize_method(self, basic_app, mock_context): async def test_initialize_already_initialized(self, basic_app, mock_context): """Test initialize method when already initialized.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ) as mock_init_context: # First initialization await basic_app.initialize() @@ -230,7 +221,7 @@ async def test_initialize_already_initialized(self, basic_app, mock_context): async def test_cleanup_method(self, basic_app, mock_context): """Test cleanup method.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): with patch("mcp_agent.app.cleanup_context", AsyncMock()) as mock_cleanup: await basic_app.initialize() @@ -312,7 +303,7 @@ async def test_run_with_cancelled_cleanup(self, basic_app, mock_context): async def test_context_property_initialized(self, basic_app, mock_context): """Test context property when initialized.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() @@ -328,7 +319,7 @@ async def test_context_property_not_initialized(self, basic_app): async def test_config_property(self, basic_app, mock_context): """Test config property.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() @@ -338,7 +329,7 @@ async def test_config_property(self, basic_app, mock_context): async def test_server_registry_property(self, basic_app, mock_context): """Test server_registry property.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() @@ -348,7 +339,7 @@ async def test_server_registry_property(self, basic_app, mock_context): async def test_executor_property(self, basic_app, mock_context): """Test executor property.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() @@ -358,35 +349,12 @@ async def test_executor_property(self, basic_app, mock_context): async def test_engine_property(self, basic_app, mock_context): """Test engine property.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() assert basic_app.engine is mock_context.executor.execution_engine - @pytest.mark.asyncio - async def test_upstream_session_getter(self, basic_app, mock_context): - """Test upstream_session getter.""" - with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) - ): - await basic_app.initialize() - - assert basic_app.upstream_session is mock_context.upstream_session - - @pytest.mark.asyncio - async def test_upstream_session_setter(self, basic_app, mock_context): - """Test upstream_session setter.""" - with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) - ): - await basic_app.initialize() - - new_session = MagicMock() - basic_app.upstream_session = new_session - - assert mock_context.upstream_session is new_session - @pytest.mark.asyncio async def test_workflows_property(self, basic_app): """Test workflows property.""" @@ -396,7 +364,7 @@ async def test_workflows_property(self, basic_app): async def test_tasks_property(self, basic_app, mock_context): """Test tasks property.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): mock_context.task_registry.list_activities.return_value = ["task1", "task2"] await basic_app.initialize() @@ -435,7 +403,7 @@ async def test_logger_property_with_session_id(self, basic_app, mock_context): # Now initialize the context with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() @@ -459,7 +427,7 @@ async def test_logger_property_with_session_id(self, basic_app, mock_context): @pytest.mark.asyncio async def test_workflow_decorator_default( - self, basic_app, test_workflow, mock_context + self, basic_app, test_workflow, mock_context ): """Test workflow decorator default behavior.""" # Set the context directly instead of patching the property @@ -487,7 +455,7 @@ async def test_workflow_decorator_default( @pytest.mark.asyncio async def test_workflow_decorator_with_id( - self, basic_app, test_workflow, mock_context + self, basic_app, test_workflow, mock_context ): """Test workflow decorator with custom ID.""" # Set the context directly instead of patching the property @@ -516,11 +484,11 @@ async def test_workflow_decorator_with_id( @pytest.mark.asyncio async def test_workflow_decorator_with_engine( - self, basic_app, test_workflow, mock_context + self, basic_app, test_workflow, mock_context ): """Test workflow decorator with execution engine.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() @@ -569,7 +537,7 @@ async def test_fn(): # Calling decorated() returns a coroutine object that we need to await result = await decorated() assert ( - result == "test" + result == "test" ) # Should still return the original function's return value finally: # Reset the app state after the test @@ -580,7 +548,7 @@ async def test_fn(): async def test_workflow_run_decorator_with_engine(self, basic_app, mock_context): """Test workflow_run decorator with execution engine.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() @@ -609,7 +577,7 @@ async def test_fn(): async def test_workflow_task_decorator(self, basic_app, test_task, mock_context): """Test workflow_task decorator.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() @@ -622,8 +590,8 @@ async def test_workflow_task_decorator(self, basic_app, test_task, mock_context) assert decorated.is_workflow_task is True assert hasattr(decorated, "execution_metadata") assert ( - decorated.execution_metadata["activity_name"] - == f"{test_task.__module__}.{test_task.__qualname__}" + decorated.execution_metadata["activity_name"] + == f"{test_task.__module__}.{test_task.__qualname__}" ) # Verify task registration in the app's _task_registry @@ -635,11 +603,11 @@ async def test_workflow_task_decorator(self, basic_app, test_task, mock_context) @pytest.mark.asyncio async def test_workflow_task_decorator_with_name( - self, basic_app, test_task, mock_context + self, basic_app, test_task, mock_context ): """Test workflow_task decorator with custom name.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() @@ -658,11 +626,11 @@ async def test_workflow_task_decorator_with_name( @pytest.mark.asyncio async def test_workflow_task_decorator_with_timeout( - self, basic_app, test_task, mock_context + self, basic_app, test_task, mock_context ): """Test workflow_task decorator with custom timeout.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() @@ -674,8 +642,8 @@ async def test_workflow_task_decorator_with_timeout( # Verification assert ( - decorated.execution_metadata["schedule_to_close_timeout"] - == custom_timeout + decorated.execution_metadata["schedule_to_close_timeout"] + == custom_timeout ) # Verify task registration in the app's _task_registry @@ -685,17 +653,17 @@ async def test_workflow_task_decorator_with_timeout( registered_task = basic_app._task_registry.get_activity(activity_name) assert registered_task is decorated assert ( - registered_task.execution_metadata["schedule_to_close_timeout"] - == custom_timeout + registered_task.execution_metadata["schedule_to_close_timeout"] + == custom_timeout ) @pytest.mark.asyncio async def test_workflow_task_decorator_with_retry_policy( - self, basic_app, test_task, mock_context + self, basic_app, test_task, mock_context ): """Test workflow_task decorator with custom retry policy.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize() @@ -730,7 +698,7 @@ def non_async_fn(param): async def test_is_workflow_task_method(self, basic_app, test_task, mock_context): """Test is_workflow_task method.""" with patch( - "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) + "mcp_agent.app.initialize_context", AsyncMock(return_value=mock_context) ): await basic_app.initialize()