Skip to content

Commit d5356b0

Browse files
committed
chore: Update unittests
1 parent 5ed0c9c commit d5356b0

File tree

4 files changed

+692
-483
lines changed

4 files changed

+692
-483
lines changed

packages/toolbox-langchain/tests/test_async_client.py

Lines changed: 127 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717

1818
import pytest
1919
from aiohttp import ClientSession
20+
from toolbox_core.client import ToolboxClient as ToolboxCoreClient
2021
from toolbox_core.protocol import ManifestSchema
22+
from toolbox_core.protocol import ParameterSchema as CoreParameterSchema
23+
from toolbox_core.tool import ToolboxTool as ToolboxCoreTool
24+
from toolbox_core.utils import params_to_pydantic_model
2125

2226
from toolbox_langchain.async_client import AsyncToolboxClient
2327
from toolbox_langchain.async_tools import AsyncToolboxTool
@@ -60,123 +64,200 @@ def manifest_schema(self):
6064
def mock_session(self):
6165
return AsyncMock(spec=ClientSession)
6266

67+
@pytest.fixture
68+
def mock_core_client_instance(self, manifest_schema, mock_session):
69+
mock = AsyncMock(spec=ToolboxCoreClient)
70+
71+
async def mock_load_tool_impl(name, auth_token_getters, bound_params):
72+
tool_schema_dict = MANIFEST_JSON["tools"].get(name)
73+
if not tool_schema_dict:
74+
raise ValueError(f"Tool '{name}' not in mock manifest_dict")
75+
76+
core_params = [
77+
CoreParameterSchema(**p) for p in tool_schema_dict["parameters"]
78+
]
79+
# Return a mock that looks like toolbox_core.tool.ToolboxTool
80+
core_tool_mock = AsyncMock(spec=ToolboxCoreTool)
81+
core_tool_mock.__name__ = name
82+
core_tool_mock.__doc__ = tool_schema_dict["description"]
83+
core_tool_mock._pydantic_model = params_to_pydantic_model(name, core_params)
84+
# Add other necessary attributes or method mocks if AsyncToolboxTool uses them
85+
return core_tool_mock
86+
87+
mock.load_tool = AsyncMock(side_effect=mock_load_tool_impl)
88+
89+
async def mock_load_toolset_impl(
90+
name, auth_token_getters, bound_params, strict
91+
):
92+
core_tools_list = []
93+
for tool_name_iter, tool_schema_dict in MANIFEST_JSON["tools"].items():
94+
core_params = [
95+
CoreParameterSchema(**p) for p in tool_schema_dict["parameters"]
96+
]
97+
core_tool_mock = AsyncMock(spec=ToolboxCoreTool)
98+
core_tool_mock.__name__ = tool_name_iter
99+
core_tool_mock.__doc__ = tool_schema_dict["description"]
100+
core_tool_mock._pydantic_model = params_to_pydantic_model(
101+
tool_name_iter, core_params
102+
)
103+
core_tools_list.append(core_tool_mock)
104+
return core_tools_list
105+
106+
mock.load_toolset = AsyncMock(side_effect=mock_load_toolset_impl)
107+
# Mock the session attribute if it's directly accessed by AsyncToolboxClient tests
108+
mock._ToolboxClient__session = mock_session
109+
return mock
110+
63111
@pytest.fixture()
64-
def mock_client(self, mock_session):
65-
return AsyncToolboxClient(URL, session=mock_session)
112+
def mock_client(self, mock_session, mock_core_client_instance):
113+
# Patch the ToolboxCoreClient constructor used by AsyncToolboxClient
114+
with patch(
115+
"toolbox_langchain.async_client.ToolboxCoreClient",
116+
return_value=mock_core_client_instance,
117+
):
118+
client = AsyncToolboxClient(URL, session=mock_session)
119+
# Ensure the mocked core client is used
120+
client._AsyncToolboxClient__core_client = mock_core_client_instance
121+
return client
66122

67123
async def test_create_with_existing_session(self, mock_client, mock_session):
68-
assert mock_client._AsyncToolboxClient__session == mock_session
124+
# AsyncToolboxClient stores the core_client, which stores the session
125+
assert (
126+
mock_client._AsyncToolboxClient__core_client._ToolboxClient__session
127+
== mock_session
128+
)
69129

