From 6d00b541d0fde0d2dbfa83c163d2b71019b9d590 Mon Sep 17 00:00:00 2001 From: rholinshead <5060851+rholinshead@users.noreply.github.com> Date: Mon, 25 Aug 2025 13:17:16 -0400 Subject: [PATCH] Track registered workflow tools for mcp server --- src/mcp_agent/server/app_server.py | 28 ++++- tests/server/test_app_server.py | 180 ++++++++++++++++++++++++++++- 2 files changed, 203 insertions(+), 5 deletions(-) diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 5d3146340..c3b5a6bb8 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -6,7 +6,7 @@ import json from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING from mcp.server.fastmcp import Context as MCPContext, FastMCP from mcp.server.fastmcp.exceptions import ToolError @@ -37,6 +37,12 @@ def __init__(self, mcp: FastMCP, context: "Context", **kwargs): self.mcp = mcp self.active_agents: Dict[str, Agent] = {} + # Maintain a list of registered workflow tools to avoid re-registration + # when server context is recreated for the same FastMCP instance (e.g. during + # FastMCP sse request handling) + if not hasattr(self.mcp, "_registered_workflow_tools"): + setattr(self.mcp, "_registered_workflow_tools", set()) + # Initialize workflow registry if not already present if not self.context.workflow_registry: if self.context.config.execution_engine == "asyncio": @@ -62,8 +68,11 @@ def register_workflow(self, workflow_name: str, workflow_cls: Type[Workflow]): """Register a workflow class.""" if workflow_name not in self.context.workflows: self.workflows[workflow_name] = workflow_cls - # Create tools for this workflow - create_workflow_specific_tools(self.mcp, workflow_name, workflow_cls) + # Create tools for this workflow if not already registered + registered_workflow_tools = _get_registered_workflow_tools(self.mcp) + if workflow_name not in registered_workflow_tools: + create_workflow_specific_tools(self.mcp, workflow_name, workflow_cls) + registered_workflow_tools.add(workflow_name) @property def app(self) -> MCPApp: @@ -86,6 +95,11 @@ def _get_attached_app(mcp: FastMCP) -> MCPApp | None: return getattr(mcp, "_mcp_agent_app", None) +def _get_registered_workflow_tools(mcp: FastMCP) -> Set[str]: + """Return the set of registered workflow tools for the FastMCP server, if any.""" + return getattr(mcp, "_registered_workflow_tools", set()) + + def _get_attached_server_context(mcp: FastMCP) -> ServerContext | None: """Return the ServerContext attached to the FastMCP server, if any.""" return getattr(mcp, "_mcp_agent_server_context", None) @@ -386,8 +400,14 @@ def create_workflow_tools(mcp: FastMCP, server_context: ServerContext): logger.warning("Server config not available for creating workflow tools") return + registered_workflow_tools = _get_registered_workflow_tools(mcp) + for workflow_name, workflow_cls in server_context.workflows.items(): - create_workflow_specific_tools(mcp, workflow_name, workflow_cls) + if workflow_name not in registered_workflow_tools: + create_workflow_specific_tools(mcp, workflow_name, workflow_cls) + registered_workflow_tools.add(workflow_name) + + setattr(mcp, "_registered_workflow_tools", registered_workflow_tools) def create_workflow_specific_tools( diff --git a/tests/server/test_app_server.py b/tests/server/test_app_server.py index e5e0b8126..a974b2b73 100644 --- a/tests/server/test_app_server.py +++ b/tests/server/test_app_server.py @@ -1,7 +1,11 @@ import pytest from unittest.mock import AsyncMock, MagicMock from types import SimpleNamespace -from mcp_agent.server.app_server import _workflow_run +from mcp_agent.server.app_server import ( + _workflow_run, + ServerContext, + create_workflow_tools, +) from mcp_agent.executor.workflow import WorkflowExecution @@ -252,3 +256,177 @@ async def test_workflow_run_preserves_user_params_with_similar_names( # The "__mcp_agent_workflow_id" from user params should not override system param assert call_kwargs["__mcp_agent_workflow_id"] != "should-not-happen" + + +def test_workflow_tools_idempotent_registration(): + """Test that workflow tools are only registered once per workflow""" + # Create mock FastMCP and context + mock_mcp = MagicMock() + mock_app = MagicMock() + mock_context = MagicMock(app=mock_app) + + # Ensure the mcp mock doesn't have _registered_workflow_tools initially + # so ServerContext.__init__ will create it + if hasattr(mock_mcp, "_registered_workflow_tools"): + delattr(mock_mcp, "_registered_workflow_tools") + + mock_app.workflows = {} + # Need to mock the config and workflow_registry for ServerContext init + mock_context.workflow_registry = None + mock_context.config = MagicMock() + mock_context.config.execution_engine = "asyncio" + + server_context = ServerContext(mcp=mock_mcp, context=mock_context) + + # Mock workflows + mock_workflow_class = MagicMock() + mock_workflow_class.__doc__ = "Test workflow" + mock_run = MagicMock() + mock_run.__name__ = "run" + mock_workflow_class.run = mock_run + + mock_app.workflows = { + "workflow1": mock_workflow_class, + "workflow2": mock_workflow_class, + } + + tools_created = [] + + def track_tool_calls(*args, **kwargs): + def decorator(func): + tools_created.append(kwargs.get("name", args[0] if args else "unknown")) + return func + + return decorator + + mock_mcp.tool = track_tool_calls + + # First call to create_workflow_tools + create_workflow_tools(mock_mcp, server_context) + + # Verify tools were created for both workflows + expected_tools = [ + "workflows-workflow1-run", + "workflows-workflow1-get_status", + "workflows-workflow2-run", + "workflows-workflow2-get_status", + ] + + assert len(tools_created) == 4 + for expected_tool in expected_tools: + assert expected_tool in tools_created + + # Verify the registered workflow tools are tracked on the MCP instance + assert hasattr(mock_mcp, "_registered_workflow_tools") + assert mock_mcp._registered_workflow_tools == {"workflow1", "workflow2"} + + # Reset tools and call create_workflow_tools again + tools_created.clear() + create_workflow_tools(mock_mcp, server_context) + + # Verify no additional tools were created (idempotent) + assert len(tools_created) == 0 + assert mock_mcp._registered_workflow_tools == {"workflow1", "workflow2"} + + # Test register_workflow with a new workflow + new_workflow_class = MagicMock() + new_workflow_class.__doc__ = "New workflow" + new_mock_run = MagicMock() + new_mock_run.__name__ = "run" + new_workflow_class.run = new_mock_run + + server_context.register_workflow("workflow3", new_workflow_class) + + # Verify the new workflow was added and its tools created + assert "workflow3" in server_context.workflows + assert "workflow3" in mock_mcp._registered_workflow_tools + assert len(tools_created) == 2 # run and get_status for workflow3 + assert "workflows-workflow3-run" in tools_created + assert "workflows-workflow3-get_status" in tools_created + + # Test registering the same workflow again (should be idempotent) + tools_created.clear() + server_context.register_workflow("workflow3", new_workflow_class) + + # Should not create duplicate tools or add to workflows again + assert len(tools_created) == 0 + assert mock_mcp._registered_workflow_tools == { + "workflow1", + "workflow2", + "workflow3", + } + + +def test_workflow_tools_persistent_across_sse_requests(): + """Test that workflow tools registration persists across SSE request context recreation""" + # Create mock FastMCP instance (this persists across requests) + mock_mcp = MagicMock() + + # Ensure the mcp mock doesn't have _registered_workflow_tools initially + if hasattr(mock_mcp, "_registered_workflow_tools"): + delattr(mock_mcp, "_registered_workflow_tools") + + # Mock workflows + mock_workflow_class = MagicMock() + mock_workflow_class.__doc__ = "Test workflow" + mock_run = MagicMock() + mock_run.__name__ = "run" + mock_workflow_class.run = mock_run + + tools_created = [] + + def track_tool_calls(*args, **kwargs): + def decorator(func): + tools_created.append(kwargs.get("name", args[0] if args else "unknown")) + return func + + return decorator + + mock_mcp.tool = track_tool_calls + + # Simulate first SSE request - create new ServerContext + mock_app1 = MagicMock() + mock_context1 = MagicMock(app=mock_app1) + mock_context1.workflow_registry = None + mock_context1.config = MagicMock() + mock_context1.config.execution_engine = "asyncio" + mock_app1.workflows = {"workflow1": mock_workflow_class} + server_context1 = ServerContext(mcp=mock_mcp, context=mock_context1) + + # Register tools in first request + create_workflow_tools(mock_mcp, server_context1) + + # Verify tools were created + assert len(tools_created) == 2 # run and get_status + assert "workflows-workflow1-run" in tools_created + assert "workflows-workflow1-get_status" in tools_created + assert hasattr(mock_mcp, "_registered_workflow_tools") + assert "workflow1" in mock_mcp._registered_workflow_tools + + # Reset tools tracker + tools_created.clear() + + # Simulate second SSE request - create NEW ServerContext (simulates fastmcp behavior) + mock_app2 = MagicMock() + mock_context2 = MagicMock(app=mock_app2) + mock_context2.workflow_registry = None + mock_context2.config = MagicMock() + mock_context2.config.execution_engine = "asyncio" + mock_app2.workflows = {"workflow1": mock_workflow_class} # Same workflow + server_context2 = ServerContext(mcp=mock_mcp, context=mock_context2) # NEW context! + + # The MCP instance should still have the registration from the first context + assert hasattr(mock_mcp, "_registered_workflow_tools") + assert isinstance( + mock_mcp._registered_workflow_tools, set + ) # Should be a real set now + + # But the FastMCP instance should still have the persistent registration + assert mock_mcp._registered_workflow_tools == {"workflow1"} + + # Call create_workflow_tools again - should be idempotent due to persistent storage + create_workflow_tools(mock_mcp, server_context2) + + # Verify NO additional tools were created (idempotent) + assert len(tools_created) == 0 + assert mock_mcp._registered_workflow_tools == {"workflow1"}