diff --git a/examples/mcp_agent_server/asyncio/basic_agent_server.py b/examples/mcp_agent_server/asyncio/basic_agent_server.py index ff498fcad..6ad9a6002 100644 --- a/examples/mcp_agent_server/asyncio/basic_agent_server.py +++ b/examples/mcp_agent_server/asyncio/basic_agent_server.py @@ -188,7 +188,6 @@ async def grade_story(story: str, app_ctx: Optional[AppContext] = None) -> str: async def grade_story_async(story: str, app_ctx: Optional[AppContext] = None) -> str: """ Async variant of grade_story that starts a workflow run and returns IDs. - Args: story: The student's short story to grade app_ctx: Optional MCPApp context for accessing app resources and logging diff --git a/examples/mcp_agent_server/temporal/basic_agent_server.py b/examples/mcp_agent_server/temporal/basic_agent_server.py index 5a29ab4f0..f4368bb54 100644 --- a/examples/mcp_agent_server/temporal/basic_agent_server.py +++ b/examples/mcp_agent_server/temporal/basic_agent_server.py @@ -59,12 +59,20 @@ async def run( context = app.context context.config.mcp.servers["filesystem"].args.extend([os.getcwd()]) + # Use of the app.logger will forward logs back to the mcp client + app_logger = app.logger + + app_logger.info("Starting finder agent") async with finder_agent: finder_llm = await finder_agent.attach_llm(OpenAIAugmentedLLM) result = await finder_llm.generate_str( message=input, ) + + # forwards the log to the caller + app_logger.info(f"Finder agent completed with result {result}") + # print to the console (for when running locally) print(f"Agent result: {result}") return WorkflowResult(value=result) diff --git a/examples/mcp_agent_server/temporal/client.py b/examples/mcp_agent_server/temporal/client.py index dd5f9b1ee..c945e1cd2 100644 --- a/examples/mcp_agent_server/temporal/client.py +++ b/examples/mcp_agent_server/temporal/client.py @@ -1,12 +1,18 @@ import asyncio import json import time -from mcp.types import CallToolResult +import argparse from mcp_agent.app import MCPApp from mcp_agent.config import MCPServerSettings from mcp_agent.executor.workflow import WorkflowExecution from mcp_agent.mcp.gen_client import gen_client +from datetime import timedelta +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp import ClientSession +from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession +from mcp.types import CallToolResult, LoggingMessageNotificationParams + try: from exceptiongroup import ExceptionGroup as _ExceptionGroup # Python 3.10 backport except Exception: # pragma: no cover @@ -18,28 +24,68 @@ async def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--server-log-level", + type=str, + default=None, + help="Set server logging level (debug, info, notice, warning, error, critical, alert, emergency)", + ) + args = parser.parse_args() # Create MCPApp to get the server registry app = MCPApp(name="workflow_mcp_client") async with app.run() as client_app: logger = client_app.logger context = client_app.context - # Connect to the workflow server - logger.info("Connecting to workflow server...") - - # Override the server configuration to point to our local script - context.server_registry.registry["basic_agent_server"] = MCPServerSettings( - name="basic_agent_server", - description="Local workflow server running the basic agent example", - transport="sse", - url="http://0.0.0.0:8000/sse", - ) - # Connect to the workflow server try: + logger.info("Connecting to workflow server...") + + # Override the server configuration to point to our local script + context.server_registry.registry["basic_agent_server"] = MCPServerSettings( + name="basic_agent_server", + description="Local workflow server running the basic agent example", + transport="sse", + url="http://0.0.0.0:8000/sse", + ) + + # Connect to the workflow server + # Define a logging callback to receive server-side log notifications + async def on_server_log(params: LoggingMessageNotificationParams) -> None: + # Pretty-print server logs locally for demonstration + level = params.level.upper() + name = params.logger or "server" + # params.data can be any JSON-serializable data + print(f"[SERVER LOG] [{level}] [{name}] {params.data}") + + # Provide a client session factory that installs our logging callback + def make_session( + read_stream: MemoryObjectReceiveStream, + write_stream: MemoryObjectSendStream, + read_timeout_seconds: timedelta | None, + ) -> ClientSession: + return MCPAgentClientSession( + read_stream=read_stream, + write_stream=write_stream, + read_timeout_seconds=read_timeout_seconds, + logging_callback=on_server_log, + ) + + # Connect to the workflow server async with gen_client( - "basic_agent_server", context.server_registry + "basic_agent_server", + context.server_registry, + client_session_factory=make_session, ) as server: + # Ask server to send logs at the requested level (default info) + level = (args.server_log_level or "info").lower() + print(f"[client] Setting server logging level to: {level}") + try: + await server.set_logging_level(level) + except Exception: + # Older servers may not support logging capability + print("[client] Server does not support logging/setLevel") # Call the BasicAgentWorkflow run_result = await server.call_tool( "workflows-BasicAgentWorkflow-run", @@ -56,6 +102,17 @@ async def main(): f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" ) + get_status_result = await server.call_tool( + "workflows-BasicAgentWorkflow-get_status", + arguments={"run_id": run_id}, + ) + + execution = WorkflowExecution(**json.loads(run_result.content[0].text)) + run_id = execution.run_id + logger.info( + f"Started BasicAgentWorkflow-run. workflow ID={execution.workflow_id}, run ID={run_id}" + ) + # Wait for the workflow to complete while True: get_status_result = await server.call_tool( diff --git a/pyproject.toml b/pyproject.toml index efa4f5ccc..2465cbe07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mcp-agent" -version = "0.1.16" +version = "0.1.17" description = "Build effective agents with Model Context Protocol (MCP) using simple, composable patterns." readme = "README.md" license = { file = "LICENSE" } diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index 925c91d51..b28a31f01 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -15,6 +15,7 @@ from mcp_agent.executor.signal_registry import SignalRegistry from mcp_agent.logging.event_progress import ProgressAction from mcp_agent.logging.logger import get_logger +from mcp_agent.logging.logger import set_default_bound_context from mcp_agent.executor.decorator_registry import ( DecoratorRegistry, register_asyncio_decorators, @@ -195,12 +196,14 @@ def logger(self): try: if self._context is not None: self._logger._bound_context = self._context # type: ignore[attr-defined] + except Exception: pass else: # Update the logger's bound context in case upstream_session was set after logger creation if self._context and hasattr(self._logger, "_bound_context"): self._logger._bound_context = self._context + return self._logger async def initialize(self): @@ -231,6 +234,12 @@ async def initialize(self): # Store a reference to this app instance in the context for easier access self._context.app = self + # Provide a safe default bound context for loggers created after init without explicit context + try: + set_default_bound_context(self._context) + except Exception: + pass + # Auto-load subagents if enabled in settings try: subagents = self._config.agents @@ -840,7 +849,15 @@ def decorator(target: Callable[..., R]) -> Callable[..., R]: ) if task_defn: - if isinstance(target, MethodType): + # Prevent re-decoration of an already temporal-decorated function, + # but still register it with the app. + if hasattr(target, "__temporal_activity_definition"): + self.logger.debug( + "Skipping redecorate for already-temporal activity", + data={"activity_name": activity_name}, + ) + task_callable = target + elif isinstance(target, MethodType): self_ref = target.__self__ @functools.wraps(func) @@ -903,7 +920,15 @@ def _register_global_workflow_tasks(self): ) if task_defn: # Engine-specific decorator available - if isinstance(target, MethodType): + # Prevent re-decoration of an already temporal-decorated function, + # but still register it with the app. + if hasattr(target, "__temporal_activity_definition"): + self.logger.debug( + "Skipping redecorate for already-temporal activity", + data={"activity_name": activity_name}, + ) + task_callable = target + elif isinstance(target, MethodType): self_ref = target.__self__ @functools.wraps(func) diff --git a/src/mcp_agent/cli/cloud/commands/auth/whoami/main.py b/src/mcp_agent/cli/cloud/commands/auth/whoami/main.py index 7965fdef8..14e3fa2a6 100644 --- a/src/mcp_agent/cli/cloud/commands/auth/whoami/main.py +++ b/src/mcp_agent/cli/cloud/commands/auth/whoami/main.py @@ -1,6 +1,5 @@ """MCP Agent Cloud whoami command implementation.""" - from rich.console import Console from rich.panel import Panel from rich.table import Table diff --git a/src/mcp_agent/cli/cloud/commands/logger/__init__.py b/src/mcp_agent/cli/cloud/commands/logger/__init__.py index bb0332d7a..66805dc7e 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/__init__.py +++ b/src/mcp_agent/cli/cloud/commands/logger/__init__.py @@ -6,4 +6,4 @@ from .tail.main import tail_logs -__all__ = ["tail_logs"] \ No newline at end of file +__all__ = ["tail_logs"] diff --git a/src/mcp_agent/cli/cloud/commands/logger/configure/__init__.py b/src/mcp_agent/cli/cloud/commands/logger/configure/__init__.py index bcfed9693..988a5e789 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/configure/__init__.py +++ b/src/mcp_agent/cli/cloud/commands/logger/configure/__init__.py @@ -2,4 +2,4 @@ from .main import configure_logger -__all__ = ["configure_logger"] \ No newline at end of file +__all__ = ["configure_logger"] diff --git a/src/mcp_agent/cli/cloud/commands/logger/configure/main.py b/src/mcp_agent/cli/cloud/commands/logger/configure/main.py index 448cc0e63..fff24b2ff 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/configure/main.py +++ b/src/mcp_agent/cli/cloud/commands/logger/configure/main.py @@ -32,10 +32,10 @@ def configure_logger( ), ) -> None: """Configure OTEL endpoint and headers for log collection. - + This command allows you to configure the OpenTelemetry endpoint and headers that will be used for collecting logs from your deployed MCP apps. - + Examples: mcp-agent cloud logger configure https://otel.example.com:4318/v1/logs mcp-agent cloud logger configure https://otel.example.com --headers "Authorization=Bearer token,X-Custom=value" @@ -44,9 +44,9 @@ def configure_logger( if not endpoint and not test: console.print("[red]Error: Must specify endpoint or use --test[/red]") raise typer.Exit(1) - + config_path = _find_config_file() - + if test: if config_path and config_path.exists(): config = _load_config(config_path) @@ -54,7 +54,9 @@ def configure_logger( endpoint = otel_config.get("endpoint") headers_dict = otel_config.get("headers", {}) else: - console.print("[yellow]No configuration file found. Use --endpoint to set up OTEL configuration.[/yellow]") + console.print( + "[yellow]No configuration file found. Use --endpoint to set up OTEL configuration.[/yellow]" + ) raise typer.Exit(1) else: headers_dict = {} @@ -64,54 +66,68 @@ def configure_logger( key, value = header_pair.strip().split("=", 1) headers_dict[key.strip()] = value.strip() except ValueError: - console.print("[red]Error: Headers must be in format 'key=value,key2=value2'[/red]") + console.print( + "[red]Error: Headers must be in format 'key=value,key2=value2'[/red]" + ) raise typer.Exit(1) - + if endpoint: console.print(f"[blue]Testing connection to {endpoint}...[/blue]") - + try: with httpx.Client(timeout=10.0) as client: response = client.get( - endpoint.replace("/v1/logs", "/health") if "/v1/logs" in endpoint else f"{endpoint}/health", - headers=headers_dict + endpoint.replace("/v1/logs", "/health") + if "/v1/logs" in endpoint + else f"{endpoint}/health", + headers=headers_dict, ) - - if response.status_code in [200, 404]: # 404 is fine, means endpoint exists + + if response.status_code in [ + 200, + 404, + ]: # 404 is fine, means endpoint exists console.print("[green]✓ Connection successful[/green]") else: - console.print(f"[yellow]⚠ Got status {response.status_code}, but endpoint is reachable[/yellow]") - + console.print( + f"[yellow]⚠ Got status {response.status_code}, but endpoint is reachable[/yellow]" + ) + except httpx.RequestError as e: console.print(f"[red]✗ Connection failed: {e}[/red]") if not test: - console.print("[yellow]Configuration will be saved anyway. Check your endpoint URL and network connection.[/yellow]") - + console.print( + "[yellow]Configuration will be saved anyway. Check your endpoint URL and network connection.[/yellow]" + ) + if not test: if not config_path: config_path = Path.cwd() / "mcp_agent.config.yaml" - + config = _load_config(config_path) if config_path.exists() else {} - + if "otel" not in config: config["otel"] = {} - + config["otel"]["endpoint"] = endpoint config["otel"]["headers"] = headers_dict - + try: config_path.parent.mkdir(parents=True, exist_ok=True) with open(config_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) - - console.print(Panel( - f"[green]✓ OTEL configuration saved to {config_path}[/green]\n\n" - f"Endpoint: {endpoint}\n" - f"Headers: {len(headers_dict)} configured" + (f" ({', '.join(headers_dict.keys())})" if headers_dict else ""), - title="Configuration Saved", - border_style="green" - )) - + + console.print( + Panel( + f"[green]✓ OTEL configuration saved to {config_path}[/green]\n\n" + f"Endpoint: {endpoint}\n" + f"Headers: {len(headers_dict)} configured" + + (f" ({', '.join(headers_dict.keys())})" if headers_dict else ""), + title="Configuration Saved", + border_style="green", + ) + ) + except Exception as e: console.print(f"[red]Error saving configuration: {e}[/red]") raise typer.Exit(1) @@ -134,4 +150,4 @@ def _load_config(config_path: Path) -> dict: with open(config_path, "r") as f: return yaml.safe_load(f) or {} except Exception as e: - raise CLIError(f"Failed to load config from {config_path}: {e}") \ No newline at end of file + raise CLIError(f"Failed to load config from {config_path}: {e}") diff --git a/src/mcp_agent/cli/cloud/commands/logger/tail/__init__.py b/src/mcp_agent/cli/cloud/commands/logger/tail/__init__.py index 36a27364c..81c22db0a 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/tail/__init__.py +++ b/src/mcp_agent/cli/cloud/commands/logger/tail/__init__.py @@ -2,4 +2,4 @@ from .main import tail_logs -__all__ = ["tail_logs"] \ No newline at end of file +__all__ = ["tail_logs"] diff --git a/src/mcp_agent/cli/cloud/commands/logger/tail/main.py b/src/mcp_agent/cli/cloud/commands/logger/tail/main.py index b9ecf5368..d99c6df70 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/tail/main.py +++ b/src/mcp_agent/cli/cloud/commands/logger/tail/main.py @@ -18,7 +18,10 @@ from mcp_agent.cli.exceptions import CLIError from mcp_agent.cli.auth import load_credentials, UserCredentials from mcp_agent.cli.core.constants import DEFAULT_API_BASE_URL -from mcp_agent.cli.cloud.commands.logger.utils import parse_app_identifier, resolve_server_url +from mcp_agent.cli.cloud.commands.logger.utils import ( + parse_app_identifier, + resolve_server_url, +) console = Console() @@ -73,89 +76,103 @@ def tail_logs( ), ) -> None: """Tail logs for an MCP app deployment. - + Retrieve and optionally stream logs from deployed MCP apps. Supports filtering by time duration, text patterns, and continuous streaming. - + Examples: # Get last 50 logs from an app mcp-agent cloud logger tail app_abc123 --limit 50 - + # Stream logs continuously mcp-agent cloud logger tail https://app.mcpac.dev/abc123 --follow - + # Show logs from the last hour with error filtering mcp-agent cloud logger tail app_abc123 --since 1h --grep "ERROR|WARN" - + # Follow logs and filter for specific patterns mcp-agent cloud logger tail app_abc123 --follow --grep "authentication.*failed" """ - + credentials = load_credentials() if not credentials: - console.print("[red]Error: Not authenticated. Run 'mcp-agent login' first.[/red]") + console.print( + "[red]Error: Not authenticated. Run 'mcp-agent login' first.[/red]" + ) raise typer.Exit(4) - + # Validate conflicting options if follow and since: - console.print("[red]Error: --since cannot be used with --follow (streaming mode)[/red]") + console.print( + "[red]Error: --since cannot be used with --follow (streaming mode)[/red]" + ) raise typer.Exit(6) - + if follow and limit != DEFAULT_LOG_LIMIT: - console.print("[red]Error: --limit cannot be used with --follow (streaming mode)[/red]") + console.print( + "[red]Error: --limit cannot be used with --follow (streaming mode)[/red]" + ) raise typer.Exit(6) - + if follow and order_by: - console.print("[red]Error: --order-by cannot be used with --follow (streaming mode)[/red]") + console.print( + "[red]Error: --order-by cannot be used with --follow (streaming mode)[/red]" + ) raise typer.Exit(6) - + if follow and (asc or desc): - console.print("[red]Error: --asc/--desc cannot be used with --follow (streaming mode)[/red]") + console.print( + "[red]Error: --asc/--desc cannot be used with --follow (streaming mode)[/red]" + ) raise typer.Exit(6) - + # Validate order_by values if order_by and order_by not in ["timestamp", "severity"]: console.print("[red]Error: --order-by must be 'timestamp' or 'severity'[/red]") raise typer.Exit(6) - + # Validate that both --asc and --desc are not used together if asc and desc: console.print("[red]Error: Cannot use both --asc and --desc together[/red]") raise typer.Exit(6) - + # Validate format values if format and format not in ["text", "json", "yaml"]: console.print("[red]Error: --format must be 'text', 'json', or 'yaml'[/red]") raise typer.Exit(6) - + app_id, config_id, server_url = parse_app_identifier(app_identifier) - + try: if follow: - asyncio.run(_stream_logs( - app_id=app_id, - config_id=config_id, - server_url=server_url, - credentials=credentials, - grep_pattern=grep, - app_identifier=app_identifier, - format=format, - )) + asyncio.run( + _stream_logs( + app_id=app_id, + config_id=config_id, + server_url=server_url, + credentials=credentials, + grep_pattern=grep, + app_identifier=app_identifier, + format=format, + ) + ) else: - asyncio.run(_fetch_logs( - app_id=app_id, - config_id=config_id, - server_url=server_url, - credentials=credentials, - since=since, - grep_pattern=grep, - limit=limit, - order_by=order_by, - asc=asc, - desc=desc, - format=format, - )) - + asyncio.run( + _fetch_logs( + app_id=app_id, + config_id=config_id, + server_url=server_url, + credentials=credentials, + since=since, + grep_pattern=grep, + limit=limit, + order_by=order_by, + asc=asc, + desc=desc, + format=format, + ) + ) + except KeyboardInterrupt: console.print("\n[yellow]Interrupted by user[/yellow]") sys.exit(0) @@ -166,7 +183,7 @@ def tail_logs( async def _fetch_logs( app_id: Optional[str], - config_id: Optional[str], + config_id: Optional[str], server_url: Optional[str], credentials: UserCredentials, since: Optional[str], @@ -178,38 +195,40 @@ async def _fetch_logs( format: str, ) -> None: """Fetch logs one-time via HTTP API.""" - + api_base = DEFAULT_API_BASE_URL headers = { "Authorization": f"Bearer {credentials.api_key}", "Content-Type": "application/json", } - + payload = {} - + if app_id: payload["app_id"] = app_id elif config_id: payload["app_configuration_id"] = config_id else: - raise CLIError("Unable to determine app or configuration ID from provided identifier") - + raise CLIError( + "Unable to determine app or configuration ID from provided identifier" + ) + if since: payload["since"] = since if limit: payload["limit"] = limit - + if order_by: if order_by == "timestamp": payload["orderBy"] = "LOG_ORDER_BY_TIMESTAMP" elif order_by == "severity": payload["orderBy"] = "LOG_ORDER_BY_LEVEL" - + if asc: payload["order"] = "LOG_ORDER_ASC" elif desc: payload["order"] = "LOG_ORDER_DESC" - + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -217,7 +236,7 @@ async def _fetch_logs( transient=True, ) as progress: progress.add_task("Fetching logs...", total=None) - + try: async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post( @@ -225,128 +244,148 @@ async def _fetch_logs( json=payload, headers=headers, ) - + if response.status_code == 401: - raise CLIError("Authentication failed. Try running 'mcp-agent login'") + raise CLIError( + "Authentication failed. Try running 'mcp-agent login'" + ) elif response.status_code == 404: raise CLIError("App or configuration not found") elif response.status_code != 200: - raise CLIError(f"API request failed: {response.status_code} {response.text}") - + raise CLIError( + f"API request failed: {response.status_code} {response.text}" + ) + data = response.json() log_entries = data.get("logEntries", []) - + except httpx.RequestError as e: raise CLIError(f"Failed to connect to API: {e}") - - filtered_logs = _filter_logs(log_entries, grep_pattern) if grep_pattern else log_entries - + + filtered_logs = ( + _filter_logs(log_entries, grep_pattern) if grep_pattern else log_entries + ) + if not filtered_logs: console.print("[yellow]No logs found matching the criteria[/yellow]") return - - _display_logs(filtered_logs, title=f"Logs for {app_id or config_id}", format=format) + _display_logs(filtered_logs, title=f"Logs for {app_id or config_id}", format=format) async def _stream_logs( app_id: Optional[str], config_id: Optional[str], - server_url: Optional[str], + server_url: Optional[str], credentials: UserCredentials, grep_pattern: Optional[str], app_identifier: str, format: str, ) -> None: """Stream logs continuously via SSE.""" - + if not server_url: server_url = await resolve_server_url(app_id, config_id, credentials) - + parsed = urlparse(server_url) stream_url = f"{parsed.scheme}://{parsed.netloc}/logs" hostname = parsed.hostname or "" - deployment_id = hostname.split('.')[0] if '.' in hostname else hostname - + deployment_id = hostname.split(".")[0] if "." in hostname else hostname + headers = { "Accept": "text/event-stream", "Cache-Control": "no-cache", "X-Routing-Key": deployment_id, } - + if credentials.api_key: headers["Authorization"] = f"Bearer {credentials.api_key}" - - console.print(f"[blue]Streaming logs from {app_identifier} (Press Ctrl+C to stop)[/blue]") - + + console.print( + f"[blue]Streaming logs from {app_identifier} (Press Ctrl+C to stop)[/blue]" + ) + # Setup signal handler for graceful shutdown def signal_handler(signum, frame): console.print("\n[yellow]Stopping log stream...[/yellow]") sys.exit(0) - + signal.signal(signal.SIGINT, signal_handler) - + try: async with httpx.AsyncClient(timeout=None) as client: async with client.stream("GET", stream_url, headers=headers) as response: - if response.status_code == 401: - raise CLIError("Authentication failed. Try running 'mcp-agent login'") + raise CLIError( + "Authentication failed. Try running 'mcp-agent login'" + ) elif response.status_code == 404: raise CLIError("Log stream not found for the specified app") elif response.status_code != 200: - raise CLIError(f"Failed to connect to log stream: {response.status_code}") - + raise CLIError( + f"Failed to connect to log stream: {response.status_code}" + ) + console.print("[green]✓ Connected to log stream[/green]\n") - + buffer = "" async for chunk in response.aiter_text(): buffer += chunk - lines = buffer.split('\n') - + lines = buffer.split("\n") + for line in lines[:-1]: - if line.startswith('data:'): - data_content = line.removeprefix('data:') - + if line.startswith("data:"): + data_content = line.removeprefix("data:") + try: log_data = json.loads(data_content) - - if 'message' in log_data: - timestamp = log_data.get('time') + + if "message" in log_data: + timestamp = log_data.get("time") if timestamp: - formatted_timestamp = _convert_timestamp_to_local(timestamp) + formatted_timestamp = ( + _convert_timestamp_to_local(timestamp) + ) else: formatted_timestamp = datetime.now().isoformat() - + log_entry = { - 'timestamp': formatted_timestamp, - 'message': log_data['message'], - 'level': log_data.get('level', 'INFO') + "timestamp": formatted_timestamp, + "message": log_data["message"], + "level": log_data.get("level", "INFO"), } - - if not grep_pattern or _matches_pattern(log_entry['message'], grep_pattern): + + if not grep_pattern or _matches_pattern( + log_entry["message"], grep_pattern + ): _display_log_entry(log_entry, format=format) - + except json.JSONDecodeError: # Skip malformed JSON continue - + except httpx.RequestError as e: raise CLIError(f"Failed to connect to log stream: {e}") - - -def _filter_logs(log_entries: List[Dict[str, Any]], pattern: str) -> List[Dict[str, Any]]: +def _filter_logs( + log_entries: List[Dict[str, Any]], pattern: str +) -> List[Dict[str, Any]]: """Filter log entries by pattern.""" if not pattern: return log_entries - + try: regex = re.compile(pattern, re.IGNORECASE) - return [entry for entry in log_entries if regex.search(entry.get('message', ''))] + return [ + entry for entry in log_entries if regex.search(entry.get("message", "")) + ] except re.error: - return [entry for entry in log_entries if pattern.lower() in entry.get('message', '').lower()] + return [ + entry + for entry in log_entries + if pattern.lower() in entry.get("message", "").lower() + ] def _matches_pattern(message: str, pattern: str) -> bool: @@ -361,21 +400,21 @@ def _matches_pattern(message: str, pattern: str) -> bool: def _clean_log_entry(entry: Dict[str, Any]) -> Dict[str, Any]: """Clean up a log entry for structured output formats.""" cleaned_entry = entry.copy() - cleaned_entry['severity'] = _parse_log_level(entry.get('level', 'INFO')) - cleaned_entry['message'] = _clean_message(entry.get('message', '')) - cleaned_entry.pop('level', None) + cleaned_entry["severity"] = _parse_log_level(entry.get("level", "INFO")) + cleaned_entry["message"] = _clean_message(entry.get("message", "")) + cleaned_entry.pop("level", None) return cleaned_entry def _display_text_log_entry(entry: Dict[str, Any]) -> None: """Display a single log entry in text format.""" - timestamp = _format_timestamp(entry.get('timestamp', '')) - raw_level = entry.get('level', 'INFO') + timestamp = _format_timestamp(entry.get("timestamp", "")) + raw_level = entry.get("level", "INFO") level = _parse_log_level(raw_level) - message = _clean_message(entry.get('message', '')) - + message = _clean_message(entry.get("message", "")) + level_style = _get_level_style(level) - + console.print( f"[bright_black not bold]{timestamp}[/bright_black not bold] " f"[{level_style}]{level:7}[/{level_style}] " @@ -383,11 +422,13 @@ def _display_text_log_entry(entry: Dict[str, Any]) -> None: ) -def _display_logs(log_entries: List[Dict[str, Any]], title: str = "Logs", format: str = "text") -> None: +def _display_logs( + log_entries: List[Dict[str, Any]], title: str = "Logs", format: str = "text" +) -> None: """Display logs in the specified format.""" if not log_entries: return - + if format == "json": cleaned_entries = [_clean_log_entry(entry) for entry in log_entries] print(json.dumps(cleaned_entries, indent=2)) @@ -397,7 +438,7 @@ def _display_logs(log_entries: List[Dict[str, Any]], title: str = "Logs", format else: # text format (default) if title: console.print(f"[bold blue]{title}[/bold blue]\n") - + for entry in log_entries: _display_text_log_entry(entry) @@ -426,20 +467,20 @@ def _format_timestamp(timestamp_str: str) -> str: try: if timestamp_str: # Parse UTC timestamp and convert to local time - dt_utc = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00')) + dt_utc = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00")) dt_local = dt_utc.astimezone() - return dt_local.strftime('%H:%M:%S') - return datetime.now().strftime('%H:%M:%S') + return dt_local.strftime("%H:%M:%S") + return datetime.now().strftime("%H:%M:%S") except (ValueError, TypeError): return timestamp_str[:8] if len(timestamp_str) >= 8 else timestamp_str def _parse_log_level(level: str) -> str: """Parse log level from API format to clean display format.""" - if level.startswith('LOG_LEVEL_'): - clean_level = level.replace('LOG_LEVEL_', '') - if clean_level == 'UNSPECIFIED': - return 'UNKNOWN' + if level.startswith("LOG_LEVEL_"): + clean_level = level.replace("LOG_LEVEL_", "") + if clean_level == "UNSPECIFIED": + return "UNKNOWN" return clean_level return level.upper() @@ -447,29 +488,36 @@ def _parse_log_level(level: str) -> str: def _clean_message(message: str) -> str: """Remove redundant log level prefix from message if present.""" prefixes = [ - 'ERROR:', 'WARNING:', 'INFO:', 'DEBUG:', 'TRACE:', - 'WARN:', 'FATAL:', 'UNKNOWN:', 'UNSPECIFIED:' + "ERROR:", + "WARNING:", + "INFO:", + "DEBUG:", + "TRACE:", + "WARN:", + "FATAL:", + "UNKNOWN:", + "UNSPECIFIED:", ] - + for prefix in prefixes: if message.startswith(prefix): - return message[len(prefix):].lstrip() - + return message[len(prefix) :].lstrip() + return message def _get_level_style(level: str) -> str: """Get Rich style for log level.""" level = level.upper() - if level in ['ERROR', 'FATAL']: + if level in ["ERROR", "FATAL"]: return "red bold" - elif level in ['WARN', 'WARNING']: + elif level in ["WARN", "WARNING"]: return "yellow bold" - elif level == 'INFO': + elif level == "INFO": return "blue" - elif level in ['DEBUG', 'TRACE']: + elif level in ["DEBUG", "TRACE"]: return "dim" - elif level in ['UNKNOWN', 'UNSPECIFIED']: + elif level in ["UNKNOWN", "UNSPECIFIED"]: return "magenta" else: return "white" diff --git a/src/mcp_agent/cli/cloud/commands/logger/utils.py b/src/mcp_agent/cli/cloud/commands/logger/utils.py index 6c9a8c2a7..3e05bad20 100644 --- a/src/mcp_agent/cli/cloud/commands/logger/utils.py +++ b/src/mcp_agent/cli/cloud/commands/logger/utils.py @@ -9,35 +9,37 @@ from mcp_agent.cli.core.constants import DEFAULT_API_BASE_URL -def parse_app_identifier(identifier: str) -> Tuple[Optional[str], Optional[str], Optional[str]]: +def parse_app_identifier( + identifier: str, +) -> Tuple[Optional[str], Optional[str], Optional[str]]: """Parse app identifier to extract app ID, config ID, and server URL.""" - + # Check if it's a URL - if identifier.startswith(('http://', 'https://')): + if identifier.startswith(("http://", "https://")): return None, None, identifier - + # Check if it's an MCPAppConfig ID (starts with apcnf_) - if identifier.startswith('apcnf_'): + if identifier.startswith("apcnf_"): return None, identifier, None - + # Check if it's an MCPApp ID (starts with app_) - if identifier.startswith('app_'): + if identifier.startswith("app_"): return identifier, None, None - + # If no specific prefix, assume it's an app ID for backward compatibility return identifier, None, None async def resolve_server_url( app_id: Optional[str], - config_id: Optional[str], + config_id: Optional[str], credentials: UserCredentials, ) -> str: """Resolve server URL from app ID or configuration ID.""" - + if not app_id and not config_id: raise CLIError("Either app_id or config_id must be provided") - + # Determine the endpoint and payload based on identifier type if app_id: endpoint = "/mcp_app/get_app" @@ -57,38 +59,42 @@ async def resolve_server_url( no_url_msg = f"No server URL found for app configuration '{config_id}'" offline_msg = f"App configuration '{config_id}' server is offline" api_error_msg = "Failed to get app configuration" - + api_base = DEFAULT_API_BASE_URL headers = { "Authorization": f"Bearer {credentials.api_key}", "Content-Type": "application/json", } - + try: async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(f"{api_base}{endpoint}", json=payload, headers=headers) - + response = await client.post( + f"{api_base}{endpoint}", json=payload, headers=headers + ) + if response.status_code == 404: raise CLIError(not_found_msg) elif response.status_code != 200: - raise CLIError(f"{api_error_msg}: {response.status_code} {response.text}") - + raise CLIError( + f"{api_error_msg}: {response.status_code} {response.text}" + ) + data = response.json() resource_info = data.get(response_key, {}) server_info = resource_info.get("appServerInfo") - + if not server_info: raise CLIError(not_deployed_msg) - + server_url = server_info.get("serverUrl") if not server_url: raise CLIError(no_url_msg) - + status = server_info.get("status", "APP_SERVER_STATUS_UNSPECIFIED") if status == "APP_SERVER_STATUS_OFFLINE": raise CLIError(offline_msg) - + return server_url - + except httpx.RequestError as e: - raise CLIError(f"Failed to connect to API: {e}") \ No newline at end of file + raise CLIError(f"Failed to connect to API: {e}") diff --git a/src/mcp_agent/cli/cloud/main.py b/src/mcp_agent/cli/cloud/main.py index 7c5fdbf94..392bec91d 100644 --- a/src/mcp_agent/cli/cloud/main.py +++ b/src/mcp_agent/cli/cloud/main.py @@ -13,7 +13,13 @@ from rich.panel import Panel from typer.core import TyperGroup -from mcp_agent.cli.cloud.commands import configure_app, deploy_config, login, logout, whoami +from mcp_agent.cli.cloud.commands import ( + configure_app, + deploy_config, + login, + logout, + whoami, +) from mcp_agent.cli.cloud.commands.logger import tail_logs from mcp_agent.cli.cloud.commands.app import ( delete_app, @@ -173,7 +179,9 @@ def invoke(self, ctx): # Add sub-typers to cloud app_cmd_cloud.add_typer(app_cmd_cloud_auth, name="auth", help="Authentication commands") -app_cmd_cloud.add_typer(app_cmd_cloud_logger, name="logger", help="Logging and observability") +app_cmd_cloud.add_typer( + app_cmd_cloud_logger, name="logger", help="Logging and observability" +) # Register cloud commands app.add_typer(app_cmd_cloud, name="cloud", help="Cloud operations and management") # Top-level auth commands that map to cloud auth commands diff --git a/src/mcp_agent/core/context.py b/src/mcp_agent/core/context.py index 462e62653..d449c938a 100644 --- a/src/mcp_agent/core/context.py +++ b/src/mcp_agent/core/context.py @@ -89,6 +89,10 @@ class Context(BaseModel): # Token counting and cost tracking token_counter: Optional[TokenCounter] = None + # Dynamic gateway configuration (per-run overrides via Temporal memo) + gateway_url: str | None = None + gateway_token: str | None = None + model_config = ConfigDict( extra="allow", arbitrary_types_allowed=True, # Tell Pydantic to defer type evaluation diff --git a/src/mcp_agent/executor/temporal/__init__.py b/src/mcp_agent/executor/temporal/__init__.py index 34984508e..3840d4944 100644 --- a/src/mcp_agent/executor/temporal/__init__.py +++ b/src/mcp_agent/executor/temporal/__init__.py @@ -32,10 +32,13 @@ from mcp_agent.config import TemporalSettings from mcp_agent.executor.executor import Executor, ExecutorConfig, R + from mcp_agent.executor.temporal.workflow_signal import TemporalSignalHandler from mcp_agent.executor.workflow_signal import SignalHandler from mcp_agent.logging.logger import get_logger from mcp_agent.utils.common import unwrap +from mcp_agent.executor.temporal.interceptor import ContextPropagationInterceptor +from mcp_agent.executor.temporal.system_activities import SystemActivities if TYPE_CHECKING: from mcp_agent.app import MCPApp @@ -119,7 +122,7 @@ async def run_task(task: Callable[..., R] | Coroutine[Any, Any, R]) -> R: return await task(*args, **kwargs) else: # Check if we're in a Temporal workflow context - if workflow._Runtime.current(): + if workflow.in_workflow(): wrapped_task = functools.partial(task, *args, **kwargs) result = wrapped_task() else: @@ -196,7 +199,7 @@ async def execute( """Execute multiple tasks (activities) in parallel.""" # Must be called from within a workflow - if not workflow._Runtime.current(): + if not workflow.in_workflow(): raise RuntimeError( "TemporalExecutor.execute must be called from within a workflow" ) @@ -214,7 +217,7 @@ async def execute_many( """Execute multiple tasks (activities) in parallel.""" # Must be called from within a workflow - if not workflow._Runtime.current(): + if not workflow.in_workflow(): raise RuntimeError( "TemporalExecutor.execute must be called from within a workflow" ) @@ -232,7 +235,7 @@ async def execute_streaming( *args, **kwargs, ) -> AsyncIterator[R | BaseException]: - if not workflow._Runtime.current(): + if not workflow.in_workflow(): raise RuntimeError( "TemporalExecutor.execute_streaming must be called from within a workflow" ) @@ -263,9 +266,9 @@ async def ensure_client(self): api_key=self.config.api_key, tls=self.config.tls, data_converter=pydantic_data_converter, - interceptors=[TracingInterceptor()] + interceptors=[TracingInterceptor(), ContextPropagationInterceptor()] if self.context.tracing_enabled - else [], + else [ContextPropagationInterceptor()], rpc_metadata=self.config.rpc_metadata or {}, ) @@ -278,6 +281,7 @@ async def start_workflow( wait_for_result: bool = False, workflow_id: str | None = None, task_queue: str | None = None, + workflow_memo: Dict[str, Any] | None = None, **kwargs: Any, ) -> WorkflowHandle: """ @@ -369,6 +373,7 @@ async def start_workflow( task_queue=task_queue, id_reuse_policy=id_reuse_policy, rpc_metadata=self.config.rpc_metadata or {}, + memo=workflow_memo or {}, ) else: handle: WorkflowHandle = await self.client.start_workflow( @@ -377,6 +382,7 @@ async def start_workflow( task_queue=task_queue, id_reuse_policy=id_reuse_policy, rpc_metadata=self.config.rpc_metadata or {}, + memo=workflow_memo or {}, ) # Wait for the result if requested @@ -497,6 +503,15 @@ async def create_temporal_worker_for_app(app: "MCPApp"): # Collect activities from the global registry activity_registry = running_app.context.task_registry + # Register system activities (logging, human input proxy, generic relays) + system_activities = SystemActivities(context=running_app.context) + app.workflow_task(name="mcp_forward_log")(system_activities.forward_log) + app.workflow_task(name="mcp_request_user_input")( + system_activities.request_user_input + ) + app.workflow_task(name="mcp_relay_notify")(system_activities.relay_notify) + app.workflow_task(name="mcp_relay_request")(system_activities.relay_request) + for name in activity_registry.list_activities(): activities.append(activity_registry.get_activity(name)) @@ -508,6 +523,7 @@ async def create_temporal_worker_for_app(app: "MCPApp"): task_queue=running_app.executor.config.task_queue, activities=activities, workflows=workflows, + interceptors=[ContextPropagationInterceptor()], ) try: diff --git a/src/mcp_agent/executor/temporal/interceptor.py b/src/mcp_agent/executor/temporal/interceptor.py new file mode 100644 index 000000000..c680fdaae --- /dev/null +++ b/src/mcp_agent/executor/temporal/interceptor.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import Any, Mapping, Protocol, Type + +import temporalio.activity +import temporalio.api.common.v1 +import temporalio.client +import temporalio.converter +import temporalio.worker +import temporalio.workflow +from mcp_agent.logging.logger import get_logger +from mcp_agent.executor.temporal.temporal_context import ( + EXECUTION_ID_KEY, + get_execution_id, + set_execution_id, +) + + +class _InputWithHeaders(Protocol): + headers: Mapping[str, temporalio.api.common.v1.Payload] + + +logger = get_logger(__name__) + + +def set_header_from_context( + input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter +) -> None: + execution_id_val = get_execution_id() + + if execution_id_val: + input.headers = { + **input.headers, + EXECUTION_ID_KEY: payload_converter.to_payload(execution_id_val), + } + + +@contextmanager +def context_from_header( + input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter +): + prev_exec_id = get_execution_id() + execution_id_payload = input.headers.get(EXECUTION_ID_KEY) + execution_id_from_header = ( + payload_converter.from_payload(execution_id_payload, str) + if execution_id_payload + else None + ) + set_execution_id(execution_id_from_header if execution_id_from_header else None) + + try: + yield + finally: + set_execution_id(prev_exec_id) + + +class ContextPropagationInterceptor( + temporalio.client.Interceptor, temporalio.worker.Interceptor +): + """Interceptor that propagates a value through client, workflow and activity calls. + + This interceptor implements methods `temporalio.client.Interceptor` and `temporalio.worker.Interceptor` so that + + (1) an execution ID key is taken from context by the client code and sent in a header field with outbound requests + (2) workflows take this value from their task input, set it in context, and propagate it into the header field of + their outbound calls + (3) activities similarly take the value from their task input and set it in context so that it's available for their + outbound calls + """ + + def __init__( + self, + payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter, + ) -> None: + self._payload_converter = payload_converter + + def intercept_client( + self, next: temporalio.client.OutboundInterceptor + ) -> temporalio.client.OutboundInterceptor: + return _ContextPropagationClientOutboundInterceptor( + next, self._payload_converter + ) + + def intercept_activity( + self, next: temporalio.worker.ActivityInboundInterceptor + ) -> temporalio.worker.ActivityInboundInterceptor: + return _ContextPropagationActivityInboundInterceptor(next) + + def workflow_interceptor_class( + self, input: temporalio.worker.WorkflowInterceptorClassInput + ) -> Type[_ContextPropagationWorkflowInboundInterceptor]: + return _ContextPropagationWorkflowInboundInterceptor + + +class _ContextPropagationClientOutboundInterceptor( + temporalio.client.OutboundInterceptor +): + def __init__( + self, + next: temporalio.client.OutboundInterceptor, + payload_converter: temporalio.converter.PayloadConverter, + ) -> None: + super().__init__(next) + self._payload_converter = payload_converter + + async def start_workflow( + self, input: temporalio.client.StartWorkflowInput + ) -> temporalio.client.WorkflowHandle[Any, Any]: + set_header_from_context(input, self._payload_converter) + return await super().start_workflow(input) + + async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> Any: + set_header_from_context(input, self._payload_converter) + return await super().query_workflow(input) + + async def signal_workflow( + self, input: temporalio.client.SignalWorkflowInput + ) -> None: + set_header_from_context(input, self._payload_converter) + await super().signal_workflow(input) + + async def start_workflow_update( + self, input: temporalio.client.StartWorkflowUpdateInput + ) -> temporalio.client.WorkflowUpdateHandle[Any]: + set_header_from_context(input, self._payload_converter) + return await self.next.start_workflow_update(input) + + +class _ContextPropagationActivityInboundInterceptor( + temporalio.worker.ActivityInboundInterceptor +): + async def execute_activity( + self, input: temporalio.worker.ExecuteActivityInput + ) -> Any: + with context_from_header(input, temporalio.activity.payload_converter()): + return await self.next.execute_activity(input) + + +class _ContextPropagationWorkflowInboundInterceptor( + temporalio.worker.WorkflowInboundInterceptor +): + def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None: + self.next.init(_ContextPropagationWorkflowOutboundInterceptor(outbound)) + + async def execute_workflow( + self, input: temporalio.worker.ExecuteWorkflowInput + ) -> Any: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.execute_workflow(input) + + async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.handle_signal(input) + + async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.handle_query(input) + + def handle_update_validator( + self, input: temporalio.worker.HandleUpdateInput + ) -> None: + with context_from_header(input, temporalio.workflow.payload_converter()): + self.next.handle_update_validator(input) + + async def handle_update_handler( + self, input: temporalio.worker.HandleUpdateInput + ) -> Any: + with context_from_header(input, temporalio.workflow.payload_converter()): + return await self.next.handle_update_handler(input) + + +class _ContextPropagationWorkflowOutboundInterceptor( + temporalio.worker.WorkflowOutboundInterceptor +): + async def signal_child_workflow( + self, input: temporalio.worker.SignalChildWorkflowInput + ) -> None: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return await self.next.signal_child_workflow(input) + + async def signal_external_workflow( + self, input: temporalio.worker.SignalExternalWorkflowInput + ) -> None: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return await self.next.signal_external_workflow(input) + + def start_activity( + self, input: temporalio.worker.StartActivityInput + ) -> temporalio.workflow.ActivityHandle: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return self.next.start_activity(input) + + async def start_child_workflow( + self, input: temporalio.worker.StartChildWorkflowInput + ) -> temporalio.workflow.ChildWorkflowHandle: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return await self.next.start_child_workflow(input) + + def start_local_activity( + self, input: temporalio.worker.StartLocalActivityInput + ) -> temporalio.workflow.ActivityHandle: + set_header_from_context(input, temporalio.workflow.payload_converter()) + return self.next.start_local_activity(input) diff --git a/src/mcp_agent/executor/temporal/session_proxy.py b/src/mcp_agent/executor/temporal/session_proxy.py new file mode 100644 index 000000000..b0fa213a4 --- /dev/null +++ b/src/mcp_agent/executor/temporal/session_proxy.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Type + +import anyio +import mcp.types as types +from anyio.streams.memory import ( + MemoryObjectReceiveStream, + MemoryObjectSendStream, +) +from temporalio import workflow as _twf + +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.message import ServerMessageMetadata + +from mcp_agent.core.context import Context +from mcp_agent.executor.temporal.system_activities import SystemActivities +from mcp_agent.executor.temporal.temporal_context import get_execution_id + + +class SessionProxy(ServerSession): + """ + SessionProxy acts like an MCP `ServerSession` for code running under the + Temporal engine. It forwards server->client messages through the MCPApp + gateway so that logs, notifications, and requests reach the original + upstream MCP client. + + Behavior: + - Inside a Temporal workflow (deterministic scope), all network I/O is + performed via registered Temporal activities. + - Outside a workflow (e.g., inside an activity or plain asyncio code), + calls are executed directly using the SystemActivities helpers. + + This keeps workflow logic deterministic while remaining a drop-in proxy + for the common ServerSession methods used by the agent runtime. + """ + + def __init__(self, *, executor, context: Context) -> None: + # Create inert in-memory streams to satisfy base constructor. We do not + # use these streams; all communication is proxied via HTTP gateway. + send_read, recv_read = anyio.create_memory_object_stream(0) + send_write, recv_write = anyio.create_memory_object_stream(0) + + init_opts = InitializationOptions( + server_name="mcp_agent_proxy", + server_version="0.0.0", + capabilities=types.ServerCapabilities(), + instructions=None, + ) + # Initialize base class in stateless mode to skip handshake state + super().__init__( + recv_read, # type: ignore[arg-type] + send_write, # type: ignore[arg-type] + init_opts, + stateless=True, + ) + + # Keep references so streams aren't GC'd + self._dummy_streams: tuple[ + MemoryObjectSendStream[Any], + MemoryObjectReceiveStream[Any], + MemoryObjectSendStream[Any], + MemoryObjectReceiveStream[Any], + ] = (send_read, recv_read, send_write, recv_write) + + self._executor = executor + self._context = context + # Local helper used when we're not inside a workflow runtime + self._system_activities = SystemActivities(context) + # Provide a low-level RPC facade similar to real ServerSession + self.rpc = _RPC(self) + + # ---------------------- + # Generic passthroughs + # ---------------------- + async def notify(self, method: str, params: Dict[str, Any] | None = None) -> bool: + """Send a server->client notification via the gateway. + + Returns True on best-effort success. + """ + exec_id = get_execution_id() + if not exec_id: + return False + + if _in_workflow_runtime(): + try: + act = self._context.task_registry.get_activity("mcp_relay_notify") + await self._executor.execute(act, exec_id, method, params or {}) + return True + except Exception: + return False + # Non-workflow (activity/asyncio) + return bool( + await self._system_activities.relay_notify(exec_id, method, params or {}) + ) + + async def request( + self, method: str, params: Dict[str, Any] | None = None + ) -> Dict[str, Any]: + """Send a server->client request and return the client's response. + The result is a plain JSON-serializable dict. + """ + exec_id = get_execution_id() + if not exec_id: + return {"error": "missing_execution_id"} + + if _in_workflow_runtime(): + act = self._context.task_registry.get_activity("mcp_relay_request") + return await self._executor.execute(act, exec_id, method, params or {}) + return await self._system_activities.relay_request( + exec_id, method, params or {} + ) + + async def send_notification( + self, + notification: types.ServerNotification, + related_request_id: types.RequestId | None = None, + ) -> None: + root = notification.root + params: Dict[str, Any] | None = None + try: + if getattr(root, "params", None) is not None: + params = root.params.model_dump(by_alias=True, mode="json") # type: ignore[attr-defined] + else: + params = {} + except Exception: + params = {} + # Best-effort pass-through of related_request_id when provided + if related_request_id is not None: + params = dict(params or {}) + params["related_request_id"] = related_request_id + await self.notify(root.method, params) # type: ignore[attr-defined] + + async def send_request( + self, + request: types.ServerRequest, + result_type: Type[Any], + metadata: ServerMessageMetadata | None = None, + ) -> Any: + root = request.root + params: Dict[str, Any] | None = None + try: + if getattr(root, "params", None) is not None: + params = root.params.model_dump(by_alias=True, mode="json") # type: ignore[attr-defined] + else: + params = {} + except Exception: + params = {} + # Note: metadata (e.g., related_request_id) is handled server-side where applicable + payload = await self.request(root.method, params) # type: ignore[attr-defined] + # Attempt to validate into the requested result type + try: + return result_type.model_validate(payload) # type: ignore[attr-defined] + except Exception: + return payload + + async def send_log_message( + self, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, + ) -> None: + """Best-effort log forwarding to the client's UI.""" + # Prefer activity-based forwarding inside workflow for determinism + exec_id = get_execution_id() + if _in_workflow_runtime() and exec_id: + try: + act = self._context.task_registry.get_activity("mcp_forward_log") + namespace = ( + (data or {}).get("namespace") + if isinstance(data, dict) + else (logger or "mcp_agent") + ) + message = (data or {}).get("message") if isinstance(data, dict) else "" + await self._executor.execute( + act, + exec_id, + str(level), + namespace or (logger or "mcp_agent"), + message or "", + (data or {}), + ) + return + except Exception: + # Fall back to notify path below + pass + + params: Dict[str, Any] = {"level": str(level), "data": data, "logger": logger} + if related_request_id is not None: + params["related_request_id"] = related_request_id + await self.notify("notifications/message", params) + + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + related_request_id: str | None = None, + ) -> None: + params: Dict[str, Any] = { + "progressToken": progress_token, + "progress": progress, + } + if total is not None: + params["total"] = total + if message is not None: + params["message"] = message + if related_request_id is not None: + params["related_request_id"] = related_request_id + await self.notify("notifications/progress", params) + + async def send_resource_updated(self, uri: types.AnyUrl) -> None: + await self.notify("notifications/resources/updated", {"uri": str(uri)}) + + async def send_resource_list_changed(self) -> None: + await self.notify("notifications/resources/list_changed", {}) + + async def send_tool_list_changed(self) -> None: + await self.notify("notifications/tools/list_changed", {}) + + async def send_prompt_list_changed(self) -> None: + await self.notify("notifications/prompts/list_changed", {}) + + async def send_ping(self) -> types.EmptyResult: + result = await self.request("ping", {}) + return types.EmptyResult.model_validate(result) + + async def list_roots(self) -> types.ListRootsResult: + result = await self.request("roots/list", {}) + return types.ListRootsResult.model_validate(result) + + async def create_message( + self, + messages: List[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: List[str] | None = None, + metadata: Dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + related_request_id: types.RequestId | None = None, + ) -> types.CreateMessageResult: + params: Dict[str, Any] = { + "messages": [m.model_dump(by_alias=True, mode="json") for m in messages], + "maxTokens": max_tokens, + } + if system_prompt is not None: + params["systemPrompt"] = system_prompt + if include_context is not None: + params["includeContext"] = include_context + if temperature is not None: + params["temperature"] = temperature + if stop_sequences is not None: + params["stopSequences"] = stop_sequences + if metadata is not None: + params["metadata"] = metadata + if model_preferences is not None: + params["modelPreferences"] = model_preferences.model_dump( + by_alias=True, mode="json" + ) + if related_request_id is not None: + # Threading ID through JSON-RPC metadata is handled by gateway; include for completeness + params["related_request_id"] = related_request_id + + result = await self.request("sampling/createMessage", params) + return types.CreateMessageResult.model_validate(result) + + async def elicit( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + params: Dict[str, Any] = { + "message": message, + "requestedSchema": requestedSchema, + } + if related_request_id is not None: + params["related_request_id"] = related_request_id + result = await self.request("elicitation/create", params) + return types.ElicitResult.model_validate(result) + + +def _in_workflow_runtime() -> bool: + """Return True if currently executing inside a Temporal workflow sandbox.""" + try: + return _twf.in_workflow() + except Exception: + return False + + +class _RPC: + """Lightweight facade to mimic the low-level RPC interface on sessions.""" + + def __init__(self, proxy: SessionProxy) -> None: + self._proxy = proxy + + async def notify(self, method: str, params: Dict[str, Any] | None = None) -> None: + await self._proxy.notify(method, params or {}) + + async def request( + self, method: str, params: Dict[str, Any] | None = None + ) -> Dict[str, Any]: + return await self._proxy.request(method, params or {}) diff --git a/src/mcp_agent/executor/temporal/system_activities.py b/src/mcp_agent/executor/temporal/system_activities.py new file mode 100644 index 000000000..b215ece79 --- /dev/null +++ b/src/mcp_agent/executor/temporal/system_activities.py @@ -0,0 +1,96 @@ +from typing import Any, Dict + +from temporalio import activity + +from mcp_agent.mcp.client_proxy import ( + log_via_proxy, + ask_via_proxy, + notify_via_proxy, + request_via_proxy, +) +from mcp_agent.core.context_dependent import ContextDependent + + +class SystemActivities(ContextDependent): + """Activities used by Temporal workflows to interact with the MCPApp gateway.""" + + @activity.defn(name="mcp_forward_log") + async def forward_log( + self, + execution_id: str, + level: str, + namespace: str, + message: str, + data: Dict[str, Any] | None = None, + ) -> bool: + registry = self.context.server_registry + gateway_url = getattr(self.context, "gateway_url", None) + gateway_token = getattr(self.context, "gateway_token", None) + return await log_via_proxy( + registry, + execution_id=execution_id, + level=level, + namespace=namespace, + message=message, + data=data or {}, + gateway_url=gateway_url, + gateway_token=gateway_token, + ) + + @activity.defn(name="mcp_request_user_input") + async def request_user_input( + self, + session_id: str, + workflow_id: str, + execution_id: str, + prompt: str, + signal_name: str = "human_input", + ) -> Dict[str, Any]: + # Reuse proxy ask API; returns {result} or {error} + registry = self.context.server_registry + gateway_url = getattr(self.context, "gateway_url", None) + gateway_token = getattr(self.context, "gateway_token", None) + return await ask_via_proxy( + registry, + execution_id=execution_id, + prompt=prompt, + metadata={ + "session_id": session_id, + "workflow_id": workflow_id, + "signal_name": signal_name, + }, + gateway_url=gateway_url, + gateway_token=gateway_token, + ) + + @activity.defn(name="mcp_relay_notify") + async def relay_notify( + self, execution_id: str, method: str, params: Dict[str, Any] | None = None + ) -> bool: + registry = self.context.server_registry + gateway_url = getattr(self.context, "gateway_url", None) + gateway_token = getattr(self.context, "gateway_token", None) + return await notify_via_proxy( + registry, + execution_id=execution_id, + method=method, + params=params or {}, + gateway_url=gateway_url, + gateway_token=gateway_token, + ) + + @activity.defn(name="mcp_relay_request") + async def relay_request( + self, execution_id: str, method: str, params: Dict[str, Any] | None = None + ) -> Dict[str, Any]: + registry = self.context.server_registry + gateway_url = getattr(self.context, "gateway_url", None) + gateway_token = getattr(self.context, "gateway_token", None) + return await request_via_proxy( + registry, + execution_id=execution_id, + method=method, + params=params or {}, + gateway_url=gateway_url, + gateway_token=gateway_token, + ) diff --git a/src/mcp_agent/executor/temporal/temporal_context.py b/src/mcp_agent/executor/temporal/temporal_context.py new file mode 100644 index 000000000..fa1cbf49b --- /dev/null +++ b/src/mcp_agent/executor/temporal/temporal_context.py @@ -0,0 +1,49 @@ +from typing import Optional + +EXECUTION_ID_KEY = "__execution_id" + +# Fallback global for non-Temporal contexts. This is best-effort only and +# used when neither workflow nor activity runtime is available. +_EXECUTION_ID: Optional[str] = None + + +def set_execution_id(execution_id: Optional[str]) -> None: + global _EXECUTION_ID + _EXECUTION_ID = execution_id + + +def get_execution_id() -> Optional[str]: + """Return the current Temporal run identifier to use for gateway routing. + + Priority: + - If inside a Temporal workflow, return workflow.info().run_id + - Else if inside a Temporal activity, return activity.info().workflow_run_id + - Else fall back to the global (best-effort) + """ + # Try workflow runtime first + try: + from temporalio import workflow # type: ignore + + try: + if workflow.in_workflow(): + return workflow.info().run_id + except Exception: + pass + except Exception: + pass + + # Then try activity runtime + try: + from temporalio import activity # type: ignore + + try: + info = activity.info() + if info is not None and getattr(info, "workflow_run_id", None): + return info.workflow_run_id + except Exception: + pass + except Exception: + pass + + # Fallback to module-global (primarily for non-Temporal contexts) + return _EXECUTION_ID diff --git a/src/mcp_agent/executor/temporal/workflow_signal.py b/src/mcp_agent/executor/temporal/workflow_signal.py index be4fb6801..30f6f4835 100644 --- a/src/mcp_agent/executor/temporal/workflow_signal.py +++ b/src/mcp_agent/executor/temporal/workflow_signal.py @@ -91,7 +91,7 @@ async def wait_for_signal( TimeoutError: If timeout is reached ValueError: If no value exists for the signal after waiting """ - if not workflow._Runtime.current(): + if not workflow.in_workflow(): raise RuntimeError("wait_for_signal must be called from within a workflow") # Get the mailbox safely from ContextVar @@ -156,7 +156,7 @@ async def signal(self, signal: Signal[SignalValueT]) -> None: # Validate the signal (already checks workflow_id is not None) self.validate_signal(signal) - if workflow._Runtime.current() is not None: + if workflow.in_workflow(): workflow_info = workflow.info() if ( signal.workflow_id == workflow_info.workflow_id diff --git a/src/mcp_agent/executor/workflow.py b/src/mcp_agent/executor/workflow.py index 05bae32a6..7e0eed92d 100644 --- a/src/mcp_agent/executor/workflow.py +++ b/src/mcp_agent/executor/workflow.py @@ -16,7 +16,10 @@ from pydantic import BaseModel, ConfigDict, Field from mcp_agent.core.context_dependent import ContextDependent -from mcp_agent.executor.workflow_signal import Signal, SignalMailbox +from mcp_agent.executor.workflow_signal import ( + Signal, + SignalMailbox, +) from mcp_agent.logging.logger import get_logger if TYPE_CHECKING: @@ -215,6 +218,7 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution": # Using __mcp_agent_ prefix to avoid conflicts with user parameters provided_workflow_id = kwargs.pop("__mcp_agent_workflow_id", None) provided_task_queue = kwargs.pop("__mcp_agent_task_queue", None) + workflow_memo = kwargs.pop("__mcp_agent_workflow_memo", None) self.update_status("scheduled") @@ -232,6 +236,7 @@ async def run_async(self, *args, **kwargs) -> "WorkflowExecution": *args, workflow_id=provided_workflow_id, task_queue=provided_task_queue, + workflow_memo=workflow_memo, **kwargs, ) self._workflow_id = handle.id @@ -734,6 +739,66 @@ async def initialize(self): "Signal handler not attached: Temporal support unavailable" ) + # Read memo (if any) and set gateway overrides on context for activities + try: + from temporalio import workflow as _twf + + # Preferred API: direct memo mapping from Temporal runtime + memo_map = None + try: + memo_map = _twf.memo() + except Exception: + # Fallback to info().memo if available + try: + _info = _twf.info() + memo_map = getattr(_info, "memo", None) + except Exception: + memo_map = None + + if isinstance(memo_map, dict): + gateway_url = memo_map.get("gateway_url") + gateway_token = memo_map.get("gateway_token") + + self._logger.debug( + f"Proxy parameters: gateway_url={gateway_url}, gateway_token={gateway_token}" + ) + + if gateway_url: + try: + self.context.gateway_url = gateway_url + except Exception: + pass + if gateway_token: + try: + self.context.gateway_token = gateway_token + except Exception: + pass + except Exception: + # Safe to ignore if called outside workflow sandbox or memo unavailable + pass + + # Expose a virtual upstream session (passthrough) bound to this run via activities + # This lets any code use context.upstream_session like a real session. + try: + from mcp_agent.executor.temporal.session_proxy import SessionProxy + + upstream_session = getattr(self.context, "upstream_session", None) + + if upstream_session is None: + self.context.upstream_session = SessionProxy( + executor=self.executor, + context=self.context, + ) + + app = self.context.app + if app: + # Ensure the app's logger is bound to the current context with upstream_session + if app._logger and hasattr(app._logger, "_bound_context"): + app._logger._bound_context = self.context + except Exception: + # Non-fatal if context is immutable early; will be set after run_id assignment in run_async + pass + self._initialized = True self.state.updated_at = datetime.now(timezone.utc).timestamp() diff --git a/src/mcp_agent/logging/listeners.py b/src/mcp_agent/logging/listeners.py index 0abc4329f..051baf075 100644 --- a/src/mcp_agent/logging/listeners.py +++ b/src/mcp_agent/logging/listeners.py @@ -249,7 +249,7 @@ async def handle_matched_event(self, event: Event) -> None: ) if upstream_session is None: - # No upstream_session available, event cannot be forwarded + # No upstream_session available; silently skip return # Map our EventType to MCP LoggingLevel; fold progress -> info diff --git a/src/mcp_agent/logging/logger.py b/src/mcp_agent/logging/logger.py index e69327d06..194c0cc45 100644 --- a/src/mcp_agent/logging/logger.py +++ b/src/mcp_agent/logging/logger.py @@ -8,6 +8,7 @@ """ import asyncio +from datetime import timedelta import threading import time @@ -15,6 +16,7 @@ from contextlib import asynccontextmanager, contextmanager + from mcp_agent.logging.events import ( Event, EventContext, @@ -72,18 +74,130 @@ def _emit_event(self, event: Event): asyncio.create_task(self.event_bus.emit(event)) else: # If no loop is running, run it until the emit completes + # Detect Temporal workflow runtime without hard dependency + # If inside Temporal workflow sandbox, avoid run_until_complete and use workflow-safe forwarding + in_temporal_workflow = False try: - loop.run_until_complete(self.event_bus.emit(event)) - except NotImplementedError: - # Handle Temporal workflow environment where run_until_complete() is not implemented - # In Temporal, we can't block on async operations, so we'll need to avoid this - # Simply log to stdout/stderr as a fallback - import sys - - print( - f"[{event.type}] {event.namespace}: {event.message}", - file=sys.stderr, - ) + from temporalio import workflow as _wf # type: ignore + + try: + in_temporal_workflow = bool(_wf.in_workflow()) + except Exception: + in_temporal_workflow = False + except Exception: + in_temporal_workflow = False + + if in_temporal_workflow: + # Prefer forwarding via the upstream session proxy using a workflow task, if available. + try: + from mcp_agent.executor.temporal.temporal_context import ( + get_execution_id as _get_exec_id, + ) + + upstream = getattr(event, "upstream_session", None) + if ( + upstream is None + and getattr(self, "_bound_context", None) is not None + ): + try: + upstream = getattr( + self._bound_context, "upstream_session", None + ) + except Exception: + upstream = None + + # Construct payload + async def _forward_via_proxy(): + # If we have an upstream session, use it first + if upstream is not None: + try: + level_map = { + "debug": "debug", + "info": "info", + "warning": "warning", + "error": "error", + "progress": "info", + } + level = level_map.get(event.type, "info") + logger_name = ( + event.namespace + if not event.name + else f"{event.namespace}.{event.name}" + ) + data = { + "message": event.message, + "namespace": event.namespace, + "name": event.name, + "timestamp": event.timestamp.isoformat(), + } + if event.data: + data["data"] = event.data + if event.trace_id or event.span_id: + data["trace"] = { + "trace_id": event.trace_id, + "span_id": event.span_id, + } + if event.context is not None: + try: + data["context"] = event.context.dict() + except Exception: + pass + + await upstream.send_log_message( # type: ignore[attr-defined] + level=level, data=data, logger=logger_name + ) + return + except Exception: + pass + + # Fallback: use activity gateway directly if execution_id is available + try: + exec_id = _get_exec_id() + if exec_id: + level = { + "debug": "debug", + "info": "info", + "warning": "warning", + "error": "error", + "progress": "info", + }.get(event.type, "info") + ns = event.namespace + msg = event.message + data = event.data or {} + # Call by activity name to align with worker registration + await _wf.execute_activity( + "mcp_forward_log", + exec_id, + level, + ns, + msg, + data, + schedule_to_close_timeout=timedelta(seconds=5), + ) + return + except Exception: + pass + + # If all else fails, fall back to stderr transport + self.event_bus.emit_with_stderr_transport(event) + + try: + _wf.create_task(_forward_via_proxy()) + return + except Exception: + # Could not create workflow task, fall through to stderr transport + pass + except Exception: + # If Temporal workflow module unavailable or any error occurs, fall through + pass + + # As a last resort, log to stdout/stderr as a fallback + self.event_bus.emit_with_stderr_transport(event) + else: + try: + loop.run_until_complete(self.event_bus.emit(event)) + except NotImplementedError: + pass def event( self, @@ -107,6 +221,7 @@ def event( # can forward reliably, regardless of the current task context. # 1) Prefer logger-bound app context (set at creation or refreshed by caller) extra_event_fields: Dict[str, Any] = {} + try: upstream = ( getattr(self._bound_context, "upstream_session", None) @@ -117,6 +232,19 @@ def event( extra_event_fields["upstream_session"] = upstream except Exception: pass + # Fallback to default bound context if logger wasn't explicitly bound + if ( + "upstream_session" not in extra_event_fields + and _default_bound_context is not None + ): + try: + upstream = getattr(_default_bound_context, "upstream_session", None) + if upstream is not None: + extra_event_fields["upstream_session"] = upstream + except Exception: + pass + + # Do not use global context fallbacks here; they are unsafe under concurrency. # No further fallbacks; upstream forwarding must be enabled by passing # a bound context when creating the logger or by server code attaching @@ -382,6 +510,7 @@ async def managed(cls, **config_kwargs): _logger_lock = threading.Lock() _loggers: Dict[str, Logger] = {} +_default_bound_context: Any | None = None def get_logger(namespace: str, session_id: str | None = None, context=None) -> Logger: @@ -401,15 +530,20 @@ def get_logger(namespace: str, session_id: str | None = None, context=None) -> L with _logger_lock: existing = _loggers.get(namespace) if existing is None: - logger = Logger(namespace, session_id, bound_context=context) + bound_ctx = context if context is not None else _default_bound_context + logger = Logger(namespace, session_id, bound_ctx) _loggers[namespace] = logger return logger + # Update session_id/bound context if caller provides them if session_id is not None: existing.session_id = session_id if context is not None: - try: - existing._bound_context = context - except Exception: - pass + existing._bound_context = context + return existing + + +def set_default_bound_context(ctx: Any | None) -> None: + global _default_bound_context + _default_bound_context = ctx diff --git a/src/mcp_agent/logging/transport.py b/src/mcp_agent/logging/transport.py index 2bf78a968..c795c47a8 100644 --- a/src/mcp_agent/logging/transport.py +++ b/src/mcp_agent/logging/transport.py @@ -8,6 +8,7 @@ import json import uuid import datetime +import sys from abc import ABC, abstractmethod from typing import Dict, List, Protocol from pathlib import Path @@ -432,6 +433,22 @@ async def emit(self, event: Event): # Then queue for listeners await self._queue.put(event) + def emit_with_stderr_transport(self, event: Event): + print( + f"[{event.type}] {event.namespace}: {event.message}", + file=sys.stderr, + ) + + # Initialize queue and start processing if needed + if not hasattr(self, "_queue"): + self.init_queue() + # Auto-start the event processing task if not running + if not self._running: + self._running = True + self._task = asyncio.create_task(self._process_events()) + + self._queue.put_nowait(event) + async def _send_to_transport(self, event: Event): """Send event to transport with error handling.""" try: diff --git a/src/mcp_agent/mcp/client_proxy.py b/src/mcp_agent/mcp/client_proxy.py new file mode 100644 index 000000000..af7d8f34b --- /dev/null +++ b/src/mcp_agent/mcp/client_proxy.py @@ -0,0 +1,178 @@ +from typing import Any, Dict, Optional + +import os +import httpx + +from mcp_agent.mcp.mcp_server_registry import ServerRegistry +from urllib.parse import quote + + +def _resolve_gateway_url( + server_registry: Optional[ServerRegistry] = None, + server_name: Optional[str] = None, + gateway_url: Optional[str] = None, +) -> str: + # Highest precedence: explicit override + if gateway_url: + return gateway_url.rstrip("/") + + # Next: environment variable + env_url = os.environ.get("MCP_GATEWAY_URL") + if env_url: + return env_url.rstrip("/") + + # Next: a registry entry (if provided) + if server_registry and server_name: + cfg = server_registry.get_server_config(server_name) + if cfg and getattr(cfg, "url", None): + return cfg.url.rstrip("/") + + # Fallback: default local server + return "http://127.0.0.1:8000" + + +async def log_via_proxy( + server_registry: Optional[ServerRegistry], + execution_id: str, + level: str, + namespace: str, + message: str, + data: Dict[str, Any] | None = None, + *, + server_name: Optional[str] = None, + gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, +) -> bool: + base = _resolve_gateway_url(server_registry, server_name, gateway_url) + url = f"{base}/internal/workflows/log" + headers: Dict[str, str] = {} + tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") + if tok: + headers["X-MCP-Gateway-Token"] = tok + timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) + try: + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, + json={ + "execution_id": execution_id, + "level": level, + "namespace": namespace, + "message": message, + "data": data or {}, + }, + headers=headers, + ) + except httpx.RequestError: + return False + if r.status_code >= 400: + return False + try: + resp = r.json() if r.content else {"ok": True} + except ValueError: + resp = {"ok": True} + return bool(resp.get("ok", True)) + + +async def ask_via_proxy( + server_registry: Optional[ServerRegistry], + execution_id: str, + prompt: str, + metadata: Dict[str, Any] | None = None, + *, + server_name: Optional[str] = None, + gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, +) -> Dict[str, Any]: + base = _resolve_gateway_url(server_registry, server_name, gateway_url) + url = f"{base}/internal/human/prompts" + headers: Dict[str, str] = {} + tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") + if tok: + headers["X-MCP-Gateway-Token"] = tok + timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) + try: + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, + json={ + "execution_id": execution_id, + "prompt": {"text": prompt}, + "metadata": metadata or {}, + }, + headers=headers, + ) + except httpx.RequestError: + return {"error": "request_failed"} + if r.status_code >= 400: + return {"error": r.text} + try: + return r.json() if r.content else {"error": "invalid_response"} + except ValueError: + return {"error": "invalid_response"} + + +async def notify_via_proxy( + server_registry: Optional[ServerRegistry], + execution_id: str, + method: str, + params: Dict[str, Any] | None = None, + *, + server_name: Optional[str] = None, + gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, +) -> bool: + base = _resolve_gateway_url(server_registry, server_name, gateway_url) + url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/notify" + headers: Dict[str, str] = {} + tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") + if tok: + headers["X-MCP-Gateway-Token"] = tok + timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "10")) + + try: + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, json={"method": method, "params": params or {}}, headers=headers + ) + except httpx.RequestError: + return False + if r.status_code >= 400: + return False + try: + resp = r.json() if r.content else {"ok": True} + except ValueError: + resp = {"ok": True} + return bool(resp.get("ok", True)) + + +async def request_via_proxy( + server_registry: Optional[ServerRegistry], + execution_id: str, + method: str, + params: Dict[str, Any] | None = None, + *, + server_name: Optional[str] = None, + gateway_url: Optional[str] = None, + gateway_token: Optional[str] = None, +) -> Dict[str, Any]: + base = _resolve_gateway_url(server_registry, server_name, gateway_url) + url = f"{base}/internal/session/by-run/{quote(execution_id, safe='')}/request" + headers: Dict[str, str] = {} + tok = gateway_token or os.environ.get("MCP_GATEWAY_TOKEN") + if tok: + headers["X-MCP-Gateway-Token"] = tok + timeout = float(os.environ.get("MCP_GATEWAY_TIMEOUT", "20")) + try: + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post( + url, json={"method": method, "params": params or {}}, headers=headers + ) + except httpx.RequestError: + return {"error": "request_failed"} + if r.status_code >= 400: + return {"error": r.text} + try: + return r.json() if r.content else {"error": "invalid_response"} + except ValueError: + return {"error": "invalid_response"} diff --git a/src/mcp_agent/mcp/mcp_aggregator.py b/src/mcp_agent/mcp/mcp_aggregator.py index ca36c0e0e..f57e67b48 100644 --- a/src/mcp_agent/mcp/mcp_aggregator.py +++ b/src/mcp_agent/mcp/mcp_aggregator.py @@ -345,26 +345,37 @@ async def load_server(self, server_name: str): # Process tools async with self._tool_map_lock: self._server_to_tool_map[server_name] = [] - + # Get server configuration to check for tool filtering allowed_tools = None disabled_tool_count = 0 - if (self.context is None or self.context.server_registry is None - or not hasattr(self.context.server_registry, "get_server_config")): - logger.warning(f"No config found for server '{server_name}', no tool filter will be applied...") + if ( + self.context is None + or self.context.server_registry is None + or not hasattr(self.context.server_registry, "get_server_config") + ): + logger.warning( + f"No config found for server '{server_name}', no tool filter will be applied..." + ) else: - allowed_tools = self.context.server_registry.get_server_config(server_name).allowed_tools + allowed_tools = self.context.server_registry.get_server_config( + server_name + ).allowed_tools if allowed_tools is not None and len(allowed_tools) == 0: - logger.warning(f"Allowed tool list is explicitly empty for server '{server_name}'") - + logger.warning( + f"Allowed tool list is explicitly empty for server '{server_name}'" + ) + for tool in tools: # Apply tool filtering if configured - O(1) lookup with set if allowed_tools is not None and tool.name not in allowed_tools: - logger.debug(f"Filtering out tool '{tool.name}' from server '{server_name}' (not in allowed_tools)") + logger.debug( + f"Filtering out tool '{tool.name}' from server '{server_name}' (not in allowed_tools)" + ) disabled_tool_count += 1 continue - + namespaced_tool_name = f"{server_name}{SEP}{tool.name}" namespaced_tool = NamespacedTool( tool=tool, diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index e6d0fc60a..9d0a38433 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -7,8 +7,13 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING +import os +import secrets +import asyncio from mcp.server.fastmcp import Context as MCPContext, FastMCP +from starlette.requests import Request +from starlette.responses import JSONResponse from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools import Tool as FastTool @@ -20,6 +25,7 @@ WorkflowRegistry, InMemoryWorkflowRegistry, ) + from mcp_agent.logging.logger import get_logger from mcp_agent.logging.logger import LoggingConfig from mcp_agent.mcp.mcp_server_registry import ServerRegistry @@ -28,6 +34,33 @@ from mcp_agent.core.context import Context logger = get_logger(__name__) +# Simple in-memory registry mapping workflow execution_id -> upstream session handle. +# Allows external workers (e.g., Temporal) to relay logs/prompts through MCPApp. +_RUN_SESSION_REGISTRY: Dict[str, Any] = {} +_RUN_EXECUTION_ID_REGISTRY: Dict[str, str] = {} +_RUN_SESSION_LOCK = asyncio.Lock() +_PENDING_PROMPTS: Dict[str, Dict[str, Any]] = {} +_PENDING_PROMPTS_LOCK = asyncio.Lock() +_IDEMPOTENCY_KEYS_SEEN: Dict[str, Set[str]] = {} +_IDEMPOTENCY_KEYS_LOCK = asyncio.Lock() + + +async def _register_session(run_id: str, execution_id: str, session: Any) -> None: + async with _RUN_SESSION_LOCK: + _RUN_SESSION_REGISTRY[execution_id] = session + _RUN_EXECUTION_ID_REGISTRY[run_id] = execution_id + + +async def _unregister_session(run_id: str) -> None: + async with _RUN_SESSION_LOCK: + execution_id = _RUN_EXECUTION_ID_REGISTRY.pop(run_id, None) + if execution_id: + _RUN_SESSION_REGISTRY.pop(execution_id, None) + + +async def _get_session(execution_id: str) -> Any | None: + async with _RUN_SESSION_LOCK: + return _RUN_SESSION_REGISTRY.get(execution_id) class ServerContext(ContextDependent): @@ -287,6 +320,259 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: # Don't clean up the MCPApp here - let the caller handle that pass + # Helper: install internal HTTP routes (not MCP tools) + def _install_internal_routes(mcp_server: FastMCP) -> None: + @mcp_server.custom_route( + "/internal/session/by-run/{execution_id}/notify", + methods=["POST"], + include_in_schema=False, + ) + async def _relay_notify(request: Request): + body = await request.json() + execution_id = request.path_params.get("execution_id") + method = body.get("method") + params = body.get("params") or {} + + # Optional shared-secret auth + gw_token = os.environ.get("MCP_GATEWAY_TOKEN") + if gw_token and not secrets.compare_digest( + request.headers.get("X-MCP-Gateway-Token", ""), gw_token + ): + return JSONResponse( + {"ok": False, "error": "unauthorized"}, status_code=401 + ) + + # Optional idempotency handling + idempotency_key = params.get("idempotency_key") + if idempotency_key: + async with _IDEMPOTENCY_KEYS_LOCK: + seen = _IDEMPOTENCY_KEYS_SEEN.setdefault(execution_id or "", set()) + if idempotency_key in seen: + return JSONResponse({"ok": True, "idempotent": True}) + seen.add(idempotency_key) + + session = await _get_session(execution_id) + if not session: + return JSONResponse( + {"ok": False, "error": "session_not_available"}, status_code=503 + ) + + try: + # Special-case the common logging notification helper + if method == "notifications/message": + level = str(params.get("level", "info")) + data = params.get("data") + logger_name = params.get("logger") + related_request_id = params.get("related_request_id") + await session.send_log_message( # type: ignore[attr-defined] + level=level, # type: ignore[arg-type] + data=data, + logger=logger_name, + related_request_id=related_request_id, + ) + elif method == "notifications/progress": + # Minimal support for progress relay + progress_token = params.get("progressToken") + progress = params.get("progress") + total = params.get("total") + message = params.get("message") + await session.send_progress_notification( # type: ignore[attr-defined] + progress_token=progress_token, + progress=progress, + total=total, + message=message, + ) + else: + # Generic passthrough using low-level RPC if available + rpc = getattr(session, "rpc", None) + if rpc and hasattr(rpc, "notify"): + await rpc.notify(method, params) + else: + return JSONResponse( + {"ok": False, "error": f"unsupported method: {method}"}, + status_code=400, + ) + + return JSONResponse({"ok": True}) + except Exception as e: + return JSONResponse({"ok": False, "error": str(e)}, status_code=500) + + @mcp_server.custom_route( + "/internal/session/by-run/{execution_id}/request", + methods=["POST"], + include_in_schema=False, + ) + async def _relay_request(request: Request): + from mcp.types import ( + CreateMessageRequest, + CreateMessageRequestParams, + CreateMessageResult, + ElicitRequest, + ElicitRequestParams, + ElicitResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + EmptyResult, + ServerRequest, + ) + + body = await request.json() + execution_id = request.path_params.get("execution_id") + method = body.get("method") + params = body.get("params") or {} + + session = await _get_session(execution_id) + if not session: + return JSONResponse({"error": "session_not_available"}, status_code=503) + + try: + # Prefer generic request passthrough if available + rpc = getattr(session, "rpc", None) + if rpc and hasattr(rpc, "request"): + result = await rpc.request(method, params) + return JSONResponse(result) + # Fallback: Map a small set of supported server->client requests + if method == "sampling/createMessage": + req = ServerRequest( + CreateMessageRequest( + method="sampling/createMessage", + params=CreateMessageRequestParams(**params), + ) + ) + result = await session.send_request( # type: ignore[attr-defined] + request=req, + result_type=CreateMessageResult, + ) + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + elif method == "elicitation/create": + req = ServerRequest( + ElicitRequest( + method="elicitation/create", + params=ElicitRequestParams(**params), + ) + ) + result = await session.send_request( # type: ignore[attr-defined] + request=req, + result_type=ElicitResult, + ) + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + elif method == "roots/list": + req = ServerRequest(ListRootsRequest(method="roots/list")) + result = await session.send_request( # type: ignore[attr-defined] + request=req, + result_type=ListRootsResult, + ) + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + elif method == "ping": + req = ServerRequest(PingRequest(method="ping")) + result = await session.send_request( # type: ignore[attr-defined] + request=req, + result_type=EmptyResult, + ) + return JSONResponse( + result.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + else: + return JSONResponse( + {"error": f"unsupported method: {method}"}, status_code=400 + ) + except Exception as e: + return JSONResponse({"error": str(e)}, status_code=500) + + @mcp_server.custom_route( + "/internal/workflows/log", methods=["POST"], include_in_schema=False + ) + async def _internal_workflows_log(request: Request): + body = await request.json() + execution_id = body.get("execution_id") + level = str(body.get("level", "info")).lower() + namespace = body.get("namespace") or "mcp_agent" + message = body.get("message") or "" + data = body.get("data") or {} + + # Optional shared-secret auth + gw_token = os.environ.get("MCP_GATEWAY_TOKEN") + if gw_token and not secrets.compare_digest( + request.headers.get("X-MCP-Gateway-Token", ""), gw_token + ): + return JSONResponse( + {"ok": False, "error": "unauthorized"}, status_code=401 + ) + + session = await _get_session(execution_id) + if not session: + return JSONResponse( + {"ok": False, "error": "session_not_available"}, status_code=503 + ) + if level not in ("debug", "info", "warning", "error"): + level = "info" + try: + await session.send_log_message( + level=level, # type: ignore[arg-type] + data={ + "message": message, + "namespace": namespace, + "data": data, + }, + logger=namespace, + ) + return JSONResponse({"ok": True}) + except Exception as e: + return JSONResponse({"ok": False, "error": str(e)}, status_code=500) + + @mcp_server.custom_route( + "/internal/human/prompts", methods=["POST"], include_in_schema=False + ) + async def _internal_human_prompts(request: Request): + body = await request.json() + execution_id = body.get("execution_id") + prompt = body.get("prompt") or {} + metadata = body.get("metadata") or {} + + # Optional shared-secret auth + gw_token = os.environ.get("MCP_GATEWAY_TOKEN") + if gw_token and not secrets.compare_digest( + request.headers.get("X-MCP-Gateway-Token", ""), gw_token + ): + return JSONResponse({"error": "unauthorized"}, status_code=401) + + session = await _get_session(execution_id) + if not session: + return JSONResponse({"error": "session_not_available"}, status_code=503) + import uuid + + request_id = str(uuid.uuid4()) + payload = { + "kind": "human_input_request", + "request_id": request_id, + "prompt": prompt if isinstance(prompt, dict) else {"text": str(prompt)}, + "metadata": metadata, + } + try: + # Store pending prompt correlation for submit tool + async with _PENDING_PROMPTS_LOCK: + _PENDING_PROMPTS[request_id] = { + "workflow_id": metadata.get("workflow_id"), + "execution_id": execution_id, + "signal_name": metadata.get("signal_name", "human_input"), + "session_id": metadata.get("session_id"), + } + await session.send_log_message( + level="info", # type: ignore[arg-type] + data=payload, + logger="mcp_agent.human", + ) + return JSONResponse({"request_id": request_id}) + except Exception as e: + return JSONResponse({"error": str(e)}, status_code=500) + # Create or attach FastMCP server if app.mcp: # Using an externally provided FastMCP instance: attach app and context @@ -305,6 +591,11 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: create_workflow_tools(mcp, server_context) # Register function-declared tools (from @app.tool/@app.async_tool) create_declared_function_tools(mcp, server_context) + # Install internal HTTP routes + try: + _install_internal_routes(mcp) + except Exception: + pass else: mcp = FastMCP( name=app.name or "mcp_agent_server", @@ -318,6 +609,11 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: # Store the server on the app so it's discoverable and can be extended further app.mcp = mcp setattr(mcp, "_mcp_agent_app", app) + # Install internal HTTP routes + try: + _install_internal_routes(mcp) + except Exception: + pass # Register logging/setLevel handler so client can adjust verbosity dynamically # This enables MCP logging capability in InitializeResult.capabilities.logging @@ -347,6 +643,11 @@ def list_workflows(ctx: MCPContext) -> Dict[str, Dict[str, Any]]: Returns information about each workflow type including name, description, and parameters. This helps in making an informed decision about which workflow to run. """ + # Ensure upstream session is set for any logs emitted during this call + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass result: Dict[str, Dict[str, Any]] = {} workflows, _ = _resolve_workflows_and_context(ctx) workflows = workflows or {} @@ -391,6 +692,12 @@ async def list_workflow_runs(ctx: MCPContext) -> List[Dict[str, Any]]: Returns: A dictionary mapping workflow instance IDs to their detailed status information. """ + # Ensure upstream session is set for any logs emitted during this call + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass + server_context = getattr( ctx.request_context, "lifespan_context", None ) or _get_attached_server_context(ctx.fastmcp) @@ -423,6 +730,11 @@ async def run_workflow( A dict with workflow_id and run_id for the started workflow run, can be passed to workflows/get_status, workflows/resume, and workflows/cancel. """ + # Ensure upstream session is set before starting the workflow + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass return await _workflow_run(ctx, workflow_name, run_parameters, **kwargs) @mcp.tool(name="workflows-get_status") @@ -444,6 +756,11 @@ async def get_workflow_status( Returns: A dictionary with comprehensive information about the workflow status. """ + # Ensure upstream session is available for any status-related logs + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass return await _workflow_status(ctx, run_id=run_id, workflow_name=workflow_id) @mcp.tool(name="workflows-resume") @@ -471,6 +788,11 @@ async def resume_workflow( Returns: True if the workflow was resumed, False otherwise. """ + # Ensure upstream session is available for any status-related logs + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass server_context: ServerContext = ctx.request_context.lifespan_context workflow_registry = server_context.workflow_registry @@ -513,6 +835,11 @@ async def cancel_workflow( Returns: True if the workflow was cancelled, False otherwise. """ + # Ensure upstream session is available for any status-related logs + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass server_context: ServerContext = ctx.request_context.lifespan_context workflow_registry = server_context.workflow_registry @@ -528,6 +855,40 @@ async def cancel_workflow( else: logger.error(f"Failed to cancel workflow {workflow_name} with ID {run_id}") + # region Proxy tools for external runners (e.g., Temporal workers) + + # Removed MCP tools for internal proxying in favor of private HTTP routes + + @mcp.tool(name="human_input.submit") + async def human_input_submit(request_id: str, text: str) -> Dict[str, Any]: + """Client replies to a human_input_request; signal the Temporal workflow.""" + app_ref = _get_attached_app(mcp) + if app_ref is None or app_ref.context is None: + return {"ok": False, "error": "server not ready"} + async with _PENDING_PROMPTS_LOCK: + info = _PENDING_PROMPTS.pop(request_id, None) + if not info: + return {"ok": False, "error": "unknown request_id"} + try: + from mcp_agent.executor.temporal import TemporalExecutor + + executor = app_ref.context.executor + if not isinstance(executor, TemporalExecutor): + return {"ok": False, "error": "temporal executor not active"} + client = await executor.ensure_client() + handle = client.get_workflow_handle( + workflow_id=info.get("workflow_id"), run_id=info.get("run_id") + ) + await handle.signal( + info.get("signal_name", "human_input"), + {"request_id": request_id, "text": text}, + ) + return {"ok": True} + except Exception as e: + return {"ok": False, "error": str(e)} + + # endregion + # endregion return mcp @@ -910,6 +1271,7 @@ async def run( ctx: MCPContext, run_parameters: Dict[str, Any] | None = None, ) -> Dict[str, str]: + _set_upstream_from_request_ctx_if_available(ctx) return await _workflow_run(ctx, workflow_name, run_parameters) @mcp.tool( @@ -922,6 +1284,7 @@ async def run( """, ) async def get_status(ctx: MCPContext, run_id: str) -> Dict[str, Any]: + _set_upstream_from_request_ctx_if_available(ctx) return await _workflow_status(ctx, run_id=run_id, workflow_name=workflow_name) @@ -976,6 +1339,9 @@ async def _workflow_run( run_parameters: Dict[str, Any] | None = None, **kwargs: Any, ) -> Dict[str, str]: + # Use Temporal run_id as the routing key for gateway callbacks. + # We don't have it until after the workflow is started; we'll register mapping post-start. + # Resolve workflows and app context irrespective of startup mode # This now returns a context with upstream_session already set workflows_dict, app_context = _resolve_workflows_and_context(ctx) @@ -1016,16 +1382,76 @@ async def _workflow_run( if task_queue: run_parameters["__mcp_agent_task_queue"] = task_queue + # Build memo for Temporal runs if gateway info is available + workflow_memo = None + try: + # Prefer explicit kwargs, else infer from request headers/environment + # FastMCP keeps raw request under ctx.request_context.request if available + gateway_url = kwargs.get("gateway_url") + gateway_token = kwargs.get("gateway_token") + + if gateway_url is None: + try: + req = getattr(ctx.request_context, "request", None) + if req is not None: + # Custom header if present + h = req.headers + gateway_url = ( + h.get("X-MCP-Gateway-URL") + or h.get("X-Forwarded-Url") + or h.get("X-Forwarded-Proto") + ) + # Best-effort reconstruction if only proto/host provided + if gateway_url is None: + proto = h.get("X-Forwarded-Proto") or "http" + host = h.get("X-Forwarded-Host") or h.get("Host") + if host: + gateway_url = f"{proto}://{host}" + except Exception: + pass + + if gateway_token is None: + try: + req = getattr(ctx.request_context, "request", None) + if req is not None: + gateway_token = req.headers.get("X-MCP-Gateway-Token") + except Exception: + pass + + if gateway_url or gateway_token: + workflow_memo = { + "gateway_url": gateway_url, + "gateway_token": gateway_token, + } + except Exception: + workflow_memo = None + # Run the workflow asynchronously and get its ID - execution = await workflow.run_async(**run_parameters) + execution = await workflow.run_async( + __mcp_agent_workflow_memo=workflow_memo, + **run_parameters, + ) + execution_id = execution.run_id logger.info( - f"Workflow {workflow_name} started with workflow ID {execution.workflow_id} and run ID {execution.run_id}. Parameters: {run_parameters}" + f"Workflow {workflow_name} started execution {execution_id} for workflow ID {execution.workflow_id}, " + f"run ID {execution.run_id}. Parameters: {run_parameters}" ) + # Register upstream session for this run so external workers can proxy logs/prompts + try: + await _register_session( + run_id=execution.run_id, + execution_id=execution_id, + session=getattr(ctx, "session", None), + ) + except Exception: + pass + return { "workflow_id": execution.workflow_id, "run_id": execution.run_id, + "execution_id": execution_id, } except Exception as e: @@ -1037,6 +1463,10 @@ async def _workflow_status( ctx: MCPContext, run_id: str, workflow_name: str | None = None ) -> Dict[str, Any]: # Ensure upstream session so status-related logs are forwarded + try: + _set_upstream_from_request_ctx_if_available(ctx) + except Exception: + pass workflow_registry: WorkflowRegistry | None = _resolve_workflow_registry(ctx) if not workflow_registry: @@ -1049,6 +1479,17 @@ async def _workflow_status( run_id=run_id, workflow_id=workflow_id ) + # Cleanup run registry on terminal states + try: + state = str(status.get("status", "")).lower() + if state in ("completed", "error", "cancelled"): + try: + await _unregister_session(run_id) + except Exception: + pass + except Exception: + pass + return status diff --git a/src/mcp_agent/tracing/token_counter.py b/src/mcp_agent/tracing/token_counter.py index ec794a746..4a1bb3728 100644 --- a/src/mcp_agent/tracing/token_counter.py +++ b/src/mcp_agent/tracing/token_counter.py @@ -827,7 +827,7 @@ async def record_usage( try: from temporalio import workflow as _twf # type: ignore - if _twf._Runtime.current(): # type: ignore[attr-defined] + if _twf.in_workflow(): if _twf.unsafe.is_replaying(): # type: ignore[attr-defined] return except Exception: diff --git a/src/mcp_agent/tracing/token_tracking_decorator.py b/src/mcp_agent/tracing/token_tracking_decorator.py index 491050e41..35833397e 100644 --- a/src/mcp_agent/tracing/token_tracking_decorator.py +++ b/src/mcp_agent/tracing/token_tracking_decorator.py @@ -38,7 +38,7 @@ async def wrapper(self, *args, **kwargs) -> T: try: from temporalio import workflow as _twf # type: ignore - if _twf._Runtime.current(): # type: ignore[attr-defined] + if _twf.in_workflow(): is_temporal_replay = _twf.unsafe.is_replaying() # type: ignore[attr-defined] except Exception: is_temporal_replay = False diff --git a/src/mcp_agent/workflows/deep_orchestrator/README.md b/src/mcp_agent/workflows/deep_orchestrator/README.md index 132e8d195..f765fd214 100644 --- a/src/mcp_agent/workflows/deep_orchestrator/README.md +++ b/src/mcp_agent/workflows/deep_orchestrator/README.md @@ -35,11 +35,11 @@ flowchart TB B --> C{Execute Tasks} C --> D[Extract Knowledge] D --> E{Objective Complete?} + E -->|Yes| G E -->|No| F{Check Policy} F -->|Replan| B F -->|Continue| C F -->|Stop| G[Synthesize Results] - E -->|Yes| G G --> H[Final Result] style B fill:#e1f5fe diff --git a/tests/executor/temporal/test_execution_id_and_interceptor.py b/tests/executor/temporal/test_execution_id_and_interceptor.py new file mode 100644 index 000000000..7aa5f5cb5 --- /dev/null +++ b/tests/executor/temporal/test_execution_id_and_interceptor.py @@ -0,0 +1,112 @@ +import pytest +from unittest.mock import patch + + +@pytest.mark.asyncio +@patch("temporalio.workflow.info") +@patch("temporalio.workflow.in_workflow", return_value=True) +def test_get_execution_id_in_workflow(_mock_in_wf, mock_info): + from mcp_agent.executor.temporal.temporal_context import get_execution_id + + mock_info.return_value.run_id = "run-123" + assert get_execution_id() == "run-123" + + +@pytest.mark.asyncio +@patch("temporalio.activity.info") +def test_get_execution_id_in_activity(mock_act_info): + from mcp_agent.executor.temporal.temporal_context import get_execution_id + + mock_act_info.return_value.workflow_run_id = "run-aaa" + assert get_execution_id() == "run-aaa" + + +def test_interceptor_restores_prev_value(): + from mcp_agent.executor.temporal.interceptor import context_from_header + from mcp_agent.executor.temporal.temporal_context import ( + EXECUTION_ID_KEY, + set_execution_id, + get_execution_id, + ) + import temporalio.converter + + payload_converter = temporalio.converter.default().payload_converter + + class Input: + headers = {} + + set_execution_id("prev") + input = Input() + # simulate header with new value + input.headers[EXECUTION_ID_KEY] = payload_converter.to_payload("new") + + assert get_execution_id() == "prev" + with context_from_header(input, payload_converter): + # inside scope we should get header value + assert get_execution_id() == "new" + # restored + assert get_execution_id() == "prev" + + +@pytest.mark.asyncio +async def test_http_proxy_helpers_happy_and_error_paths(monkeypatch): + from mcp_agent.mcp import client_proxy + + class Resp: + def __init__(self, status_code, json_data=None, text=""): + self.status_code = status_code + self._json = json_data or {} + self.text = text + self.content = b"x" if json_data is not None else b"" + + def json(self): + return self._json + + class Client: + def __init__(self, rcodes_iter): + self._rcodes = rcodes_iter + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, json=None, headers=None): + code, body = next(self._rcodes) + if body is None: + return Resp(code) + return Resp(code, body) + + # log_via_proxy ok, then error + rcodes = iter( + [ + (200, {"ok": True}), + (500, None), + (200, {"ok": True}), + (401, None), + (200, {"ok": True}), + (400, None), + ] + ) + + monkeypatch.setattr( + client_proxy.httpx, "AsyncClient", lambda timeout: Client(rcodes) + ) + + ok = await client_proxy.log_via_proxy(None, "run", "info", "ns", "msg") + assert ok is True + ok = await client_proxy.log_via_proxy(None, "run", "info", "ns", "msg") + assert ok is False + + # notify ok, then error + ok = await client_proxy.notify_via_proxy(None, "run", "m", {}) + assert ok is True + ok = await client_proxy.notify_via_proxy(None, "run", "m", {}) + assert ok is False + + # request ok, then error + res = await client_proxy.request_via_proxy(None, "run", "m", {}) + assert isinstance(res, dict) and res.get("ok", True) in (True,) + res = await client_proxy.request_via_proxy(None, "run", "m", {}) + assert isinstance(res, dict) and "error" in res diff --git a/tests/executor/temporal/test_signal_handler.py b/tests/executor/temporal/test_signal_handler.py index 2b898bc30..6459aea09 100644 --- a/tests/executor/temporal/test_signal_handler.py +++ b/tests/executor/temporal/test_signal_handler.py @@ -53,8 +53,8 @@ def test_attach_to_workflow(handler, mock_workflow): @pytest.mark.asyncio -@patch("temporalio.workflow._Runtime.current", return_value=MagicMock()) -async def test_wait_for_signal(mock_runtime, handler, mock_workflow): +@patch("temporalio.workflow.in_workflow", return_value=True) +async def test_wait_for_signal(_mock_in_wf, handler, mock_workflow): handler.attach_to_workflow(mock_workflow) # Patch the handler's ContextVar to point to the mock_workflow's mailbox handler._mailbox_ref.set(mock_workflow._signal_mailbox) @@ -66,7 +66,7 @@ async def test_wait_for_signal(mock_runtime, handler, mock_workflow): @pytest.mark.asyncio -@patch("temporalio.workflow._Runtime.current", return_value=None) +@patch("temporalio.workflow.in_workflow", return_value=False) @patch( "temporalio.workflow.get_external_workflow_handle", side_effect=__import__("temporalio.workflow").workflow._NotInWorkflowEventLoopError( @@ -74,7 +74,7 @@ async def test_wait_for_signal(mock_runtime, handler, mock_workflow): ), ) async def test_signal_outside_workflow( - mock_get_external, mock_runtime, handler, mock_executor + mock_get_external, _mock_in_wf, handler, mock_executor ): signal = Signal( name="test_signal", diff --git a/tests/executor/test_workflow.py b/tests/executor/test_workflow.py index 71df7c544..bcb11bff1 100644 --- a/tests/executor/test_workflow.py +++ b/tests/executor/test_workflow.py @@ -346,7 +346,10 @@ async def test_run_async_with_temporal_custom_params(self, mock_context): # Verify start_workflow was called with correct parameters workflow.executor.start_workflow.assert_called_once_with( - "TestWorkflow", workflow_id=custom_workflow_id, task_queue=custom_task_queue + "TestWorkflow", + workflow_id=custom_workflow_id, + task_queue=custom_task_queue, + workflow_memo=None, ) # Verify execution uses the handle's ID diff --git a/tests/mcp/test_mcp_aggregator.py b/tests/mcp/test_mcp_aggregator.py index d5ef11299..a592743fc 100644 --- a/tests/mcp/test_mcp_aggregator.py +++ b/tests/mcp/test_mcp_aggregator.py @@ -868,27 +868,28 @@ async def get_prompt(self, name, arguments=None): # ============================================================================= - class MockServerConfig: """Mock server configuration for testing""" + def __init__(self, allowed_tools=None): self.allowed_tools = allowed_tools class DummyContextWithServerRegistry: """Extended dummy context with server registry for tool filtering tests""" + def __init__(self, server_configs=None): self.tracer = None self.tracing_enabled = False self.server_configs = server_configs or {} - + class MockServerRegistry: def __init__(self, configs): self.configs = configs - + def get_server_config(self, server_name): return self.configs.get(server_name, MockServerConfig()) - + def start_server(self, server_name, client_session_factory=None): class DummyCtxMgr: async def __aenter__(self): @@ -896,13 +897,16 @@ class DummySession: async def initialize(self): class InitResult: capabilities = {"tools": True} + return InitResult() + return DummySession() - + async def __aexit__(self, exc_type, exc_val, exc_tb): pass + return DummyCtxMgr() - + self.server_registry = MockServerRegistry(self.server_configs) self._mcp_connection_manager_lock = asyncio.Lock() self._mcp_connection_manager_ref_count = 0 @@ -912,43 +916,59 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): async def test_tool_filtering_with_allowed_tools(): """Test that tools are filtered correctly when allowed_tools is configured""" # Setup server config with allowed tools - server_configs = { - "test_server": MockServerConfig(allowed_tools={"tool1", "tool3"}) - } + server_configs = {"test_server": MockServerConfig(allowed_tools={"tool1", "tool3"})} context = DummyContextWithServerRegistry(server_configs) - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["test_server"], connection_persistence=False, context=context, name="test_agent", ) - + # Mock tools that would be returned from server mock_tools = [ - Tool(name="tool1", description="Description for tool1", inputSchema={"type": "object"}), # Should be included - Tool(name="tool2", description="Description for tool2", inputSchema={"type": "object"}), # Should be filtered out - Tool(name="tool3", description="Description for tool3", inputSchema={"type": "object"}), # Should be included - Tool(name="tool4", description="Description for tool4", inputSchema={"type": "object"}), # Should be filtered out + Tool( + name="tool1", + description="Description for tool1", + inputSchema={"type": "object"}, + ), # Should be included + Tool( + name="tool2", + description="Description for tool2", + inputSchema={"type": "object"}, + ), # Should be filtered out + Tool( + name="tool3", + description="Description for tool3", + inputSchema={"type": "object"}, + ), # Should be included + Tool( + name="tool4", + description="Description for tool4", + inputSchema={"type": "object"}, + ), # Should be filtered out ] - + # Mock _fetch_capabilities to return our test tools async def mock_fetch_capabilities(server_name): return (None, mock_tools, [], []) # capabilities, tools, prompts, resources - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("test_server") - + # Verify only allowed tools were added server_tools = aggregator._server_to_tool_map.get("test_server", []) assert len(server_tools) == 2 - + tool_names = [tool.tool.name for tool in server_tools] assert "tool1" in tool_names assert "tool3" in tool_names assert "tool2" not in tool_names assert "tool4" not in tool_names - + # Verify namespaced tools map assert "test_server_tool1" in aggregator._namespaced_tool_map assert "test_server_tool3" in aggregator._namespaced_tool_map @@ -960,34 +980,46 @@ async def mock_fetch_capabilities(server_name): async def test_tool_filtering_no_filtering_when_none(): """Test that all tools are included when allowed_tools is None""" # Setup server config with no filtering - server_configs = { - "test_server": MockServerConfig(allowed_tools=None) - } + server_configs = {"test_server": MockServerConfig(allowed_tools=None)} context = DummyContextWithServerRegistry(server_configs) - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["test_server"], connection_persistence=False, context=context, name="test_agent", ) - + mock_tools = [ - Tool(name="tool1", description="Description for tool1", inputSchema={"type": "object"}), - Tool(name="tool2", description="Description for tool2", inputSchema={"type": "object"}), - Tool(name="tool3", description="Description for tool3", inputSchema={"type": "object"}), + Tool( + name="tool1", + description="Description for tool1", + inputSchema={"type": "object"}, + ), + Tool( + name="tool2", + description="Description for tool2", + inputSchema={"type": "object"}, + ), + Tool( + name="tool3", + description="Description for tool3", + inputSchema={"type": "object"}, + ), ] - + async def mock_fetch_capabilities(server_name): return (None, mock_tools, [], []) - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("test_server") - + # Verify all tools were added server_tools = aggregator._server_to_tool_map.get("test_server", []) assert len(server_tools) == 3 - + tool_names = [tool.tool.name for tool in server_tools] assert "tool1" in tool_names assert "tool2" in tool_names @@ -998,33 +1030,41 @@ async def mock_fetch_capabilities(server_name): async def test_tool_filtering_empty_allowed_tools(): """Test behavior when allowed_tools is empty set (should filter out all tools)""" # Setup server config with empty allowed tools - server_configs = { - "test_server": MockServerConfig(allowed_tools=set()) - } + server_configs = {"test_server": MockServerConfig(allowed_tools=set())} context = DummyContextWithServerRegistry(server_configs) - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["test_server"], connection_persistence=False, context=context, name="test_agent", ) - + mock_tools = [ - Tool(name="tool1", description="Description for tool1", inputSchema={"type": "object"}), - Tool(name="tool2", description="Description for tool2", inputSchema={"type": "object"}), + Tool( + name="tool1", + description="Description for tool1", + inputSchema={"type": "object"}, + ), + Tool( + name="tool2", + description="Description for tool2", + inputSchema={"type": "object"}, + ), ] - + async def mock_fetch_capabilities(server_name): return (None, mock_tools, [], []) - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("test_server") - + # Verify no tools were added server_tools = aggregator._server_to_tool_map.get("test_server", []) assert len(server_tools) == 0 - + # Verify namespaced tools map is empty for this server assert "test_server_tool1" not in aggregator._namespaced_tool_map assert "test_server_tool2" not in aggregator._namespaced_tool_map @@ -1035,29 +1075,39 @@ async def test_tool_filtering_no_server_registry(): """Test fallback behavior when server registry is not available""" # Setup context without proper server registry context = DummyContext() # Original dummy context without server registry - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["test_server"], connection_persistence=False, context=context, name="test_agent", ) - + mock_tools = [ - Tool(name="tool1", description="Description for tool1", inputSchema={"type": "object"}), - Tool(name="tool2", description="Description for tool2", inputSchema={"type": "object"}), + Tool( + name="tool1", + description="Description for tool1", + inputSchema={"type": "object"}, + ), + Tool( + name="tool2", + description="Description for tool2", + inputSchema={"type": "object"}, + ), ] - + async def mock_fetch_capabilities(server_name): return (None, mock_tools, [], []) - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("test_server") - + # Should include all tools when no server registry is available server_tools = aggregator._server_to_tool_map.get("test_server", []) assert len(server_tools) == 2 - + tool_names = [tool.tool.name for tool in server_tools] assert "tool1" in tool_names assert "tool2" in tool_names @@ -1073,40 +1123,70 @@ async def test_tool_filtering_multiple_servers(): "server3": MockServerConfig(allowed_tools=None), # No filtering } context = DummyContextWithServerRegistry(server_configs) - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["server1", "server2", "server3"], connection_persistence=False, context=context, name="test_agent", ) - + # Different tools for each server server_tools = { "server1": [ - Tool(name="tool1", description="Description for tool1", inputSchema={"type": "object"}), - Tool(name="tool2", description="Description for tool2", inputSchema={"type": "object"}), - Tool(name="tool_extra", description="Description for tool_extra", inputSchema={"type": "object"}) + Tool( + name="tool1", + description="Description for tool1", + inputSchema={"type": "object"}, + ), + Tool( + name="tool2", + description="Description for tool2", + inputSchema={"type": "object"}, + ), + Tool( + name="tool_extra", + description="Description for tool_extra", + inputSchema={"type": "object"}, + ), ], "server2": [ - Tool(name="tool3", description="Description for tool3", inputSchema={"type": "object"}), - Tool(name="tool_filtered", description="Description for tool_filtered", inputSchema={"type": "object"}) + Tool( + name="tool3", + description="Description for tool3", + inputSchema={"type": "object"}, + ), + Tool( + name="tool_filtered", + description="Description for tool_filtered", + inputSchema={"type": "object"}, + ), ], "server3": [ - Tool(name="toolA", description="Description for toolA", inputSchema={"type": "object"}), - Tool(name="toolB", description="Description for toolB", inputSchema={"type": "object"}) + Tool( + name="toolA", + description="Description for toolA", + inputSchema={"type": "object"}, + ), + Tool( + name="toolB", + description="Description for toolB", + inputSchema={"type": "object"}, + ), ], } - + async def mock_fetch_capabilities(server_name): tools = server_tools.get(server_name, []) return (None, tools, [], []) - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("server1") - await aggregator.load_server("server2") + await aggregator.load_server("server2") await aggregator.load_server("server3") - + # Check server1 filtering server1_tools = aggregator._server_to_tool_map.get("server1", []) assert len(server1_tools) == 2 @@ -1114,21 +1194,21 @@ async def mock_fetch_capabilities(server_name): assert "tool1" in server1_names assert "tool2" in server1_names assert "tool_extra" not in server1_names - + # Check server2 filtering server2_tools = aggregator._server_to_tool_map.get("server2", []) assert len(server2_tools) == 1 server2_names = [tool.tool.name for tool in server2_tools] assert "tool3" in server2_names assert "tool_filtered" not in server2_names - + # Check server3 (no filtering) server3_tools = aggregator._server_to_tool_map.get("server3", []) assert len(server3_tools) == 2 server3_names = [tool.tool.name for tool in server3_tools] assert "toolA" in server3_names assert "toolB" in server3_names - + # Check namespaced tools map assert "server1_tool1" in aggregator._namespaced_tool_map assert "server1_tool2" in aggregator._namespaced_tool_map @@ -1139,39 +1219,56 @@ async def mock_fetch_capabilities(server_name): assert "server3_toolB" in aggregator._namespaced_tool_map - -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_tool_filtering_edge_case_exact_match(): """Test that tool filtering requires exact name matches""" server_configs = { "test_server": MockServerConfig(allowed_tools={"tool", "tool_exact"}) } context = DummyContextWithServerRegistry(server_configs) - + aggregator = mcp_aggregator_mod.MCPAggregator( server_names=["test_server"], connection_persistence=False, context=context, name="test_agent", ) - + mock_tools = [ - Tool(name="tool", description="Description for tool", inputSchema={"type": "object"}), # Should be included (exact match) - Tool(name="tool_exact", description="Description for tool_exact", inputSchema={"type": "object"}), # Should be included (exact match) - Tool(name="tool_similar", description="Description for tool_similar", inputSchema={"type": "object"}), # Should be filtered (not exact match) - Tool(name="my_tool", description="Description for my_tool", inputSchema={"type": "object"}), # Should be filtered (not exact match) + Tool( + name="tool", + description="Description for tool", + inputSchema={"type": "object"}, + ), # Should be included (exact match) + Tool( + name="tool_exact", + description="Description for tool_exact", + inputSchema={"type": "object"}, + ), # Should be included (exact match) + Tool( + name="tool_similar", + description="Description for tool_similar", + inputSchema={"type": "object"}, + ), # Should be filtered (not exact match) + Tool( + name="my_tool", + description="Description for my_tool", + inputSchema={"type": "object"}, + ), # Should be filtered (not exact match) ] - + async def mock_fetch_capabilities(server_name): return (None, mock_tools, [], []) - - with patch.object(aggregator, '_fetch_capabilities', side_effect=mock_fetch_capabilities): + + with patch.object( + aggregator, "_fetch_capabilities", side_effect=mock_fetch_capabilities + ): await aggregator.load_server("test_server") - + # Verify only exact matches were included server_tools = aggregator._server_to_tool_map.get("test_server", []) assert len(server_tools) == 2 - + tool_names = [tool.tool.name for tool in server_tools] assert "tool" in tool_names assert "tool_exact" in tool_names diff --git a/uv.lock b/uv.lock index 218056688..09b0c7340 100644 --- a/uv.lock +++ b/uv.lock @@ -2040,7 +2040,7 @@ wheels = [ [[package]] name = "mcp-agent" -version = "0.1.15" +version = "0.1.17" source = { editable = "." } dependencies = [ { name = "aiohttp" },