diff --git a/src/mcp_agent/cli/cloud/commands/workflows/runs/main.py b/src/mcp_agent/cli/cloud/commands/workflows/runs/main.py index 507fc30c8..3bb4a6744 100644 --- a/src/mcp_agent/cli/cloud/commands/workflows/runs/main.py +++ b/src/mcp_agent/cli/cloud/commands/workflows/runs/main.py @@ -71,7 +71,9 @@ async def _list_workflow_runs_async( workflows = workflows_data if status: status_filter = _get_status_filter(status) - workflows = [w for w in workflows if _matches_status(w, status_filter)] + workflows = [ + w for w in workflows if _matches_status(w, status_filter) + ] if limit: workflows = workflows[:limit] @@ -144,7 +146,7 @@ def _get_status_filter(status: str) -> str: def _matches_status(workflow: dict, status_filter: str) -> bool: """Check if workflow matches the status filter. - + Note: We use string-based matching instead of protobuf enum values because the MCP tool response format returns status as strings, not enum objects. This approach is more flexible and doesn't require maintaining sync with @@ -158,9 +160,7 @@ def _matches_status(workflow: dict, status_filter: str) -> bool: def _print_workflows_text(workflows, status_filter, server_id_or_url): """Print workflows in text format.""" - console.print( - f"\n[bold blue]📊 Workflow Runs ({len(workflows)})[/bold blue]" - ) + console.print(f"\n[bold blue]📊 Workflow Runs ({len(workflows)})[/bold blue]") if not workflows: print_info("No workflow runs found for this server.") @@ -169,7 +169,7 @@ def _print_workflows_text(workflows, status_filter, server_id_or_url): for i, workflow in enumerate(workflows): if i > 0: console.print() - + if isinstance(workflow, dict): workflow_id = workflow.get("workflow_id", "N/A") name = workflow.get("name", "N/A") @@ -184,27 +184,28 @@ def _print_workflows_text(workflows, status_filter, server_id_or_url): run_id = getattr(workflow, "run_id", "N/A") created_at = getattr(workflow, "created_at", "N/A") principal_id = getattr(workflow, "principal_id", "N/A") - + status_display = _get_status_display(execution_status) - + if created_at and created_at != "N/A": if hasattr(created_at, "strftime"): created_display = created_at.strftime("%Y-%m-%d %H:%M:%S") else: try: from datetime import datetime + dt = datetime.fromisoformat(str(created_at).replace("Z", "+00:00")) created_display = dt.strftime("%Y-%m-%d %H:%M:%S") except (ValueError, TypeError): created_display = str(created_at) else: created_display = "N/A" - + console.print(f"[bold cyan]{name or 'Unnamed'}[/bold cyan] {status_display}") console.print(f" Workflow ID: {workflow_id}") console.print(f" Run ID: {run_id}") console.print(f" Created: {created_display}") - + if principal_id and principal_id != "N/A": console.print(f" Principal: {principal_id}") @@ -228,13 +229,17 @@ def _workflow_to_dict(workflow): """Convert workflow dict to standardized dictionary format.""" if isinstance(workflow, dict): return workflow - + return { "workflow_id": getattr(workflow, "workflow_id", None), "run_id": getattr(workflow, "run_id", None), "name": getattr(workflow, "name", None), - "created_at": getattr(workflow, "created_at", None).isoformat() if getattr(workflow, "created_at", None) else None, - "execution_status": getattr(workflow, "execution_status", None).value if getattr(workflow, "execution_status", None) else None, + "created_at": getattr(workflow, "created_at", None).isoformat() + if getattr(workflow, "created_at", None) + else None, + "execution_status": getattr(workflow, "execution_status", None).value + if getattr(workflow, "execution_status", None) + else None, } @@ -242,9 +247,9 @@ def _get_status_display(status): """Convert status to display string with emoji.""" if not status: return "❓ Unknown" - + status_str = str(status).lower() - + if "running" in status_str: return "[green]🟢 Running[/green]" elif "completed" in status_str: diff --git a/src/mcp_agent/cli/mcp_app/api_client.py b/src/mcp_agent/cli/mcp_app/api_client.py index 48e289635..9631c2d3d 100644 --- a/src/mcp_agent/cli/mcp_app/api_client.py +++ b/src/mcp_agent/cli/mcp_app/api_client.py @@ -63,7 +63,6 @@ class CanDoActionsResponse(BaseModel): canDoActions: Optional[List[CanDoActionCheck]] = [] - APP_ID_PREFIX = "app_" APP_CONFIG_ID_PREFIX = "apcnf_" @@ -467,7 +466,6 @@ async def list_app_configurations( response = await self.post("/mcp_app/list_app_configurations", payload) return ListAppConfigurationsResponse(**response.json()) - async def delete_app(self, app_id: str) -> str: """Delete an MCP App via the API. diff --git a/src/mcp_agent/mcp/mcp_aggregator.py b/src/mcp_agent/mcp/mcp_aggregator.py index f57e67b48..80a3e1f12 100644 --- a/src/mcp_agent/mcp/mcp_aggregator.py +++ b/src/mcp_agent/mcp/mcp_aggregator.py @@ -237,28 +237,19 @@ async def close(self): and self.context._mcp_connection_manager == self._persistent_connection_manager ): - # Add timeout protection for the disconnect operation + # Close via manager's thread-aware close() try: await asyncio.wait_for( - self._persistent_connection_manager.disconnect_all(), + self._persistent_connection_manager.close(), timeout=5.0, ) except asyncio.TimeoutError: logger.warning( - "Timeout during disconnect_all(), forcing shutdown" + "Timeout during connection manager close(), forcing shutdown" ) - - # Avoid calling __aexit__ directly across threads; mark inactive - try: - if hasattr( - self._persistent_connection_manager, - "_tg_active", - ): - self._persistent_connection_manager._tg_active = False - self._persistent_connection_manager._tg = None except Exception as e: logger.warning( - f"Error during connection manager state cleanup: {e}" + f"Error during connection manager close(): {e}" ) # Clean up the connection manager from the context diff --git a/src/mcp_agent/mcp/mcp_connection_manager.py b/src/mcp_agent/mcp/mcp_connection_manager.py index cae9c9ea0..5009d4f63 100644 --- a/src/mcp_agent/mcp/mcp_connection_manager.py +++ b/src/mcp_agent/mcp/mcp_connection_manager.py @@ -3,6 +3,7 @@ """ from datetime import timedelta +import asyncio import threading from typing import ( AsyncGenerator, @@ -237,50 +238,175 @@ def __init__( self._tg_active = False # Track the thread this manager was created in to ensure TaskGroup cleanup self._thread_id = threading.get_ident() + # Event loop where the TaskGroup lives + self._loop: asyncio.AbstractEventLoop | None = None + # Owner task + coordination events for safe TaskGroup lifecycle + self._tg_owner_task: asyncio.Task | None = None + self._owner_tg: TaskGroup | None = None + self._tg_ready_event: Event = Event() + self._tg_close_event: Event = Event() + self._tg_closed_event: Event = Event() + # Ensure a single close sequence at a time on the origin loop + self._close_lock = Lock() + # Serialize owner startup to avoid races across tasks + self._owner_start_lock = Lock() async def __aenter__(self): - # We create a task group to manage all server lifecycle tasks - tg = create_task_group() - # Enter the task group context - await tg.__aenter__() - self._tg_active = True - self._tg = tg + # Start the TaskGroup owner task and wait until ready + await self._start_owner() + # Record the loop and thread where the TaskGroup is running + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = None return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Ensure clean shutdown of all connections before exiting.""" + await self.close(exc_type, exc_val, exc_tb) + # Close the owner TaskGroup in the same task that entered it + if self._owner_tg is not None: + try: + await self._owner_tg.__aexit__(exc_type, exc_val, exc_tb) + except Exception as e: + logger.warning( + f"MCPConnectionManager: Error during owner TaskGroup cleanup: {e}" + ) + finally: + self._owner_tg = None + + async def close(self, exc_type=None, exc_val=None, exc_tb=None): + """Close all connections and tear down the internal TaskGroup safely. + + This is thread-aware: if called from a different thread than the one where the + TaskGroup was created, it will signal the owner task on the original loop to + perform cleanup and await completion without violating task affinity. + """ try: - # First request all servers to shutdown - logger.debug("MCPConnectionManager: shutting down all server tasks...") - await self.disconnect_all() + current_thread = threading.get_ident() + if current_thread == self._thread_id: + # Same thread: perform shutdown inline with exclusive access + async with self._close_lock: + logger.debug( + "MCPConnectionManager: shutting down all server tasks..." + ) + await self.disconnect_all() + await anyio.sleep(0.5) + if self._tg_active: + self._tg_close_event.set() + # Wait for owner to report TaskGroup closed with an anyio timeout + try: + with anyio.fail_after(5.0): + await self._tg_closed_event.wait() + except TimeoutError: + logger.warning( + "MCPConnectionManager: Timeout waiting for TaskGroup owner to close" + ) + # Do not attempt to close the owner TaskGroup here; __aexit__ will handle it + else: + # Different thread – run entire shutdown on the original loop to avoid cross-thread Event.set + if self._loop is not None: - # Add a small delay to allow for clean shutdown of subprocess transports, etc. - await anyio.sleep(0.5) + async def _shutdown_and_close(): + logger.debug( + "MCPConnectionManager: shutting down all server tasks (origin loop)..." + ) + async with self._close_lock: + await self.disconnect_all() + await anyio.sleep(0.5) + if self._tg_active: + self._tg_close_event.set() + await self._tg_closed_event.wait() - # Then close the task group if it's active and we're in the same thread - if self._tg_active and self._tg: - current_thread = threading.get_ident() - if current_thread == self._thread_id: try: - await self._tg.__aexit__(exc_type, exc_val, exc_tb) + cfut = asyncio.run_coroutine_threadsafe( + _shutdown_and_close(), self._loop + ) + # Wait in a worker thread to avoid blocking non-asyncio contexts + try: + with anyio.fail_after(5.0): + await anyio.to_thread.run_sync(cfut.result) + except TimeoutError: + logger.warning( + "MCPConnectionManager: Timeout during cross-thread shutdown/close" + ) + try: + cfut.cancel() + except Exception: + pass except Exception as e: logger.warning( - f"MCPConnectionManager: Error during task group cleanup: {e}" + f"MCPConnectionManager: Error scheduling cross-thread shutdown: {e}" ) else: - # Different thread – cannot safely cleanup anyio TaskGroup logger.warning( - f"MCPConnectionManager: Task group cleanup called from different thread " - f"(created in {self._thread_id}, called from {current_thread}). Skipping cleanup." + "MCPConnectionManager: No event loop recorded for cleanup; skipping TaskGroup close" ) - # Always mark as inactive and drop reference - self._tg_active = False - self._tg = None except AttributeError: # Handle missing `_exceptions` pass except Exception as e: logger.warning(f"MCPConnectionManager: Error during shutdown: {e}") + async def _start_owner(self): + """Start the TaskGroup owner task if not already running (task-safe).""" + async with self._owner_start_lock: + # If an owner is active or TaskGroup is already active, nothing to do + if (self._tg_owner_task and not self._tg_owner_task.done()) or ( + self._tg_active and self._tg is not None + ): + return + # If previous owner exists but is done (possibly with error), log and restart + if self._tg_owner_task and self._tg_owner_task.done(): + try: + exc = self._tg_owner_task.exception() + if exc: + logger.warning( + f"MCPConnectionManager: restarting owner after error: {exc}" + ) + except Exception: + logger.warning( + "MCPConnectionManager: restarting owner after unknown state" + ) + # Reset coordination events (safe here since no active owner/TG) + self._tg_ready_event = Event() + self._tg_close_event = Event() + self._tg_closed_event = Event() + # Record loop and thread + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = None + self._thread_id = threading.get_ident() + # Create an owner TaskGroup and start the owner task within it + owner_tg = create_task_group() + await owner_tg.__aenter__() + self._owner_tg = owner_tg + owner_tg.start_soon(self._tg_owner) + # Wait until the TaskGroup is ready + await self._tg_ready_event.wait() + + async def _tg_owner(self): + """Own the TaskGroup lifecycle so __aexit__ runs in the same task it was entered.""" + try: + async with create_task_group() as tg: + self._tg = tg + self._tg_active = True + # Signal that TaskGroup is ready + self._tg_ready_event.set() + # Wait for close request + await self._tg_close_event.wait() + except Exception as e: + logger.warning(f"MCPConnectionManager: Error in TaskGroup owner: {e}") + finally: + # Mark closed and clear references + self._tg_active = False + self._tg = None + # Signal that TaskGroup has been closed + try: + self._tg_closed_event.set() + except Exception as e: + logger.warning(f"Failed to set _tg_closed_event: {e}") + async def launch_server( self, server_name: str, @@ -295,12 +421,9 @@ async def launch_server( Connect to a server and return a RunningServer instance that will persist until explicitly disconnected. """ - # Create task group if it doesn't exist yet - make this method more resilient + # Ensure the TaskGroup owner is running - make this method more resilient if not self._tg_active: - tg = create_task_group() - await tg.__aenter__() - self._tg_active = True - self._tg = tg + await self._start_owner() logger.info( f"MCPConnectionManager: Auto-created task group for server: {server_name}" ) diff --git a/src/mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py b/src/mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py index 29ada7aea..6d9a6d270 100644 --- a/src/mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +++ b/src/mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py @@ -22,7 +22,7 @@ logger = get_logger(__name__) -class QualityRating(str, Enum): +class QualityRating(int, Enum): """Enum for evaluation quality ratings""" POOR = 0 # Major improvements needed diff --git a/tests/mcp/test_connection_manager_concurrency.py b/tests/mcp/test_connection_manager_concurrency.py new file mode 100644 index 000000000..fd52bd11f --- /dev/null +++ b/tests/mcp/test_connection_manager_concurrency.py @@ -0,0 +1,59 @@ +import asyncio +import threading + +import anyio +import pytest + +from mcp_agent.mcp.mcp_connection_manager import MCPConnectionManager + + +class DummyServerRegistry: + def __init__(self): + self.registry = {} + self.init_hooks = {} + + +@pytest.mark.anyio("asyncio") +async def test_concurrent_close_calls_same_and_cross_thread(): + mgr = MCPConnectionManager(server_registry=DummyServerRegistry()) + await mgr.__aenter__() + + # Run one close() on the event loop and one from a separate thread at the same time + thread_exc = [] + + def close_in_thread(): + async def _run(): + try: + # Exercise cross-thread shutdown path + await mgr.close() + except Exception as e: + thread_exc.append(e) + + asyncio.run(_run()) + + t = threading.Thread(target=close_in_thread, daemon=True) + + async with anyio.create_task_group() as tg: + # Start cross-thread close, then quickly start same-thread close + t.start() + # Add a tiny delay to improve overlap + await anyio.sleep(0.05) + + async def close_in_loop(): + await mgr.close() + + # Guard against hangs + with anyio.fail_after(6.0): + tg.start_soon(close_in_loop) + # Wait for thread to complete + await anyio.to_thread.run_sync(t.join) + + # Ensure no exceptions from thread + assert not thread_exc, f"Thread close failed: {thread_exc!r}" + + # Now exit context to close the owner TaskGroup on the origin loop + await mgr.__aexit__(None, None, None) + + # Verify TaskGroup cleared + assert getattr(mgr, "_tg", None) is None + assert getattr(mgr, "_tg_active", False) is False