diff --git a/examples/mcp_agent_server/asyncio/basic_agent_server.py b/examples/mcp_agent_server/asyncio/basic_agent_server.py index b4cdeccda..f54f171d9 100644 --- a/examples/mcp_agent_server/asyncio/basic_agent_server.py +++ b/examples/mcp_agent_server/asyncio/basic_agent_server.py @@ -13,6 +13,8 @@ import logging from typing import Dict, Any +from mcp.server.fastmcp import FastMCP + from mcp_agent.app import MCPApp from mcp_agent.server.app_server import create_mcp_server_for_app from mcp_agent.agents.agent import Agent @@ -28,8 +30,16 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Create a single FastMCPApp instance (which extends MCPApp) -app = MCPApp(name="basic_agent_server", description="Basic agent server example") +# Note: This is purely optional: +# if not provided, a default FastMCP server will be created by MCPApp using create_mcp_server_for_app() +mcp = FastMCP(name="basic_agent_server", description="My basic agent server example.") + +# Define the MCPApp instance +app = MCPApp( + name="basic_agent_server", + description="Basic agent server example", + mcp=mcp, +) @app.workflow @@ -169,6 +179,123 @@ async def run(self, input: str) -> WorkflowResult[str]: return WorkflowResult(value=result) +# Add custom tool to get token usage for a workflow +@mcp.tool( + name="get_token_usage", + structured_output=True, + description=""" +Get detailed token usage information for a specific workflow run. + +This provides a comprehensive breakdown of token usage including: +- Total tokens used across all LLM calls within the workflow +- Breakdown by model provider and specific models +- Hierarchical usage tree showing usage at each level (workflow -> agent -> llm) +- Total cost estimate based on model pricing + +Args: + workflow_id: Optional workflow ID (if multiple workflows have the same name) + run_id: Optional ID of the workflow run to get token usage for + workflow_name: Optional name of the workflow (used as fallback) + +Returns: + Detailed token usage information for the specific workflow run +""", +) +async def get_workflow_token_usage( + workflow_id: str | None = None, + run_id: str | None = None, + workflow_name: str | None = None, +) -> Dict[str, Any]: + """Get token usage information for a specific workflow run.""" + context = app.context + + if not context.token_counter: + return { + "error": "Token counter not available", + "message": "Token tracking is not enabled for this application", + } + + # Find the specific workflow node + workflow_node = await context.token_counter.get_workflow_node( + name=workflow_name, workflow_id=workflow_id, run_id=run_id + ) + + if not workflow_node: + return { + "error": "Workflow not found", + "message": f"Could not find workflow with run_id='{run_id}'", + } + + # Get the aggregated usage for this workflow + workflow_usage = workflow_node.aggregate_usage() + + # Calculate cost for this workflow + workflow_cost = context.token_counter._calculate_node_cost(workflow_node) + + # Build the response + result = { + "workflow": { + "name": workflow_node.name, + "run_id": workflow_node.metadata.get("run_id"), + "workflow_id": workflow_node.metadata.get("workflow_id"), + }, + "usage": { + "input_tokens": workflow_usage.input_tokens, + "output_tokens": workflow_usage.output_tokens, + "total_tokens": workflow_usage.total_tokens, + }, + "cost": round(workflow_cost, 4), + "model_breakdown": {}, + "usage_tree": workflow_node.to_dict(), + } + + # Get model breakdown for this workflow + model_usage = {} + + def collect_model_usage(node: TokenNode): + """Recursively collect model usage from a node tree""" + if node.usage.model_name: + model_name = node.usage.model_name + provider = node.usage.model_info.provider if node.usage.model_info else None + + # Use tuple as key to handle same model from different providers + model_key = (model_name, provider) + + if model_key not in model_usage: + model_usage[model_key] = { + "model_name": model_name, + "provider": provider, + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + + model_usage[model_key]["input_tokens"] += node.usage.input_tokens + model_usage[model_key]["output_tokens"] += node.usage.output_tokens + model_usage[model_key]["total_tokens"] += node.usage.total_tokens + + for child in node.children: + collect_model_usage(child) + + collect_model_usage(workflow_node) + + # Calculate costs for each model and format for output + for (model_name, provider), usage in model_usage.items(): + cost = context.token_counter.calculate_cost( + model_name, usage["input_tokens"], usage["output_tokens"], provider + ) + + # Create display key with provider info if available + display_key = f"{model_name} ({provider})" if provider else model_name + + result["model_breakdown"][display_key] = { + **usage, + "cost": round(cost, 4), + } + + return result + + async def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -202,124 +329,6 @@ async def main(): mcp_server = create_mcp_server_for_app(agent_app, **(fast_mcp_settings or {})) logger.info(f"MCP Server settings: {mcp_server.settings}") - # Add custom tool to get token usage for a workflow - @mcp_server.tool( - name="get_token_usage", - structured_output=True, - description=""" - Get detailed token usage information for a specific workflow run. - - This provides a comprehensive breakdown of token usage including: - - Total tokens used across all LLM calls within the workflow - - Breakdown by model provider and specific models - - Hierarchical usage tree showing usage at each level (workflow -> agent -> llm) - - Total cost estimate based on model pricing - - Args: - workflow_id: Optional workflow ID (if multiple workflows have the same name) - run_id: Optional ID of the workflow run to get token usage for - workflow_name: Optional name of the workflow (used as fallback) - - Returns: - Detailed token usage information for the specific workflow run - """, - ) - async def get_workflow_token_usage( - workflow_id: str | None = None, - run_id: str | None = None, - workflow_name: str | None = None, - ) -> Dict[str, Any]: - """Get token usage information for a specific workflow run.""" - if not context.token_counter: - return { - "error": "Token counter not available", - "message": "Token tracking is not enabled for this application", - } - - # Find the specific workflow node - workflow_node = await context.token_counter.get_workflow_node( - name=workflow_name, workflow_id=workflow_id, run_id=run_id - ) - - if not workflow_node: - return { - "error": "Workflow not found", - "message": f"Could not find workflow with run_id='{run_id}'", - } - - # Get the aggregated usage for this workflow - workflow_usage = workflow_node.aggregate_usage() - - # Calculate cost for this workflow - workflow_cost = context.token_counter._calculate_node_cost(workflow_node) - - # Build the response - result = { - "workflow": { - "name": workflow_node.name, - "run_id": workflow_node.metadata.get("run_id"), - "workflow_id": workflow_node.metadata.get("workflow_id"), - }, - "usage": { - "input_tokens": workflow_usage.input_tokens, - "output_tokens": workflow_usage.output_tokens, - "total_tokens": workflow_usage.total_tokens, - }, - "cost": round(workflow_cost, 4), - "model_breakdown": {}, - "usage_tree": workflow_node.to_dict(), - } - - # Get model breakdown for this workflow - model_usage = {} - - def collect_model_usage(node: TokenNode): - """Recursively collect model usage from a node tree""" - if node.usage.model_name: - model_name = node.usage.model_name - provider = ( - node.usage.model_info.provider - if node.usage.model_info - else None - ) - - # Use tuple as key to handle same model from different providers - model_key = (model_name, provider) - - if model_key not in model_usage: - model_usage[model_key] = { - "model_name": model_name, - "provider": provider, - "input_tokens": 0, - "output_tokens": 0, - "total_tokens": 0, - } - - model_usage[model_key]["input_tokens"] += node.usage.input_tokens - model_usage[model_key]["output_tokens"] += node.usage.output_tokens - model_usage[model_key]["total_tokens"] += node.usage.total_tokens - - for child in node.children: - collect_model_usage(child) - - collect_model_usage(workflow_node) - - # Calculate costs for each model and format for output - for (model_name, provider), usage in model_usage.items(): - cost = context.token_counter.calculate_cost( - model_name, usage["input_tokens"], usage["output_tokens"], provider - ) - - # Create display key with provider info if available - display_key = f"{model_name} ({provider})" if provider else model_name - - result["model_breakdown"][display_key] = { - **usage, - "cost": round(cost, 4), - } - - return result - # Run the server await mcp_server.run_stdio_async() diff --git a/src/mcp_agent/app.py b/src/mcp_agent/app.py index f5c0c02a9..cfccec0d4 100644 --- a/src/mcp_agent/app.py +++ b/src/mcp_agent/app.py @@ -7,6 +7,7 @@ from contextlib import asynccontextmanager from mcp import ServerSession +from mcp.server.fastmcp import FastMCP from mcp_agent.core.context import Context, initialize_context, cleanup_context from mcp_agent.config import Settings, get_settings from mcp_agent.executor.signal_registry import SignalRegistry @@ -57,12 +58,13 @@ def __init__( self, name: str = "mcp_application", description: str | None = None, - settings: Optional[Settings] | str = None, - human_input_callback: Optional[HumanInputCallback] = None, - elicitation_callback: Optional[ElicitationCallback] = None, - signal_notification: Optional[SignalWaitCallback] = None, + settings: Settings | str | None = None, + mcp: FastMCP | None = None, + human_input_callback: HumanInputCallback | None = None, + elicitation_callback: ElicitationCallback | None = None, + signal_notification: SignalWaitCallback | None = None, upstream_session: Optional["ServerSession"] = None, - model_selector: ModelSelector = None, + model_selector: ModelSelector | None = None, ): """ Initialize the application with a name and optional settings. @@ -72,6 +74,9 @@ def __init__( provide a detailed description, since it will be used as the server's description. settings: Application configuration - If unspecified, the settings are loaded from mcp_agent.config.yaml. If this is a string, it is treated as the path to the config file to load. + mcp: MCP server instance to use for the application to expose agents and workflows as tools. + If not provided, a default FastMCP server will be created by create_mcp_server_for_app(). + If provided, the MCPApp will add tools to the provided server instance. human_input_callback: Callback for handling human input signal_notification: Callback for getting notified on workflow signals/events. upstream_session: Upstream session if the MCPApp is running as a server to an MCP client. @@ -79,6 +84,7 @@ def __init__( """ self.name = name self.description = description or "MCP Agent Application" + self.mcp = mcp # We use these to initialize the context in initialize() if settings is None: diff --git a/src/mcp_agent/core/context.py b/src/mcp_agent/core/context.py index f509715d8..1b07baa04 100644 --- a/src/mcp_agent/core/context.py +++ b/src/mcp_agent/core/context.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, ConfigDict from mcp import ServerSession +from mcp.server.fastmcp import FastMCP from opentelemetry import trace @@ -88,6 +89,10 @@ class Context(BaseModel): arbitrary_types_allowed=True, # Tell Pydantic to defer type evaluation ) + @property + def mcp(self) -> FastMCP | None: + return self.app.mcp if self.app else None + async def configure_otel( config: "Settings", session_id: str | None = None diff --git a/src/mcp_agent/server/app_server.py b/src/mcp_agent/server/app_server.py index 4a8eec3a6..5d3146340 100644 --- a/src/mcp_agent/server/app_server.py +++ b/src/mcp_agent/server/app_server.py @@ -6,7 +6,7 @@ import json from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any, Dict, List, Type, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Tuple, Type, TYPE_CHECKING from mcp.server.fastmcp import Context as MCPContext, FastMCP from mcp.server.fastmcp.exceptions import ToolError @@ -81,6 +81,53 @@ def workflow_registry(self) -> WorkflowRegistry: return self.context.workflow_registry +def _get_attached_app(mcp: FastMCP) -> MCPApp | None: + """Return the MCPApp instance attached to the FastMCP server, if any.""" + return getattr(mcp, "_mcp_agent_app", None) + + +def _get_attached_server_context(mcp: FastMCP) -> ServerContext | None: + """Return the ServerContext attached to the FastMCP server, if any.""" + return getattr(mcp, "_mcp_agent_server_context", None) + + +def _resolve_workflows_and_context( + ctx: MCPContext, +) -> Tuple[Dict[str, Type["Workflow"]] | None, Optional["Context"]]: + """Resolve the workflows mapping and underlying app context regardless of startup mode. + + Tries lifespan ServerContext first (including compatible mocks), then attached app. + """ + # Try lifespan-provided ServerContext first + lifespan_ctx = getattr(ctx.request_context, "lifespan_context", None) + if ( + lifespan_ctx is not None + and hasattr(lifespan_ctx, "workflows") + and hasattr(lifespan_ctx, "context") + ): + return lifespan_ctx.workflows, lifespan_ctx.context + + # Fall back to app attached to FastMCP + app: MCPApp | None = _get_attached_app(ctx.fastmcp) + if app is not None: + return app.workflows, app.context + + return None, None + + +def _resolve_workflow_registry(ctx: MCPContext) -> WorkflowRegistry | None: + """Resolve the workflow registry regardless of startup mode.""" + lifespan_ctx = getattr(ctx.request_context, "lifespan_context", None) + if lifespan_ctx is not None and hasattr(lifespan_ctx, "workflow_registry"): + return lifespan_ctx.workflow_registry + + app: MCPApp | None = _get_attached_app(ctx.fastmcp) + if app is not None and app.context is not None: + return app.context.workflow_registry + + return None + + def create_mcp_server_for_app(app: MCPApp, **kwargs: Any) -> FastMCP: """ Create an MCP server for a given MCPApp instance. @@ -103,7 +150,7 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: # Create the server context which is available during the lifespan of the server server_context = ServerContext(mcp=mcp, context=app.context) - # Register initial workflow tools + # Register initial workflow tools when running with our managed lifespan create_workflow_tools(mcp, server_context) try: @@ -112,16 +159,35 @@ async def app_specific_lifespan(mcp: FastMCP) -> AsyncIterator[ServerContext]: # Don't clean up the MCPApp here - let the caller handle that pass - # Create FastMCP server with the app's name - mcp = FastMCP( - name=app.name or "mcp_agent_server", - # TODO: saqadri (MAC) - create a much more detailed description - # based on all the available agents and workflows, - # or use the MCPApp's description if available. - instructions=f"MCP server exposing {app.name} workflows and agents. Description: {app.description}", - lifespan=app_specific_lifespan, - **kwargs, - ) + # Create or attach FastMCP server + if app.mcp: + # Using an externally provided FastMCP instance: attach app and context + mcp = app.mcp + setattr(mcp, "_mcp_agent_app", app) + + # Create and attach a ServerContext since we don't control the server's lifespan + # This enables tools to access context via ctx.fastmcp._mcp_agent_server_context + if not hasattr(mcp, "_mcp_agent_server_context"): + server_context = ServerContext(mcp=mcp, context=app.context) + setattr(mcp, "_mcp_agent_server_context", server_context) + else: + server_context = getattr(mcp, "_mcp_agent_server_context") + + # Register per-workflow tools + create_workflow_tools(mcp, server_context) + else: + mcp = FastMCP( + name=app.name or "mcp_agent_server", + # TODO: saqadri (MAC) - create a much more detailed description + # based on all the available agents and workflows, + # or use the MCPApp's description if available. + instructions=f"MCP server exposing {app.name} workflows and agents. Description: {app.description}", + lifespan=app_specific_lifespan, + **kwargs, + ) + # Store the server on the app so it's discoverable and can be extended further + app.mcp = mcp + setattr(mcp, "_mcp_agent_app", app) # region Workflow Tools @@ -132,10 +198,10 @@ def list_workflows(ctx: MCPContext) -> Dict[str, Dict[str, Any]]: Returns information about each workflow type including name, description, and parameters. This helps in making an informed decision about which workflow to run. """ - server_context: ServerContext = ctx.request_context.lifespan_context - - result = {} - for workflow_name, workflow_cls in server_context.workflows.items(): + result: Dict[str, Dict[str, Any]] = {} + workflows, _ = _resolve_workflows_and_context(ctx) + workflows = workflows or {} + for workflow_name, workflow_cls in workflows.items(): # Get workflow documentation run_fn_tool = FastTool.from_function(workflow_cls.run) @@ -167,7 +233,11 @@ async def list_workflow_runs(ctx: MCPContext) -> List[Dict[str, Any]]: Returns: A dictionary mapping workflow instance IDs to their detailed status information. """ - server_context: ServerContext = ctx.request_context.lifespan_context + server_context = getattr( + ctx.request_context, "lifespan_context", None + ) or _get_attached_server_context(ctx.fastmcp) + if server_context is None or not hasattr(server_context, "workflow_registry"): + raise ToolError("Server context not available for MCPApp Server.") # Get all workflow statuses from the registry workflow_statuses = ( @@ -412,20 +482,21 @@ async def _workflow_run( run_parameters: Dict[str, Any] | None = None, **kwargs: Any, ) -> Dict[str, str]: - server_context: ServerContext = ctx.request_context.lifespan_context + # Resolve workflows and app context irrespective of startup mode + workflows_dict, app_context = _resolve_workflows_and_context(ctx) + if not workflows_dict or not app_context: + raise ToolError("Server context not available for MCPApp Server.") - if workflow_name not in server_context.workflows: + if workflow_name not in workflows_dict: raise ToolError(f"Workflow '{workflow_name}' not found.") # Get the workflow class - workflow_cls = server_context.workflows[workflow_name] + workflow_cls = workflows_dict[workflow_name] # Create and initialize the workflow instance using the factory method try: # Create workflow instance - workflow = await workflow_cls.create( - name=workflow_name, context=server_context.context - ) + workflow = await workflow_cls.create(name=workflow_name, context=app_context) run_parameters = run_parameters or {} @@ -459,8 +530,7 @@ async def _workflow_run( async def _workflow_status( ctx: MCPContext, run_id: str, workflow_name: str | None = None ) -> Dict[str, Any]: - server_context: ServerContext = ctx.request_context.lifespan_context - workflow_registry: WorkflowRegistry = server_context.workflow_registry + workflow_registry: WorkflowRegistry | None = _resolve_workflow_registry(ctx) if not workflow_registry: raise ToolError("Workflow registry not found for MCPApp Server.") diff --git a/tests/server/test_app_server.py b/tests/server/test_app_server.py index 3769adc9c..e5e0b8126 100644 --- a/tests/server/test_app_server.py +++ b/tests/server/test_app_server.py @@ -1,5 +1,6 @@ import pytest from unittest.mock import AsyncMock, MagicMock +from types import SimpleNamespace from mcp_agent.server.app_server import _workflow_run from mcp_agent.executor.workflow import WorkflowExecution @@ -7,12 +8,15 @@ @pytest.fixture def mock_server_context(): """Mock server context for testing""" - context = MagicMock() - context.request_context = MagicMock() - context.request_context.lifespan_context = MagicMock() - context.request_context.lifespan_context.workflows = {} - context.request_context.lifespan_context.context = MagicMock() - return context + # Build a minimal ctx object compatible with new resolution helpers + app_context = MagicMock() + server_context = SimpleNamespace(workflows={}, context=app_context) + + ctx = MagicMock() + ctx.request_context = SimpleNamespace(lifespan_context=server_context) + # Ensure no attached app path is used in tests; rely on lifespan path + ctx.fastmcp = SimpleNamespace(_mcp_agent_app=None) + return ctx @pytest.fixture