10
10
from langgraph .runtime import Runtime
11
11
12
12
13
- def test_model_request_tools_are_strings () -> None :
14
- """Test that ModelRequest.tools contains tool names as strings, not tool objects."""
13
+ def test_model_request_tools_are_base_tools () -> None :
14
+ """Test that ModelRequest.tools contains BaseTool objects."""
15
15
captured_requests : list [ModelRequest ] = []
16
16
17
17
@tool
@@ -43,16 +43,15 @@ def modify_model_request(
43
43
# Verify that at least one request was captured
44
44
assert len (captured_requests ) > 0
45
45
46
- # Check that tools in the request are strings (tool names)
46
+ # Check that tools in the request are BaseTool objects
47
47
request = captured_requests [0 ]
48
48
assert isinstance (request .tools , list )
49
49
assert len (request .tools ) == 2
50
- assert all (isinstance (tool_name , str ) for tool_name in request .tools )
51
- assert set (request .tools ) == {"search_tool" , "calculator" }
50
+ assert {t .name for t in request .tools } == {"search_tool" , "calculator" }
52
51
53
52
54
- def test_middleware_can_modify_tool_names () -> None :
55
- """Test that middleware can modify the list of tool names in ModelRequest."""
53
+ def test_middleware_can_modify_tools () -> None :
54
+ """Test that middleware can modify the list of tools in ModelRequest."""
56
55
57
56
@tool
58
57
def tool_a (input : str ) -> str :
@@ -74,7 +73,7 @@ def modify_model_request(
74
73
self , request : ModelRequest , state : AgentState , runtime : Runtime
75
74
) -> ModelRequest :
76
75
# Only allow tool_a and tool_b
77
- request .tools = ["tool_a" , "tool_b" ]
76
+ request .tools = [t for t in request . tools if t . name in [ "tool_a" , "tool_b" ] ]
78
77
return request
79
78
80
79
# Model will try to call tool_a
@@ -98,20 +97,26 @@ def modify_model_request(
98
97
assert tool_messages [0 ].name == "tool_a"
99
98
100
99
101
- def test_unknown_tool_name_raises_error () -> None :
102
- """Test that using an unknown tool name in ModelRequest raises a clear error."""
100
+ def test_unknown_tool_raises_error () -> None :
101
+ """Test that using an unknown tool in ModelRequest raises a clear error."""
102
+ from langchain_core .tools import BaseTool
103
103
104
104
@tool
105
105
def known_tool (input : str ) -> str :
106
106
"""A known tool."""
107
107
return "result"
108
108
109
+ @tool
110
+ def unknown_tool (input : str ) -> str :
111
+ """An unknown tool not passed to create_agent."""
112
+ return "unknown"
113
+
109
114
class BadMiddleware (AgentMiddleware ):
110
115
def modify_model_request (
111
116
self , request : ModelRequest , state : AgentState , runtime : Runtime
112
117
) -> ModelRequest :
113
- # Add an unknown tool name
114
- request .tools = [ "known_tool" , " unknown_tool" ]
118
+ # Add an unknown tool
119
+ request .tools = request . tools + [ unknown_tool ]
115
120
return request
116
121
117
122
agent = create_agent (
@@ -149,7 +154,7 @@ def modify_model_request(
149
154
) -> ModelRequest :
150
155
# Remove admin_tool if not admin
151
156
if not state .get ("is_admin" , False ):
152
- request .tools = [name for name in request .tools if name != "admin_tool" ]
157
+ request .tools = [t for t in request .tools if t . name != "admin_tool" ]
153
158
return request
154
159
155
160
model = FakeToolCallingModel ()
@@ -224,20 +229,20 @@ class FirstMiddleware(AgentMiddleware):
224
229
def modify_model_request (
225
230
self , request : ModelRequest , state : AgentState , runtime : Runtime
226
231
) -> ModelRequest :
227
- modification_order .append (request .tools . copy () )
232
+ modification_order .append ([ t . name for t in request .tools ] )
228
233
# Remove tool_c
229
- request .tools = [name for name in request .tools if name != "tool_c" ]
234
+ request .tools = [t for t in request .tools if t . name != "tool_c" ]
230
235
return request
231
236
232
237
class SecondMiddleware (AgentMiddleware ):
233
238
def modify_model_request (
234
239
self , request : ModelRequest , state : AgentState , runtime : Runtime
235
240
) -> ModelRequest :
236
- modification_order .append (request .tools . copy () )
241
+ modification_order .append ([ t . name for t in request .tools ] )
237
242
# Should not see tool_c here
238
- assert "tool_c" not in request .tools
243
+ assert all ( t . name != "tool_c" for t in request .tools )
239
244
# Remove tool_b
240
- request .tools = [name for name in request .tools if name != "tool_b" ]
245
+ request .tools = [t for t in request .tools if t . name != "tool_b" ]
241
246
return request
242
247
243
248
agent = create_agent (
0 commit comments