70-
@patch("toolbox_langchain.async_client._load_manifest")
71130
async def test_aload_tool(
72-
self, mock_load_manifest, mock_client, mock_session, manifest_schema
131+
self,
132+
mock_client,
133+
manifest_schema, # mock_session removed as it's part of mock_core_client_instance
73134
):
74135
tool_name = "test_tool_1"
75-
mock_load_manifest.return_value = manifest_schema
136+
# manifest_schema is used by mock_core_client_instance fixture to provide tool details
76137

77138
tool = await mock_client.aload_tool(tool_name)
78139

79-
mock_load_manifest.assert_called_once_with(
80-
f"{URL}/api/tool/{tool_name}", mock_session
140+
# Assert that the core client's load_tool was called correctly
141+
mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with(
142+
name=tool_name, auth_token_getters={}, bound_params={}
81143
)
82144
assert isinstance(tool, AsyncToolboxTool)
83-
assert tool.name == tool_name
145+
assert (
146+
tool.name == tool_name
147+
) # AsyncToolboxTool gets its name from the core_tool
84148

85-
@patch("toolbox_langchain.async_client._load_manifest")
86149
async def test_aload_tool_auth_headers_deprecated(
87-
self, mock_load_manifest, mock_client, manifest_schema
150+
self, mock_client, manifest_schema
88151
):
89152
tool_name = "test_tool_1"
90-
mock_manifest = manifest_schema
91-
mock_load_manifest.return_value = mock_manifest
153+
auth_lambda = lambda: "Bearer token" # Define lambda once
92154
with catch_warnings(record=True) as w:
93155
simplefilter("always")
94156
await mock_client.aload_tool(
95-
tool_name, auth_headers={"Authorization": lambda: "Bearer token"}
157+
tool_name,
158+
auth_headers={"Authorization": auth_lambda}, # Use the defined lambda
96159
)
97160
assert len(w) == 1
98161
assert issubclass(w[-1].category, DeprecationWarning)
99162
assert "auth_headers" in str(w[-1].message)
100163

101-
@patch("toolbox_langchain.async_client._load_manifest")
164+
mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with(
165+
name=tool_name,
166+
auth_token_getters={"Authorization": auth_lambda},
167+
bound_params={},
168+
)
169+
102170
async def test_aload_tool_auth_headers_and_tokens(
103-
self, mock_load_manifest, mock_client, manifest_schema
171+
self, mock_client, manifest_schema
104172
):
105173
tool_name = "test_tool_1"
106-
mock_manifest = manifest_schema
107-
mock_load_manifest.return_value = mock_manifest
174+
auth_getters = {"test": lambda: "token"}
175+
auth_headers_lambda = lambda: "Bearer token" # Define lambda once
176+
108177
with catch_warnings(record=True) as w:
109178
simplefilter("always")
110179
await mock_client.aload_tool(
111180
tool_name,
112-
auth_headers={"Authorization": lambda: "Bearer token"},
113-
auth_token_getters={"test": lambda: "token"},
181+
auth_headers={
182+
"Authorization": auth_headers_lambda
183+
}, # Use defined lambda
184+
auth_token_getters=auth_getters,
114185
)
115-
assert len(w) == 1
186+
assert (
187+
len(w) == 1
188+
) # Only one warning because auth_token_getters takes precedence
116189
assert issubclass(w[-1].category, DeprecationWarning)
117-
assert "auth_headers" in str(w[-1].message)
190+
assert "auth_headers" in str(w[-1].message) # Warning for auth_headers
191+
192+
mock_client._AsyncToolboxClient__core_client.load_tool.assert_called_once_with(
193+
name=tool_name, auth_token_getters=auth_getters, bound_params={}
194+
)
118195

119-
@patch("toolbox_langchain.async_client._load_manifest")
120196
async def test_aload_toolset(
121-
self, mock_load_manifest, mock_client, mock_session, manifest_schema
197+
self, mock_client, manifest_schema # mock_session removed
122198
):
123-
mock_manifest = manifest_schema
124-
mock_load_manifest.return_value = mock_manifest
125199
tools = await mock_client.aload_toolset()
126200

127-
mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session)
128-
assert len(tools) == 2
201+
mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with(
202+
name=None, auth_token_getters={}, bound_params={}, strict=False
203+
)
204+
assert len(tools) == 2 # Based on MANIFEST_JSON
129205
for tool in tools:
130206
assert isinstance(tool, AsyncToolboxTool)
131207
assert tool.name in ["test_tool_1", "test_tool_2"]
132208

