16
16
17
17
import pytest
18
18
from pydantic import BaseModel
19
-
20
- from toolbox_core .sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool
21
19
from toolbox_core .protocol import ParameterSchema as CoreParameterSchema
20
+ from toolbox_core .sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool
22
21
from toolbox_core .utils import params_to_pydantic_model
23
22
24
23
from toolbox_langchain .client import ToolboxClient
25
24
from toolbox_langchain .tools import ToolboxTool
26
25
27
26
URL = "http://test_url"
28
27
29
- def create_mock_core_sync_tool (name = "mock-sync-tool" , doc = "Mock sync description." , model_name = "MockSyncModel" , params = None ):
28
+
29
+ def create_mock_core_sync_tool (
30
+ name = "mock-sync-tool" ,
31
+ doc = "Mock sync description." ,
32
+ model_name = "MockSyncModel" ,
33
+ params = None ,
34
+ ):
30
35
mock_tool = Mock (spec = ToolboxCoreSyncTool )
31
36
mock_tool .__name__ = name
32
37
mock_tool .__doc__ = doc
33
38
mock_tool ._name = model_name
34
39
if params is None :
35
- mock_tool ._params = [CoreParameterSchema (name = "param1" , type = "string" , description = "Param 1" )]
40
+ mock_tool ._params = [
41
+ CoreParameterSchema (name = "param1" , type = "string" , description = "Param 1" )
42
+ ]
36
43
else :
37
44
mock_tool ._params = params
38
45
return mock_tool
39
46
40
- def assert_pydantic_models_equivalent (model_cls1 : type [BaseModel ], model_cls2 : type [BaseModel ], expected_model_name : str ):
47
+
48
+ def assert_pydantic_models_equivalent (
49
+ model_cls1 : type [BaseModel ], model_cls2 : type [BaseModel ], expected_model_name : str
50
+ ):
41
51
assert issubclass (model_cls1 , BaseModel ), "model_cls1 is not a Pydantic BaseModel"
42
52
assert issubclass (model_cls2 , BaseModel ), "model_cls2 is not a Pydantic BaseModel"
43
-
44
- assert model_cls1 .__name__ == expected_model_name , f"model_cls1 name mismatch: expected { expected_model_name } , got { model_cls1 .__name__ } "
45
- assert model_cls2 .__name__ == expected_model_name , f"model_cls2 name mismatch: expected { expected_model_name } , got { model_cls2 .__name__ } "
53
+
54
+ assert (
55
+ model_cls1 .__name__ == expected_model_name
56
+ ), f"model_cls1 name mismatch: expected { expected_model_name } , got { model_cls1 .__name__ } "
57
+ assert (
58
+ model_cls2 .__name__ == expected_model_name
59
+ ), f"model_cls2 name mismatch: expected { expected_model_name } , got { model_cls2 .__name__ } "
46
60
47
61
fields1 = model_cls1 .model_fields
48
62
fields2 = model_cls2 .model_fields
49
63
50
- assert fields1 .keys () == fields2 .keys (), \
51
- f"Field names mismatch: { fields1 .keys ()} != { fields2 .keys ()} "
64
+ assert (
65
+ fields1 .keys () == fields2 .keys ()
66
+ ), f"Field names mismatch: { fields1 .keys ()} != { fields2 .keys ()} "
52
67
53
68
for field_name in fields1 .keys ():
54
69
field_info1 = fields1 [field_name ]
55
70
field_info2 = fields2 [field_name ]
56
71
57
- assert field_info1 .annotation == field_info2 .annotation , \
58
- f"Field '{ field_name } ': Annotation mismatch ({ field_info1 .annotation } != { field_info2 .annotation } )"
59
- assert field_info1 .description == field_info2 .description , \
60
- f"Field '{ field_name } ': Description mismatch ('{ field_info1 .description } ' != '{ field_info2 .description } ')"
61
- is_required1 = field_info1 .is_required () if hasattr (field_info1 , 'is_required' ) else not field_info1 .is_nullable ()
62
- is_required2 = field_info2 .is_required () if hasattr (field_info2 , 'is_required' ) else not field_info2 .is_nullable ()
63
- assert is_required1 == is_required2 , \
64
- f"Field '{ field_name } ': Required status mismatch ({ is_required1 } != { is_required2 } )"
72
+ assert (
73
+ field_info1 .annotation == field_info2 .annotation
74
+ ), f"Field '{ field_name } ': Annotation mismatch ({ field_info1 .annotation } != { field_info2 .annotation } )"
75
+ assert (
76
+ field_info1 .description == field_info2 .description
77
+ ), f"Field '{ field_name } ': Description mismatch ('{ field_info1 .description } ' != '{ field_info2 .description } ')"
78
+ is_required1 = (
79
+ field_info1 .is_required ()
80
+ if hasattr (field_info1 , "is_required" )
81
+ else not field_info1 .is_nullable ()
82
+ )
83
+ is_required2 = (
84
+ field_info2 .is_required ()
85
+ if hasattr (field_info2 , "is_required" )
86
+ else not field_info2 .is_nullable ()
87
+ )
88
+ assert (
89
+ is_required1 == is_required2
90
+ ), f"Field '{ field_name } ': Required status mismatch ({ is_required1 } != { is_required2 } )"
65
91
66
92
67
93
class TestToolboxClient :
@@ -78,26 +104,29 @@ def test_load_tool(self, mock_core_load_tool, toolbox_client):
78
104
name = "test_tool_sync" ,
79
105
doc = "Sync tool description." ,
80
106
model_name = "TestToolSyncModel" ,
81
- params = [CoreParameterSchema (name = "sp1" , type = "integer" , description = "Sync Param 1" )]
107
+ params = [
108
+ CoreParameterSchema (
109
+ name = "sp1" , type = "integer" , description = "Sync Param 1"
110
+ )
111
+ ],
82
112
)
83
113
mock_core_load_tool .return_value = mock_core_tool_instance
84
-
114
+
85
115
langchain_tool = toolbox_client .load_tool ("test_tool" )
86
-
116
+
87
117
assert isinstance (langchain_tool , ToolboxTool )
88
118
assert langchain_tool .name == mock_core_tool_instance .__name__
89
119
assert langchain_tool .description == mock_core_tool_instance .__doc__
90
-
120
+
91
121
# Generate the expected schema once for comparison
92
122
expected_args_schema = params_to_pydantic_model (
93
- mock_core_tool_instance ._name ,
94
- mock_core_tool_instance ._params
123
+ mock_core_tool_instance ._name , mock_core_tool_instance ._params
95
124
)
96
-
125
+
97
126
assert_pydantic_models_equivalent (
98
- langchain_tool .args_schema ,
127
+ langchain_tool .args_schema ,
99
128
expected_args_schema ,
100
- mock_core_tool_instance ._name
129
+ mock_core_tool_instance ._name ,
101
130
)
102
131
103
132
mock_core_load_tool .assert_called_once_with (
@@ -106,8 +135,12 @@ def test_load_tool(self, mock_core_load_tool, toolbox_client):
106
135
107
136
@patch ("toolbox_core.sync_client.ToolboxSyncClient.load_toolset" )
108
137
def test_load_toolset (self , mock_core_load_toolset , toolbox_client ):
109
- mock_core_tool_instance1 = create_mock_core_sync_tool (name = "tool-0" , doc = "desc 0" , model_name = "T0Model" )
110
- mock_core_tool_instance2 = create_mock_core_sync_tool (name = "tool-1" , doc = "desc 1" , model_name = "T1Model" , params = [])
138
+ mock_core_tool_instance1 = create_mock_core_sync_tool (
139
+ name = "tool-0" , doc = "desc 0" , model_name = "T0Model"
140
+ )
141
+ mock_core_tool_instance2 = create_mock_core_sync_tool (
142
+ name = "tool-1" , doc = "desc 1" , model_name = "T1Model" , params = []
143
+ )
111
144
112
145
mock_core_load_toolset .return_value = [
113
146
mock_core_tool_instance1 ,
@@ -116,22 +149,21 @@ def test_load_toolset(self, mock_core_load_toolset, toolbox_client):
116
149
117
150
langchain_tools = toolbox_client .load_toolset ()
118
151
assert len (langchain_tools ) == 2
119
-
152
+
120
153
tool_instances_mocks = [mock_core_tool_instance1 , mock_core_tool_instance2 ]
121
154
for i , tool_instance_mock in enumerate (tool_instances_mocks ):
122
155
langchain_tool = langchain_tools [i ]
123
156
assert isinstance (langchain_tool , ToolboxTool )
124
157
assert langchain_tool .name == tool_instance_mock .__name__
125
158
assert langchain_tool .description == tool_instance_mock .__doc__
126
-
159
+
127
160
expected_args_schema = params_to_pydantic_model (
128
- tool_instance_mock ._name ,
129
- tool_instance_mock ._params
161
+ tool_instance_mock ._name , tool_instance_mock ._params
130
162
)
131
163
assert_pydantic_models_equivalent (
132
- langchain_tool .args_schema ,
164
+ langchain_tool .args_schema ,
133
165
expected_args_schema ,
134
- tool_instance_mock ._name
166
+ tool_instance_mock ._name ,
135
167
)
136
168
137
169
mock_core_load_toolset .assert_called_once_with (
@@ -144,7 +176,7 @@ async def test_aload_tool(self, mock_sync_core_load_tool, toolbox_client):
144
176
mock_core_sync_tool_instance = create_mock_core_sync_tool (
145
177
name = "test_async_loaded_tool" ,
146
178
doc = "Async loaded sync tool description." ,
147
- model_name = "AsyncTestToolModel"
179
+ model_name = "AsyncTestToolModel" ,
148
180
)
149
181
mock_sync_core_load_tool .return_value = mock_core_sync_tool_instance
150
182
@@ -153,26 +185,32 @@ async def test_aload_tool(self, mock_sync_core_load_tool, toolbox_client):
153
185
assert isinstance (langchain_tool , ToolboxTool )
154
186
assert langchain_tool .name == mock_core_sync_tool_instance .__name__
155
187
assert langchain_tool .description == mock_core_sync_tool_instance .__doc__
156
-
188
+
157
189
expected_args_schema = params_to_pydantic_model (
158
- mock_core_sync_tool_instance ._name ,
159
- mock_core_sync_tool_instance ._params
190
+ mock_core_sync_tool_instance ._name , mock_core_sync_tool_instance ._params
160
191
)
161
192
assert_pydantic_models_equivalent (
162
- langchain_tool .args_schema ,
193
+ langchain_tool .args_schema ,
163
194
expected_args_schema ,
164
- mock_core_sync_tool_instance ._name
195
+ mock_core_sync_tool_instance ._name ,
165
196
)
166
-
197
+
167
198
mock_sync_core_load_tool .assert_called_once_with (
168
199
name = "test_tool" , auth_token_getters = {}, bound_params = {}
169
200
)
170
201
171
202
@pytest .mark .asyncio
172
203
@patch ("toolbox_core.sync_client.ToolboxSyncClient.load_toolset" )
173
204
async def test_aload_toolset (self , mock_sync_core_load_toolset , toolbox_client ):
174
- mock_core_sync_tool1 = create_mock_core_sync_tool (name = "async-tool-0" , doc = "async desc 0" , model_name = "AT0Model" )
175
- mock_core_sync_tool2 = create_mock_core_sync_tool (name = "async-tool-1" , doc = "async desc 1" , model_name = "AT1Model" , params = [CoreParameterSchema (name = "p1" , type = "string" , description = "P1" )])
205
+ mock_core_sync_tool1 = create_mock_core_sync_tool (
206
+ name = "async-tool-0" , doc = "async desc 0" , model_name = "AT0Model"
207
+ )
208
+ mock_core_sync_tool2 = create_mock_core_sync_tool (
209
+ name = "async-tool-1" ,
210
+ doc = "async desc 1" ,
211
+ model_name = "AT1Model" ,
212
+ params = [CoreParameterSchema (name = "p1" , type = "string" , description = "P1" )],
213
+ )
176
214
177
215
mock_sync_core_load_toolset .return_value = [
178
216
mock_core_sync_tool1 ,
@@ -181,21 +219,20 @@ async def test_aload_toolset(self, mock_sync_core_load_toolset, toolbox_client):
181
219
182
220
langchain_tools = await toolbox_client .aload_toolset ()
183
221
assert len (langchain_tools ) == 2
184
-
222
+
185
223
tool_instances_mocks = [mock_core_sync_tool1 , mock_core_sync_tool2 ]
186
224
for i , tool_instance_mock in enumerate (tool_instances_mocks ):
187
225
langchain_tool = langchain_tools [i ]
188
226
assert isinstance (langchain_tool , ToolboxTool )
189
227
assert langchain_tool .name == tool_instance_mock .__name__
190
-
228
+
191
229
expected_args_schema = params_to_pydantic_model (
192
- tool_instance_mock ._name ,
193
- tool_instance_mock ._params
230
+ tool_instance_mock ._name , tool_instance_mock ._params
194
231
)
195
232
assert_pydantic_models_equivalent (
196
- langchain_tool .args_schema ,
233
+ langchain_tool .args_schema ,
197
234
expected_args_schema ,
198
- tool_instance_mock ._name
235
+ tool_instance_mock ._name ,
199
236
)
200
237
201
238
mock_sync_core_load_toolset .assert_called_once_with (
@@ -223,49 +260,64 @@ def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client):
223
260
assert len (record ) == 2
224
261
messages = sorted ([str (r .message ) for r in record ])
225
262
# Warning for auth_headers when auth_token_getters is also present
226
- assert "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." in messages
263
+ assert (
264
+ "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used."
265
+ in messages
266
+ )
227
267
# Warning for auth_tokens when auth_token_getters is also present
228
- assert "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used." in messages
229
-
268
+ assert (
269
+ "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used."
270
+ in messages
271
+ )
272
+
230
273
assert isinstance (tool , ToolboxTool )
231
274
mock_core_load_tool .assert_called_with (
232
275
name = "test_tool_name" ,
233
276
auth_token_getters = auth_token_getters ,
234
277
bound_params = bound_params ,
235
278
)
236
279
mock_core_load_tool .reset_mock ()
237
-
280
+
238
281
# Scenario 2: auth_tokens and auth_headers provided, auth_token_getters is default (empty initially)
239
282
with pytest .warns (DeprecationWarning ) as record :
240
283
toolbox_client .load_tool (
241
284
"test_tool_name_2" ,
242
- auth_tokens = auth_tokens_deprecated , # This will be used for auth_token_getters
243
- auth_headers = auth_headers_deprecated , # This will warn as auth_token_getters is now populated
285
+ auth_tokens = auth_tokens_deprecated , # This will be used for auth_token_getters
286
+ auth_headers = auth_headers_deprecated , # This will warn as auth_token_getters is now populated
244
287
bound_params = bound_params ,
245
288
)
246
289
assert len (record ) == 2
247
290
messages = sorted ([str (r .message ) for r in record ])
248
-
249
- assert messages [0 ] == "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead."
250
- assert messages [1 ] == "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used."
251
-
252
- expected_getters_for_call = auth_tokens_deprecated
253
-
291
+
292
+ assert (
293
+ messages [0 ]
294
+ == "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead."
295
+ )
296
+ assert (
297
+ messages [1 ]
298
+ == "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used."
299
+ )
300
+
301
+ expected_getters_for_call = auth_tokens_deprecated
302
+
254
303
mock_core_load_tool .assert_called_with (
255
304
name = "test_tool_name_2" ,
256
305
auth_token_getters = expected_getters_for_call ,
257
306
bound_params = bound_params ,
258
307
)
259
308
mock_core_load_tool .reset_mock ()
260
-
261
- with pytest .warns (DeprecationWarning , match = "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead." ) as record :
309
+
310
+ with pytest .warns (
311
+ DeprecationWarning ,
312
+ match = "Argument `auth_headers` is deprecated. Use `auth_token_getters` instead." ,
313
+ ) as record :
262
314
toolbox_client .load_tool (
263
315
"test_tool_name_3" ,
264
316
auth_headers = auth_headers_deprecated ,
265
317
bound_params = bound_params ,
266
318
)
267
319
assert len (record ) == 1
268
-
320
+
269
321
mock_core_load_tool .assert_called_with (
270
322
name = "test_tool_name_3" ,
271
323
auth_token_getters = auth_headers_deprecated ,
@@ -306,7 +358,9 @@ def test_load_toolset_with_args(self, mock_core_load_toolset, toolbox_client):
306
358
@pytest .mark .asyncio
307
359
@patch ("toolbox_core.sync_client.ToolboxSyncClient.load_tool" )
308
360
async def test_aload_tool_with_args (self , mock_sync_core_load_tool , toolbox_client ):
309
- mock_core_tool_instance = create_mock_core_sync_tool (model_name = "MyAsyncToolModel" )
361
+ mock_core_tool_instance = create_mock_core_sync_tool (
362
+ model_name = "MyAsyncToolModel"
363
+ )
310
364
mock_sync_core_load_tool .return_value = mock_core_tool_instance
311
365
312
366
auth_token_getters = {"token_getter1" : lambda : "value1" }
@@ -336,7 +390,9 @@ async def test_aload_tool_with_args(self, mock_sync_core_load_tool, toolbox_clie
336
390
async def test_aload_toolset_with_args (
337
391
self , mock_sync_core_load_toolset , toolbox_client
338
392
):
339
- mock_core_tool_instance = create_mock_core_sync_tool (model_name = "MyAsyncSetModel" )
393
+ mock_core_tool_instance = create_mock_core_sync_tool (
394
+ model_name = "MyAsyncSetModel"
395
+ )
340
396
mock_sync_core_load_toolset .return_value = [mock_core_tool_instance ]
341
397
342
398
auth_token_getters = {"token_getter1" : lambda : "value1" }
0 commit comments