Skip to content

Commit 3637bba

Browse files
committed
chore: Delint
1 parent e6cfed2 commit 3637bba

File tree

1 file changed

+123
-67
lines changed

1 file changed

+123
-67
lines changed

packages/toolbox-langchain/tests/test_client.py

Lines changed: 123 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -16,52 +16,78 @@
1616

1717
import pytest
1818
from pydantic import BaseModel
19-
20-
from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool
2119
from toolbox_core.protocol import ParameterSchema as CoreParameterSchema
20+
from toolbox_core.sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool
2221
from toolbox_core.utils import params_to_pydantic_model
2322

2423
from toolbox_langchain.client import ToolboxClient
2524
from toolbox_langchain.tools import ToolboxTool
2625

2726
URL = "http://test_url"
2827

29-
def create_mock_core_sync_tool(name="mock-sync-tool", doc="Mock sync description.", model_name="MockSyncModel", params=None):
28+
29+
def create_mock_core_sync_tool(
30+
name="mock-sync-tool",
31+
doc="Mock sync description.",
32+
model_name="MockSyncModel",
33+
params=None,
34+
):
3035
mock_tool = Mock(spec=ToolboxCoreSyncTool)
3136
mock_tool.__name__ = name
3237
mock_tool.__doc__ = doc
3338
mock_tool._name = model_name
3439
if params is None:
35-
mock_tool._params = [CoreParameterSchema(name="param1", type="string", description="Param 1")]
40+
mock_tool._params = [
41+
CoreParameterSchema(name="param1", type="string", description="Param 1")
42+
]
3643
else:
3744
mock_tool._params = params
3845
return mock_tool
3946

40-
def assert_pydantic_models_equivalent(model_cls1: type[BaseModel], model_cls2: type[BaseModel], expected_model_name: str):
47+
48+
def assert_pydantic_models_equivalent(
49+
model_cls1: type[BaseModel], model_cls2: type[BaseModel], expected_model_name: str
50+
):
4151
assert issubclass(model_cls1, BaseModel), "model_cls1 is not a Pydantic BaseModel"
4252
assert issubclass(model_cls2, BaseModel), "model_cls2 is not a Pydantic BaseModel"
43-
44-
assert model_cls1.__name__ == expected_model_name, f"model_cls1 name mismatch: expected {expected_model_name}, got {model_cls1.__name__}"
45-
assert model_cls2.__name__ == expected_model_name, f"model_cls2 name mismatch: expected {expected_model_name}, got {model_cls2.__name__}"
53+
54+
assert (
55+
model_cls1.__name__ == expected_model_name
56+
), f"model_cls1 name mismatch: expected {expected_model_name}, got {model_cls1.__name__}"
57+
assert (
58+
model_cls2.__name__ == expected_model_name
59+
), f"model_cls2 name mismatch: expected {expected_model_name}, got {model_cls2.__name__}"
4660

4761
fields1 = model_cls1.model_fields
4862
fields2 = model_cls2.model_fields
4963

50-
assert fields1.keys() == fields2.keys(), \
51-
f"Field names mismatch: {fields1.keys()} != {fields2.keys()}"
64+
assert (
65+
fields1.keys() == fields2.keys()
66+
), f"Field names mismatch: {fields1.keys()} != {fields2.keys()}"
5267

5368
for field_name in fields1.keys():
5469
field_info1 = fields1[field_name]
5570
field_info2 = fields2[field_name]
5671

57-
assert field_info1.annotation == field_info2.annotation, \
58-
f"Field '{field_name}': Annotation mismatch ({field_info1.annotation} != {field_info2.annotation})"
59-
assert field_info1.description == field_info2.description, \
60-
f"Field '{field_name}': Description mismatch ('{field_info1.description}' != '{field_info2.description}')"
61-
is_required1 = field_info1.is_required() if hasattr(field_info1, 'is_required') else not field_info1.is_nullable()
62-
is_required2 = field_info2.is_required() if hasattr(field_info2, 'is_required') else not field_info2.is_nullable()
63-
assert is_required1 == is_required2, \
64-
f"Field '{field_name}': Required status mismatch ({is_required1} != {is_required2})"
72+
assert (
73+
field_info1.annotation == field_info2.annotation
74+
), f"Field '{field_name}': Annotation mismatch ({field_info1.annotation} != {field_info2.annotation})"
75+
assert (
76+
field_info1.description == field_info2.description
77+
), f"Field '{field_name}': Description mismatch ('{field_info1.description}' != '{field_info2.description}')"
78+
is_required1 = (
79+
field_info1.is_required()
80+
if hasattr(field_info1, "is_required")
81+
else not field_info1.is_nullable()
82+
)
83+
is_required2 = (
84+
field_info2.is_required()
85+
if hasattr(field_info2, "is_required")
86+
else not field_info2.is_nullable()
87+
)
88+
assert (
89+
is_required1 == is_required2
90+
), f"Field '{field_name}': Required status mismatch ({is_required1} != {is_required2})"
6591

