Skip to content

Commit df2ecd9

Browse files
authored
feat(langchain_v1): add llm selection middleware (#33272)
* Add llm based tool selection middleware. * Note that we might want some form of caching for when the agent is inside an active tool calling loop as the tool selection isn't expected to change during that time. API: ```python class LLMToolSelectorMiddleware(AgentMiddleware): """Uses an LLM to select relevant tools before calling the main model. When an agent has many tools available, this middleware filters them down to only the most relevant ones for the user's query. This reduces token usage and helps the main model focus on the right tools. Examples: Limit to 3 tools: ```python from langchain.agents.middleware import LLMToolSelectorMiddleware middleware = LLMToolSelectorMiddleware(max_tools=3) agent = create_agent( model="openai:gpt-4o", tools=[tool1, tool2, tool3, tool4, tool5], middleware=[middleware], ) ``` Use a smaller model for selection: ```python middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2) ``` """ def __init__( self, *, model: str | BaseChatModel | None = None, system_prompt: str = DEFAULT_SYSTEM_PROMPT, max_tools: int | None = None, always_include: list[str] | None = None, ) -> None: """Initialize the tool selector. Args: model: Model to use for selection. If not provided, uses the agent's main model. Can be a model identifier string or BaseChatModel instance. system_prompt: Instructions for the selection model. max_tools: Maximum number of tools to select. If the model selects more, only the first max_tools will be used. No limit if not specified. always_include: Tool names to always include regardless of selection. These do not count against the max_tools limit. """ ``` ```python """Test script for LLM tool selection middleware.""" from langchain.agents import create_agent from langchain.agents.middleware import LLMToolSelectorMiddleware from langchain_core.tools import tool @tool def get_weather(location: str) -> str: """Get current weather for a location.""" return f"Weather in {location}: 72°F, sunny" @tool def search_web(query: str) -> str: """Search the web for information.""" return f"Search results for: {query}" @tool def calculate(expression: str) -> str: """Perform mathematical calculations.""" return f"Result of {expression}: 42" @tool def send_email(to: str, subject: str) -> str: """Send an email to someone.""" return f"Email sent to {to} with subject: {subject}" @tool def get_stock_price(symbol: str) -> str: """Get current stock price for a symbol.""" return f"Stock price for {symbol}: $150.25" @tool def translate_text(text: str, target_language: str) -> str: """Translate text to another language.""" return f"Translated '{text}' to {target_language}" @tool def set_reminder(task: str, time: str) -> str: """Set a reminder for a task.""" return f"Reminder set: {task} at {time}" @tool def get_news(topic: str) -> str: """Get latest news about a topic.""" return f"Latest news about {topic}" @tool def book_flight(destination: str, date: str) -> str: """Book a flight to a destination.""" return f"Flight booked to {destination} on {date}" @tool def get_restaurant_recommendations(city: str, cuisine: str) -> str: """Get restaurant recommendations.""" return f"Top {cuisine} restaurants in {city}" # Create agent with tool selection middleware middleware = LLMToolSelectorMiddleware( model="openai:gpt-4o-mini", max_tools=3, ) agent = create_agent( model="openai:gpt-4o", tools=[ get_weather, search_web, calculate, send_email, get_stock_price, translate_text, set_reminder, get_news, book_flight, get_restaurant_recommendations, ], middleware=[middleware], ) # Test with a query that should select specific tools response = agent.invoke( {"messages": [{"role": "user", "content": "I need to find restaurants"}]} ) print(response) ```
1 parent bdb7dbb commit df2ecd9

File tree

3 files changed

+893
-0
lines changed

3 files changed

+893
-0
lines changed

libs/langchain_v1/langchain/agents/middleware/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .prompt_caching import AnthropicPromptCachingMiddleware
77
from .summarization import SummarizationMiddleware
88
from .tool_call_limit import ToolCallLimitMiddleware
9+
from .tool_selection import LLMToolSelectorMiddleware
910
from .types import (
1011
AgentMiddleware,
1112
AgentState,
@@ -23,6 +24,7 @@
2324
# should move to langchain-anthropic if we decide to keep it
2425
"AnthropicPromptCachingMiddleware",
2526
"HumanInTheLoopMiddleware",
27+
"LLMToolSelectorMiddleware",
2628
"ModelRequest",
2729
"PIIDetectionError",
2830
"PIIMiddleware",
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
"""LLM-based tool selector middleware."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from dataclasses import dataclass
7+
from typing import TYPE_CHECKING, Annotated, Literal, Union
8+
9+
from langchain_core.language_models.chat_models import BaseChatModel
10+
from langchain_core.messages import HumanMessage
11+
from pydantic import Field, TypeAdapter
12+
from typing_extensions import TypedDict
13+
14+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest, StateT
15+
from langchain.chat_models.base import init_chat_model
16+
17+
if TYPE_CHECKING:
18+
from langgraph.runtime import Runtime
19+
from langgraph.typing import ContextT
20+
21+
from langchain.tools import BaseTool
22+
23+
logger = logging.getLogger(__name__)
24+
25+
DEFAULT_SYSTEM_PROMPT = (
26+
"Your goal is to select the most relevant tools for answering the user's query."
27+
)
28+
29+
30+
@dataclass
31+
class _SelectionRequest:
32+
"""Prepared inputs for tool selection."""
33+
34+
available_tools: list[BaseTool]
35+
system_message: str
36+
last_user_message: HumanMessage
37+
model: BaseChatModel
38+
valid_tool_names: list[str]
39+
40+
41+
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter:
42+
"""Create a structured output schema for tool selection.
43+
44+
Args:
45+
tools: Available tools to include in the schema.
46+
47+
Returns:
48+
TypeAdapter for a schema where each tool name is a Literal with its description.
49+
"""
50+
if not tools:
51+
msg = "Invalid usage: tools must be non-empty"
52+
raise AssertionError(msg)
53+
54+
# Create a Union of Annotated Literal types for each tool name with description
55+
# Example: Union[Annotated[Literal["tool1"], Field(description="...")], ...] noqa: ERA001
56+
literals = [
57+
Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools
58+
]
59+
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
60+
61+
description = "Tools to use. Place the most relevant tools first."
62+
63+
class ToolSelectionResponse(TypedDict):
64+
"""Use to select relevant tools."""
65+
66+
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
67+
68+
return TypeAdapter(ToolSelectionResponse)
69+
70+
71+
def _render_tool_list(tools: list[BaseTool]) -> str:
72+
"""Format tools as markdown list.
73+
74+
Args:
75+
tools: Tools to format.
76+
77+
Returns:
78+
Markdown string with each tool on a new line.
79+
"""
80+
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
81+
82+
83+
class LLMToolSelectorMiddleware(AgentMiddleware):
84+
"""Uses an LLM to select relevant tools before calling the main model.
85+
86+
When an agent has many tools available, this middleware filters them down
87+
to only the most relevant ones for the user's query. This reduces token usage
88+
and helps the main model focus on the right tools.
89+
90+
Examples:
91+
Limit to 3 tools:
92+
```python
93+
from langchain.agents.middleware import LLMToolSelectorMiddleware
94+
95+
middleware = LLMToolSelectorMiddleware(max_tools=3)
96+
97+
agent = create_agent(
98+
model="openai:gpt-4o",
99+
tools=[tool1, tool2, tool3, tool4, tool5],
100+
middleware=[middleware],
101+
)
102+
```
103+
104+
Use a smaller model for selection:
105+
```python
106+
middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
107+
```
108+
"""
109+
110+
def __init__(
111+
self,
112+
*,
113+
model: str | BaseChatModel | None = None,
114+
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
115+
max_tools: int | None = None,
116+
always_include: list[str] | None = None,
117+
) -> None:
118+
"""Initialize the tool selector.
119+
120+
Args:
121+
model: Model to use for selection. If not provided, uses the agent's main model.
122+
Can be a model identifier string or BaseChatModel instance.
123+
system_prompt: Instructions for the selection model.
124+
max_tools: Maximum number of tools to select. If the model selects more,
125+
only the first max_tools will be used. No limit if not specified.
126+
always_include: Tool names to always include regardless of selection.
127+
These do not count against the max_tools limit.
128+
"""
129+
super().__init__()
130+
self.system_prompt = system_prompt
131+
self.max_tools = max_tools
132+
self.always_include = always_include or []
133+
134+
if isinstance(model, (BaseChatModel, type(None))):
135+
self.model: BaseChatModel | None = model
136+
else:
137+
self.model = init_chat_model(model)
138+
139+
def _prepare_selection_request(self, request: ModelRequest) -> _SelectionRequest | None:
140+
"""Prepare inputs for tool selection.
141+
142+
Returns:
143+
SelectionRequest with prepared inputs, or None if no selection is needed.
144+
"""
145+
# If no tools available, return None
146+
if not request.tools or len(request.tools) == 0:
147+
return None
148+
149+
# Validate that always_include tools exist
150+
if self.always_include:
151+
available_tool_names = {tool.name for tool in request.tools}
152+
missing_tools = [
153+
name for name in self.always_include if name not in available_tool_names
154+
]
155+
if missing_tools:
156+
msg = (
157+
f"Tools in always_include not found in request: {missing_tools}. "
158+
f"Available tools: {sorted(available_tool_names)}"
159+
)
160+
raise ValueError(msg)
161+
162+
# Separate tools that are always included from those available for selection
163+
available_tools = [tool for tool in request.tools if tool.name not in self.always_include]
164+
165+
# If no tools available for selection, return None
166+
if not available_tools:
167+
return None
168+
169+
system_message = self.system_prompt
170+
# If there's a max_tools limit, append instructions to the system prompt
171+
if self.max_tools is not None:
172+
system_message += (
173+
f"\nIMPORTANT: List the tool names in order of relevance, "
174+
f"with the most relevant first. "
175+
f"If you exceed the maximum number of tools, "
176+
f"only the first {self.max_tools} will be used."
177+
)
178+
179+
# Get the last user message from the conversation history
180+
last_user_message: HumanMessage
181+
for message in request.messages:
182+
if isinstance(message, HumanMessage):
183+
last_user_message = message
184+
break
185+
else:
186+
msg = "No user message found in request messages"
187+
raise AssertionError(msg)
188+
189+
model = self.model or request.model
190+
valid_tool_names = [tool.name for tool in available_tools]
191+
192+
return _SelectionRequest(
193+
available_tools=available_tools,
194+
system_message=system_message,
195+
last_user_message=last_user_message,
196+
model=model,
197+
valid_tool_names=valid_tool_names,
198+
)
199+
200+
def _process_selection_response(
201+
self,
202+
response: dict,
203+
available_tools: list[BaseTool],
204+
valid_tool_names: list[str],
205+
request: ModelRequest,
206+
) -> ModelRequest:
207+
"""Process the selection response and return filtered ModelRequest."""
208+
selected_tool_names: list[str] = []
209+
invalid_tool_selections = []
210+
211+
for tool_name in response["tools"]:
212+
if tool_name not in valid_tool_names:
213+
invalid_tool_selections.append(tool_name)
214+
continue
215+
216+
# Only add if not already selected and within max_tools limit
217+
if tool_name not in selected_tool_names and (
218+
self.max_tools is None or len(selected_tool_names) < self.max_tools
219+
):
220+
selected_tool_names.append(tool_name)
221+
222+
if invalid_tool_selections:
223+
msg = f"Model selected invalid tools: {invalid_tool_selections}"
224+
raise ValueError(msg)
225+
226+
# Filter tools based on selection and append always-included tools
227+
selected_tools = [tool for tool in available_tools if tool.name in selected_tool_names]
228+
always_included_tools = [tool for tool in request.tools if tool.name in self.always_include]
229+
selected_tools.extend(always_included_tools)
230+
request.tools = selected_tools
231+
return request
232+
233+
def modify_model_request(
234+
self,
235+
request: ModelRequest,
236+
state: StateT, # noqa: ARG002
237+
runtime: Runtime[ContextT], # noqa: ARG002
238+
) -> ModelRequest:
239+
"""Modify the model request to filter tools based on LLM selection."""
240+
selection_request = self._prepare_selection_request(request)
241+
if selection_request is None:
242+
return request
243+
244+
# Create dynamic response model with Literal enum of available tool names
245+
type_adapter = _create_tool_selection_response(selection_request.available_tools)
246+
schema = type_adapter.json_schema()
247+
structured_model = selection_request.model.with_structured_output(schema)
248+
249+
response = structured_model.invoke(
250+
[
251+
{"role": "system", "content": selection_request.system_message},
252+
selection_request.last_user_message,
253+
]
254+
)
255+
256+
# Response should be a dict since we're passing a schema (not a Pydantic model class)
257+
if not isinstance(response, dict):
258+
msg = f"Expected dict response, got {type(response)}"
259+
raise AssertionError(msg)
260+
return self._process_selection_response(
261+
response, selection_request.available_tools, selection_request.valid_tool_names, request
262+
)
263+
264+
async def amodify_model_request(
265+
self,
266+
request: ModelRequest,
267+
state: AgentState, # noqa: ARG002
268+
runtime: Runtime, # noqa: ARG002
269+
) -> ModelRequest:
270+
"""Modify the model request to filter tools based on LLM selection."""
271+
selection_request = self._prepare_selection_request(request)
272+
if selection_request is None:
273+
return request
274+
275+
# Create dynamic response model with Literal enum of available tool names
276+
type_adapter = _create_tool_selection_response(selection_request.available_tools)
277+
schema = type_adapter.json_schema()
278+
structured_model = selection_request.model.with_structured_output(schema)
279+
280+
response = await structured_model.ainvoke(
281+
[
282+
{"role": "system", "content": selection_request.system_message},
283+
selection_request.last_user_message,
284+
]
285+
)
286+
287+
# Response should be a dict since we're passing a schema (not a Pydantic model class)
288+
if not isinstance(response, dict):
289+
msg = f"Expected dict response, got {type(response)}"
290+
raise AssertionError(msg)
291+
return self._process_selection_response(
292+
response, selection_request.available_tools, selection_request.valid_tool_names, request
293+
)

0 commit comments

Comments
 (0)