From 0edd193c5843a6c8dfb25910f5780fdab9242278 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 8 Aug 2025 23:00:03 -0400 Subject: [PATCH] Router LLM should extend AugmentedLLM to enable composability --- src/mcp_agent/workflows/llm/augmented_llm.py | 17 ++++ src/mcp_agent/workflows/router/router_llm.py | 90 +++++++++++++++++-- .../workflows/router/router_llm_anthropic.py | 2 + .../workflows/router/router_llm_openai.py | 2 + 4 files changed, 102 insertions(+), 9 deletions(-) diff --git a/src/mcp_agent/workflows/llm/augmented_llm.py b/src/mcp_agent/workflows/llm/augmented_llm.py index 8f53ce65f..50fefa9b8 100644 --- a/src/mcp_agent/workflows/llm/augmented_llm.py +++ b/src/mcp_agent/workflows/llm/augmented_llm.py @@ -20,6 +20,7 @@ CallToolResult, CreateMessageRequestParams, CreateMessageResult, + ListToolsResult, SamplingMessage, TextContent, PromptMessage, @@ -679,3 +680,19 @@ def _gen_name(self, name: str | None, prefix: str | None) -> str: identifier = str(self.context.executor.uuid()) return f"{prefix}-{identifier}" + + # --- Agent convenience proxies ------------------------------------------------- + async def list_tools(self, server_name: str | None = None) -> ListToolsResult: + """Proxy to the underlying agent's list_tools for a simpler API.""" + return await self.agent.list_tools(server_name=server_name) + + async def close(self): + """Close underlying agent connections.""" + await self.agent.close() + + async def __aenter__(self): + await self.agent.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.agent.__aexit__(exc_type, exc_val, exc_tb) diff --git a/src/mcp_agent/workflows/router/router_llm.py b/src/mcp_agent/workflows/router/router_llm.py index 9165ef4cd..5fa639524 100644 --- a/src/mcp_agent/workflows/router/router_llm.py +++ b/src/mcp_agent/workflows/router/router_llm.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Literal, Optional, TYPE_CHECKING +from typing import Any, Callable, List, Literal, Optional, TYPE_CHECKING from opentelemetry import trace from pydantic import BaseModel @@ -6,7 +6,7 @@ from mcp_agent.agents.agent import Agent from mcp_agent.tracing.semconv import GEN_AI_REQUEST_TOP_K from mcp_agent.tracing.telemetry import get_tracer -from mcp_agent.workflows.llm.augmented_llm import AugmentedLLM +from mcp_agent.workflows.llm.augmented_llm import AugmentedLLM, RequestParams from mcp_agent.workflows.router.router_base import ResultT, Router, RouterResult from mcp_agent.logging.logger import get_logger @@ -79,9 +79,14 @@ class StructuredResponse(BaseModel): """A list of categories to route the input to.""" -class LLMRouter(Router): +class LLMRouter(AugmentedLLM[Any, Any], Router): """ - A router that uses an LLM to route an input to a specific category. + A router workflow that also behaves like an AugmentedLLM. + + - As a Router: provides route/route_to_* APIs that return routing targets. + - As an AugmentedLLM: generate/generate_str/generate_structured delegate to routing + and return the routing outputs in unstructured or structured forms, enabling + composition with other AugmentedLLM-based workflows (Parallel, Evaluator/Optimizer, etc.). """ def __init__( @@ -94,7 +99,9 @@ def __init__( context: Optional["Context"] = None, **kwargs, ): - super().__init__( + # Initialize Router side (category discovery, etc.) + Router.__init__( + self, server_names=server_names, agents=agents, functions=functions, @@ -103,7 +110,21 @@ def __init__( **kwargs, ) - self.llm = llm + # Initialize AugmentedLLM side for workflow composition + # We do not use this class itself to call a provider; we delegate to the + # provided classifier LLM. Still, initializing allows uniform tracing/name. + AugmentedLLM.__init__( + self, + name=(f"{llm.name}-router" if getattr(llm, "name", None) else None), + instruction="You are a router workflow that returns categories.", + context=context, + **kwargs, + ) + + # Inner LLM that makes the routing decision + self.classifier_llm: AugmentedLLM = llm + # Back-compat alias + self.llm: AugmentedLLM = llm @classmethod async def create( @@ -248,8 +269,8 @@ async def _route_with_llm( context=context, request=request, top_k=top_k ) - # Get routes from LLM - response = await self.llm.generate_structured( + # Get routes from the inner/classifier LLM + response = await self.classifier_llm.generate_structured( message=prompt, response_model=StructuredResponse, ) @@ -312,7 +333,8 @@ def _annotate_span_for_route_request( return span.set_attribute("request", request) span.set_attribute(GEN_AI_REQUEST_TOP_K, top_k) - span.set_attribute("llm", self.llm.name) + if getattr(self.classifier_llm, "name", None): + span.set_attribute("llm", self.classifier_llm.name) span.set_attribute( "agents", [a.name for a in self.agents] if self.agents else [] ) @@ -372,3 +394,53 @@ def _generate_context( idx += 1 return "\n\n".join(context_list) + + # --- AugmentedLLM interface ------------------------------------------------- + async def generate( + self, + message: str | Any | List[Any], + request_params: RequestParams | None = None, + ) -> List[Any]: + """Return routing results as a list for composition with other workflows. + + The return value is a list of dicts: [{"category": name, "confidence": str, "reasoning": str?}] + """ + results = await self._route_with_llm(str(message), top_k=5) + payload = [ + { + "category": ( + r.result + if isinstance(r.result, str) + else ( + r.result.name + if isinstance(r.result, Agent) + else getattr(r.result, "__name__", str(r.result)) + ) + ), + "confidence": r.confidence, + "reasoning": r.reasoning, + } + for r in results + ] + return payload # type: ignore[return-value] + + async def generate_str( + self, + message: str | Any | List[Any], + request_params: RequestParams | None = None, + ) -> str: + """Return routing results as JSON string.""" + import json + + payload = await self.generate(message=message, request_params=request_params) + return json.dumps({"categories": payload}) + + async def generate_structured( + self, + message: str | Any | List[Any], + response_model: type[StructuredResponse], + request_params: RequestParams | None = None, + ) -> StructuredResponse: + """Return routing results as a StructuredResponse Pydantic model.""" + txt = await self.generate_str(message=message, request_params=request_params) + return response_model.model_validate_json(txt) diff --git a/src/mcp_agent/workflows/router/router_llm_anthropic.py b/src/mcp_agent/workflows/router/router_llm_anthropic.py index 5b37136f3..79f72057a 100644 --- a/src/mcp_agent/workflows/router/router_llm_anthropic.py +++ b/src/mcp_agent/workflows/router/router_llm_anthropic.py @@ -1,6 +1,7 @@ from typing import Callable, List, Optional, TYPE_CHECKING from mcp_agent.agents.agent import Agent +from mcp_agent.workflows.llm.augmented_llm import AugmentedLLM from mcp_agent.workflows.llm.augmented_llm_anthropic import AnthropicAugmentedLLM from mcp_agent.workflows.router.router_llm import LLMRouter @@ -49,6 +50,7 @@ def __init__( @classmethod async def create( cls, + llm: AugmentedLLM | None = None, server_names: List[str] | None = None, agents: List[Agent] | None = None, functions: List[Callable] | None = None, diff --git a/src/mcp_agent/workflows/router/router_llm_openai.py b/src/mcp_agent/workflows/router/router_llm_openai.py index 6c4d456c9..50ade163a 100644 --- a/src/mcp_agent/workflows/router/router_llm_openai.py +++ b/src/mcp_agent/workflows/router/router_llm_openai.py @@ -1,6 +1,7 @@ from typing import Callable, List, Optional, TYPE_CHECKING from mcp_agent.agents.agent import Agent +from mcp_agent.workflows.llm.augmented_llm import AugmentedLLM from mcp_agent.workflows.llm.augmented_llm_openai import OpenAIAugmentedLLM from mcp_agent.workflows.router.router_llm import LLMRouter @@ -49,6 +50,7 @@ def __init__( @classmethod async def create( cls, + llm: AugmentedLLM | None = None, server_names: List[str] | None = None, agents: List[Agent] | None = None, functions: List[Callable] | None = None,