6692

6793
class TestToolboxClient:
@@ -78,26 +104,29 @@ def test_load_tool(self, mock_core_load_tool, toolbox_client):
78104
name="test_tool_sync",
79105
doc="Sync tool description.",
80106
model_name="TestToolSyncModel",
81-
params=[CoreParameterSchema(name="sp1", type="integer", description="Sync Param 1")]
107+
params=[
108+
CoreParameterSchema(
109+
name="sp1", type="integer", description="Sync Param 1"
110+
)
111+
],
82112
)
83113
mock_core_load_tool.return_value = mock_core_tool_instance
84-
114+
85115
langchain_tool = toolbox_client.load_tool("test_tool")
86-
116+
87117
assert isinstance(langchain_tool, ToolboxTool)
88118
assert langchain_tool.name == mock_core_tool_instance.__name__
89119
assert langchain_tool.description == mock_core_tool_instance.__doc__
90-
120+
91121
# Generate the expected schema once for comparison
92122
expected_args_schema = params_to_pydantic_model(
93-
mock_core_tool_instance._name,
94-
mock_core_tool_instance._params
123+
mock_core_tool_instance._name, mock_core_tool_instance._params
95124
)
96-
125+
97126
assert_pydantic_models_equivalent(
98-
langchain_tool.args_schema,
127+
langchain_tool.args_schema,
99128
expected_args_schema,
100-
mock_core_tool_instance._name
129+
mock_core_tool_instance._name,
101130
)
102131