133-
@patch("toolbox_langchain.async_client._load_manifest")
134209
async def test_aload_toolset_with_toolset_name(
135-
self, mock_load_manifest, mock_client, mock_session, manifest_schema
210+
self, mock_client, manifest_schema # mock_session removed
136211
):
137-
toolset_name = "test_toolset_1"
138-
mock_manifest = manifest_schema
139-
mock_load_manifest.return_value = mock_manifest
212+
toolset_name = "test_toolset_1" # This name isn't in MANIFEST_JSON, but load_toolset mock doesn't filter by it
140213
tools = await mock_client.aload_toolset(toolset_name=toolset_name)
141214

142-
mock_load_manifest.assert_called_once_with(
143-
f"{URL}/api/toolset/{toolset_name}", mock_session
215+
mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with(
216+
name=toolset_name, auth_token_getters={}, bound_params={}, strict=False
144217
)
145218
assert len(tools) == 2
146219
for tool in tools:
147220
assert isinstance(tool, AsyncToolboxTool)
148221
assert tool.name in ["test_tool_1", "test_tool_2"]
149222

150-
@patch("toolbox_langchain.async_client._load_manifest")
151223
async def test_aload_toolset_auth_headers_deprecated(
152-
self, mock_load_manifest, mock_client, manifest_schema
224+
self, mock_client, manifest_schema
153225
):
154-
mock_manifest = manifest_schema
155-
mock_load_manifest.return_value = mock_manifest
226+
auth_lambda = lambda: "Bearer token" # Define lambda once
156227
with catch_warnings(record=True) as w:
157228
simplefilter("always")
158229
await mock_client.aload_toolset(
159-
auth_headers={"Authorization": lambda: "Bearer token"}
230+
auth_headers={"Authorization": auth_lambda} # Use defined lambda
160231
)
161232
assert len(w) == 1
162233
assert issubclass(w[-1].category, DeprecationWarning)
163234
assert "auth_headers" in str(w[-1].message)
235+
mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with(
236+
name=None,
237+
auth_token_getters={"Authorization": auth_lambda},
238+
bound_params={},
239+
strict=False,
240+
)
164241

165-
@patch("toolbox_langchain.async_client._load_manifest")
166242
async def test_aload_toolset_auth_headers_and_tokens(
167-
self, mock_load_manifest, mock_client, manifest_schema
243+
self, mock_client, manifest_schema
168244
):
169-
mock_manifest = manifest_schema
170-
mock_load_manifest.return_value = mock_manifest
245+
auth_getters = {"test": lambda: "token"}
246+
auth_headers_lambda = lambda: "Bearer token" # Define lambda once
171247
with catch_warnings(record=True) as w:
172248
simplefilter("always")
173249
await mock_client.aload_toolset(
174-
auth_headers={"Authorization": lambda: "Bearer token"},
175-
auth_token_getters={"test": lambda: "token"},
250+
auth_headers={
251+
"Authorization": auth_headers_lambda
252+
}, # Use defined lambda
253+
auth_token_getters=auth_getters,
176254
)
177255
assert len(w) == 1
178256
assert issubclass(w[-1].category, DeprecationWarning)
179257
assert "auth_headers" in str(w[-1].message)
258+
mock_client._AsyncToolboxClient__core_client.load_toolset.assert_called_once_with(
259+
name=None, auth_token_getters=auth_getters, bound_params={}, strict=False
260+
)
180261

181262
async def test_load_tool_not_implemented(self, mock_client):
182263
with pytest.raises(NotImplementedError) as excinfo:

0 commit comments

Comments
 (0)