Skip to content

Commit 46b87e4

Browse files
authored
chore(langchain_v1): change modifyModelRequest back to tools (#33270)
Seems like a much better devx with fairly little downside (we'll document that you can't register new tools)
1 parent 905c6d7 commit 46b87e4

File tree

3 files changed

+33
-30
lines changed

3 files changed

+33
-30
lines changed

libs/langchain_v1/langchain/agents/middleware/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class ModelRequest:
6363
system_prompt: str | None
6464
messages: list[AnyMessage] # excluding system prompt
6565
tool_choice: Any | None
66-
tools: list[str]
66+
tools: list[BaseTool]
6767
response_format: ResponseFormat | None
6868
model_settings: dict[str, Any] = field(default_factory=dict)
6969

libs/langchain_v1/langchain/agents/middleware_agent.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,11 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
395395
"""
396396
# Validate requested tools are available
397397
tools_by_name = {t.name: t for t in default_tools}
398-
unknown_tools = [name for name in request.tools if name not in tools_by_name]
399-
if unknown_tools:
398+
unknown_tool_names = [t.name for t in request.tools if t.name not in tools_by_name]
399+
if unknown_tool_names:
400400
available_tools = sorted(tools_by_name.keys())
401401
msg = (
402-
f"Middleware returned unknown tool names: {unknown_tools}\n\n"
402+
f"Middleware returned unknown tool names: {unknown_tool_names}\n\n"
403403
f"Available tools: {available_tools}\n\n"
404404
"To fix this issue:\n"
405405
"1. Ensure the tools are passed to create_agent() via "
@@ -411,8 +411,6 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
411411
)
412412
raise ValueError(msg)
413413

414-
requested_tools = [tools_by_name[name] for name in request.tools]
415-
416414
# Determine effective response format (auto-detect if needed)
417415
effective_response_format: ResponseFormat | None = request.response_format
418416
if (
@@ -432,7 +430,7 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
432430
kwargs = effective_response_format.to_model_kwargs()
433431
return (
434432
request.model.bind_tools(
435-
requested_tools, strict=True, **kwargs, **request.model_settings
433+
request.tools, strict=True, **kwargs, **request.model_settings
436434
),
437435
effective_response_format,
438436
)
@@ -442,16 +440,16 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
442440
tool_choice = "any" if structured_output_tools else request.tool_choice
443441
return (
444442
request.model.bind_tools(
445-
requested_tools, tool_choice=tool_choice, **request.model_settings
443+
request.tools, tool_choice=tool_choice, **request.model_settings
446444
),
447445
effective_response_format,
448446
)
449447

450448
# No structured output - standard model binding
451-
if requested_tools:
449+
if request.tools:
452450
return (
453451
request.model.bind_tools(
454-
requested_tools, tool_choice=request.tool_choice, **request.model_settings
452+
request.tools, tool_choice=request.tool_choice, **request.model_settings
455453
),
456454
None,
457455
)
@@ -461,7 +459,7 @@ def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, An
461459
"""Sync model request handler with sequential middleware processing."""
462460
request = ModelRequest(
463461
model=model,
464-
tools=[t.name for t in default_tools],
462+
tools=default_tools,
465463
system_prompt=system_prompt,
466464
response_format=initial_response_format,
467465
messages=state["messages"],
@@ -498,7 +496,7 @@ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[
498496
"""Async model request handler with sequential middleware processing."""
499497
request = ModelRequest(
500498
model=model,
501-
tools=[t.name for t in default_tools],
499+
tools=default_tools,
502500
system_prompt=system_prompt,
503501
response_format=initial_response_format,
504502
messages=state["messages"],

libs/langchain_v1/tests/unit_tests/agents/test_middleware_tools.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from langgraph.runtime import Runtime
1111

1212

13-
def test_model_request_tools_are_strings() -> None:
14-
"""Test that ModelRequest.tools contains tool names as strings, not tool objects."""
13+
def test_model_request_tools_are_base_tools() -> None:
14+
"""Test that ModelRequest.tools contains BaseTool objects."""
1515
captured_requests: list[ModelRequest] = []
1616

1717
@tool
@@ -43,16 +43,15 @@ def modify_model_request(
4343
# Verify that at least one request was captured
4444
assert len(captured_requests) > 0
4545

46-
# Check that tools in the request are strings (tool names)
46+
# Check that tools in the request are BaseTool objects
4747
request = captured_requests[0]
4848
assert isinstance(request.tools, list)
4949
assert len(request.tools) == 2
50-
assert all(isinstance(tool_name, str) for tool_name in request.tools)
51-
assert set(request.tools) == {"search_tool", "calculator"}
50+
assert {t.name for t in request.tools} == {"search_tool", "calculator"}
5251

5352

54-
def test_middleware_can_modify_tool_names() -> None:
55-
"""Test that middleware can modify the list of tool names in ModelRequest."""
53+
def test_middleware_can_modify_tools() -> None:
54+
"""Test that middleware can modify the list of tools in ModelRequest."""
5655

5756
@tool
5857
def tool_a(input: str) -> str:
@@ -74,7 +73,7 @@ def modify_model_request(
7473
self, request: ModelRequest, state: AgentState, runtime: Runtime
7574
) -> ModelRequest:
7675
# Only allow tool_a and tool_b
77-
request.tools = ["tool_a", "tool_b"]
76+
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
7877
return request
7978

8079
# Model will try to call tool_a
@@ -98,20 +97,26 @@ def modify_model_request(
9897
assert tool_messages[0].name == "tool_a"
9998

10099

101-
def test_unknown_tool_name_raises_error() -> None:
102-
"""Test that using an unknown tool name in ModelRequest raises a clear error."""
100+
def test_unknown_tool_raises_error() -> None:
101+
"""Test that using an unknown tool in ModelRequest raises a clear error."""
102+
from langchain_core.tools import BaseTool
103103

104104
@tool
105105
def known_tool(input: str) -> str:
106106
"""A known tool."""
107107
return "result"
108108

109+
@tool
110+
def unknown_tool(input: str) -> str:
111+
"""An unknown tool not passed to create_agent."""
112+
return "unknown"
113+
109114
class BadMiddleware(AgentMiddleware):
110115
def modify_model_request(
111116
self, request: ModelRequest, state: AgentState, runtime: Runtime
112117
) -> ModelRequest:
113-
# Add an unknown tool name
114-
request.tools = ["known_tool", "unknown_tool"]
118+
# Add an unknown tool
119+
request.tools = request.tools + [unknown_tool]
115120
return request
116121

117122
agent = create_agent(
@@ -149,7 +154,7 @@ def modify_model_request(
149154
) -> ModelRequest:
150155
# Remove admin_tool if not admin
151156
if not state.get("is_admin", False):
152-
request.tools = [name for name in request.tools if name != "admin_tool"]
157+
request.tools = [t for t in request.tools if t.name != "admin_tool"]
153158
return request
154159

155160
model = FakeToolCallingModel()
@@ -224,20 +229,20 @@ class FirstMiddleware(AgentMiddleware):
224229
def modify_model_request(
225230
self, request: ModelRequest, state: AgentState, runtime: Runtime
226231
) -> ModelRequest:
227-
modification_order.append(request.tools.copy())
232+
modification_order.append([t.name for t in request.tools])
228233
# Remove tool_c
229-
request.tools = [name for name in request.tools if name != "tool_c"]
234+
request.tools = [t for t in request.tools if t.name != "tool_c"]
230235
return request
231236

232237
class SecondMiddleware(AgentMiddleware):
233238
def modify_model_request(
234239
self, request: ModelRequest, state: AgentState, runtime: Runtime
235240
) -> ModelRequest:
236-
modification_order.append(request.tools.copy())
241+
modification_order.append([t.name for t in request.tools])
237242
# Should not see tool_c here
238-
assert "tool_c" not in request.tools
243+
assert all(t.name != "tool_c" for t in request.tools)
239244
# Remove tool_b
240-
request.tools = [name for name in request.tools if name != "tool_b"]
245+
request.tools = [t for t in request.tools if t.name != "tool_b"]
241246
return request
242247

243248
agent = create_agent(

0 commit comments

Comments
 (0)