103132
mock_core_load_tool.assert_called_once_with(
@@ -106,8 +135,12 @@ def test_load_tool(self, mock_core_load_tool, toolbox_client):
106135

107136
@patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset")
108137
def test_load_toolset(self, mock_core_load_toolset, toolbox_client):
109-
mock_core_tool_instance1 = create_mock_core_sync_tool(name="tool-0", doc="desc 0", model_name="T0Model")
110-
mock_core_tool_instance2 = create_mock_core_sync_tool(name="tool-1", doc="desc 1", model_name="T1Model", params=[])
138+
mock_core_tool_instance1 = create_mock_core_sync_tool(
139+
name="tool-0", doc="desc 0", model_name="T0Model"
140+
)
141+
mock_core_tool_instance2 = create_mock_core_sync_tool(
142+
name="tool-1", doc="desc 1", model_name="T1Model", params=[]
143+
)
111144

112145
mock_core_load_toolset.return_value = [
113146
mock_core_tool_instance1,
@@ -116,22 +149,21 @@ def test_load_toolset(self, mock_core_load_toolset, toolbox_client):
116149

117150
langchain_tools = toolbox_client.load_toolset()
118151
assert len(langchain_tools) == 2
119-
152+
120153
tool_instances_mocks = [mock_core_tool_instance1, mock_core_tool_instance2]
121154
for i, tool_instance_mock in enumerate(tool_instances_mocks):
122155
langchain_tool = langchain_tools[i]
123156
assert isinstance(langchain_tool, ToolboxTool)
124157
assert langchain_tool.name == tool_instance_mock.__name__
125158
assert langchain_tool.description == tool_instance_mock.__doc__
126-
159+
127160
expected_args_schema = params_to_pydantic_model(
128-
tool_instance_mock._name,
129-
tool_instance_mock._params
161+
tool_instance_mock._name, tool_instance_mock._params
130162
)
131163
assert_pydantic_models_equivalent(
132-
langchain_tool.args_schema,
164+
langchain_tool.args_schema,
133165
expected_args_schema,
134-
tool_instance_mock._name
166+
tool_instance_mock._name,
135167
)
136168

137169
mock_core_load_toolset.assert_called_once_with(
@@ -144,7 +176,7 @@ async def test_aload_tool(self, mock_sync_core_load_tool, toolbox_client):
144176
mock_core_sync_tool_instance = create_mock_core_sync_tool(
145177
name="test_async_loaded_tool",
146178
doc="Async loaded sync tool description.",
147-
model_name="AsyncTestToolModel"
179+
model_name="AsyncTestToolModel",
148180
)
149181
mock_sync_core_load_tool.return_value = mock_core_sync_tool_instance
150182

@@ -153,26 +185,32 @@ async def test_aload_tool(self, mock_sync_core_load_tool, toolbox_client):
153185
assert isinstance(langchain_tool, ToolboxTool)
154186
assert langchain_tool.name == mock_core_sync_tool_instance.__name__
155187
assert langchain_tool.description == mock_core_sync_tool_instance.__doc__
156-
188+
157189
expected_args_schema = params_to_pydantic_model(
158-
mock_core_sync_tool_instance._name,
159-
mock_core_sync_tool_instance._params
190+
mock_core_sync_tool_instance._name, mock_core_sync_tool_instance._params
160191
)
161192
assert_pydantic_models_equivalent(
162-
langchain_tool.args_schema,
193+
langchain_tool.args_schema,
163194
expected_args_schema,
164-
mock_core_sync_tool_instance._name
195+
mock_core_sync_tool_instance._name,
165196
)
166-
197+
167198
mock_sync_core_load_tool.assert_called_once_with(
168199
name="test_tool", auth_token_getters={}, bound_params={}
169200
)
170201

171202
@pytest.mark.asyncio
172203
@patch("toolbox_core.sync_client.ToolboxSyncClient.load_toolset")
173204
async def test_aload_toolset(self, mock_sync_core_load_toolset, toolbox_client):
174-
mock_core_sync_tool1 = create_mock_core_sync_tool(name="async-tool-0", doc="async desc 0", model_name="AT0Model")
175-
mock_core_sync_tool2 = create_mock_core_sync_tool(name="async-tool-1", doc="async desc 1", model_name="AT1Model", params=[CoreParameterSchema(name="p1", type="string", description="P1")])
205+
mock_core_sync_tool1 = create_mock_core_sync_tool(
206+
name="async-tool-0", doc="async desc 0", model_name="AT0Model"
207+
)
208+
mock_core_sync_tool2 = create_mock_core_sync_tool(
209+
name="async-tool-1",
210+
doc="async desc 1",
211+
model_name="AT1Model",
212+
params=[CoreParameterSchema(name="p1", type="string", description="P1")],
213+
)
176214

177215
mock_sync_core_load_toolset.return_value = [
178216
mock_core_sync_tool1,
@@ -181,21 +219,20 @@ async def test_aload_toolset(self, mock_sync_core_load_toolset, toolbox_client):
181219

182220
langchain_tools = await toolbox_client.aload_toolset()
183221
assert len(langchain_tools) == 2
184-
222+
185223
tool_instances_mocks = [mock_core_sync_tool1, mock_core_sync_tool2]
186224
for i, tool_instance_mock in enumerate(tool_instances_mocks):
187225
langchain_tool = langchain_tools[i]
188226
assert isinstance(langchain_tool, ToolboxTool)
189227
assert langchain_tool.name == tool_instance_mock.__name__
190-
228+
191229
expected_args_schema = params_to_pydantic_model(
192-
tool_instance_mock._name,
193-
tool_instance_mock._params
230+
tool_instance_mock._name, tool_instance_mock._params
194231
)
195232
assert_pydantic_models_equivalent(
196-
langchain_tool.args_schema,
233+
langchain_tool.args_schema,
197234
expected_args_schema,
198-
tool_instance_mock._name
235+
tool_instance_mock._name,
199236
)
200237

201238
mock_sync_core_load_toolset.assert_called_once_with(
@@ -223,49 +260,64 @@ def test_load_tool_with_args(self, mock_core_load_tool, toolbox_client):
223260
assert len(record) == 2
224261
messages = sorted([str(r.message) for r in record])
225262
# Warning for auth_headers when auth_token_getters is also present
226-
assert "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used." in messages
263+
assert (
264+
"Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used."
265+
in messages
266+
)
227267
# Warning for auth_tokens when auth_token_getters is also present
228-
assert "Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used." in messages
229-
268+
assert (
269+
"Both `auth_token_getters` and `auth_tokens` are provided. `auth_tokens` is deprecated, and `auth_token_getters` will be used."
270+
in messages
271+
)
272+
230273
assert isinstance(tool, ToolboxTool)
231274
mock_core_load_tool.assert_called_with(
232275
name="test_tool_name",
233276
auth_token_getters=auth_token_getters,
234277
bound_params=bound_params,
235278
)
236279
mock_core_load_tool.reset_mock()
237-
280+
238281
# Scenario 2: auth_tokens and auth_headers provided, auth_token_getters is default (empty initially)
239282
with pytest.warns(DeprecationWarning) as record:
240283
toolbox_client.load_tool(
241284
"test_tool_name_2",
242-
auth_tokens=auth_tokens_deprecated, # This will be used for auth_token_getters
243-
auth_headers=auth_headers_deprecated, # This will warn as auth_token_getters is now populated
285+
auth_tokens=auth_tokens_deprecated, # This will be used for auth_token_getters
286+
auth_headers=auth_headers_deprecated, # This will warn as auth_token_getters is now populated
244287
bound_params=bound_params,
245288
)
246289
assert len(record) == 2
247290
messages = sorted([str(r.message) for r in record])
248-
249-
assert messages[0] == "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead."
250-
assert messages[1] == "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used."
251-
252-
expected_getters_for_call = auth_tokens_deprecated
253-
291+
292+
assert (
293+
messages[0]
294+
== "Argument `auth_tokens` is deprecated. Use `auth_token_getters` instead."
295+
)
296+
assert (
297+
messages[1]
298+
== "Both `auth_token_getters` and `auth_headers` are provided. `auth_headers` is deprecated, and `auth_token_getters` will be used."
299+
)
300+
301+
expected_getters_for_call = auth_tokens_deprecated
302+
254303
mock_core_load_tool.assert_called_with(
255304
name="test_tool_name_2",
256305
auth_token_getters=expected_getters_for_call,
257306
bound_params=bound_params,
258307
)
259308
mock_core_load_tool.reset_mock()
260-
261-
with pytest.warns(DeprecationWarning, match="Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.") as record:
309+
310+
with pytest.warns(
311+
DeprecationWarning,
312+
match="Argument `auth_headers` is deprecated. Use `auth_token_getters` instead.",
313+
) as record:
262314
toolbox_client.load_tool(
263315
"test_tool_name_3",
264316
auth_headers=auth_headers_deprecated,
265317
bound_params=bound_params,
266318
)
267319
assert len(record) == 1
268-
320+
269321
mock_core_load_tool.assert_called_with(
270322
name="test_tool_name_3",
271323
auth_token_getters=auth_headers_deprecated,
@@ -306,7 +358,9 @@ def test_load_toolset_with_args(self, mock_core_load_toolset, toolbox_client):
306358
@pytest.mark.asyncio
307359
@patch("toolbox_core.sync_client.ToolboxSyncClient.load_tool")
308360
async def test_aload_tool_with_args(self, mock_sync_core_load_tool, toolbox_client):
309-
mock_core_tool_instance = create_mock_core_sync_tool(model_name="MyAsyncToolModel")
361+
mock_core_tool_instance = create_mock_core_sync_tool(
362+
model_name="MyAsyncToolModel"
363+
)
310364
mock_sync_core_load_tool.return_value = mock_core_tool_instance
311365

312366
auth_token_getters = {"token_getter1": lambda: "value1"}
@@ -336,7 +390,9 @@ async def test_aload_tool_with_args(self, mock_sync_core_load_tool, toolbox_clie
336390
async def test_aload_toolset_with_args(
337391
self, mock_sync_core_load_toolset, toolbox_client
338392
):
339-
mock_core_tool_instance = create_mock_core_sync_tool(model_name="MyAsyncSetModel")
393+
mock_core_tool_instance = create_mock_core_sync_tool(
394+
model_name="MyAsyncSetModel"
395+
)
340396
mock_sync_core_load_toolset.return_value = [mock_core_tool_instance]
341397

342398
auth_token_getters = {"token_getter1": lambda: "value1"}

0 commit comments

Comments
 (0)