Skip to content

Commit 854c9df

Browse files
committed
chore: Add unit tests previously deleted
1 parent 7a55c25 commit 854c9df

File tree

4 files changed

+964
-0
lines changed

4 files changed

+964
-0
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import AsyncMock, patch
16+
from warnings import catch_warnings, simplefilter
17+
18+
import pytest
19+
from aiohttp import ClientSession
20+
21+
from toolbox_langchain.async_client import AsyncToolboxClient
22+
from toolbox_langchain.async_tools import AsyncToolboxTool
23+
from toolbox_core.protocol import ManifestSchema
24+
25+
URL = "http://test_url"
26+
MANIFEST_JSON = {
27+
"serverVersion": "1.0.0",
28+
"tools": {
29+
"test_tool_1": {
30+
"description": "Test Tool 1 Description",
31+
"parameters": [
32+
{
33+
"name": "param1",
34+
"type": "string",
35+
"description": "Param 1",
36+
}
37+
],
38+
},
39+
"test_tool_2": {
40+
"description": "Test Tool 2 Description",
41+
"parameters": [
42+
{
43+
"name": "param2",
44+
"type": "integer",
45+
"description": "Param 2",
46+
}
47+
],
48+
},
49+
},
50+
}
51+
52+
53+
@pytest.mark.asyncio
54+
class TestAsyncToolboxClient:
55+
@pytest.fixture()
56+
def manifest_schema(self):
57+
return ManifestSchema(**MANIFEST_JSON)
58+
59+
@pytest.fixture()
60+
def mock_session(self):
61+
return AsyncMock(spec=ClientSession)
62+
63+
@pytest.fixture()
64+
def mock_client(self, mock_session):
65+
return AsyncToolboxClient(URL, session=mock_session)
66+
67+
async def test_create_with_existing_session(self, mock_client, mock_session):
68+
assert mock_client._AsyncToolboxClient__session == mock_session
69+
70+
@patch("toolbox_langchain.async_client._load_manifest")
71+
async def test_aload_tool(
72+
self, mock_load_manifest, mock_client, mock_session, manifest_schema
73+
):
74+
tool_name = "test_tool_1"
75+
mock_load_manifest.return_value = manifest_schema
76+
77+
tool = await mock_client.aload_tool(tool_name)
78+
79+
mock_load_manifest.assert_called_once_with(
80+
f"{URL}/api/tool/{tool_name}", mock_session
81+
)
82+
assert isinstance(tool, AsyncToolboxTool)
83+
assert tool.name == tool_name
84+
85+
@patch("toolbox_langchain.async_client._load_manifest")
86+
async def test_aload_tool_auth_headers_deprecated(
87+
self, mock_load_manifest, mock_client, manifest_schema
88+
):
89+
tool_name = "test_tool_1"
90+
mock_manifest = manifest_schema
91+
mock_load_manifest.return_value = mock_manifest
92+
with catch_warnings(record=True) as w:
93+
simplefilter("always")
94+
await mock_client.aload_tool(
95+
tool_name, auth_headers={"Authorization": lambda: "Bearer token"}
96+
)
97+
assert len(w) == 1
98+
assert issubclass(w[-1].category, DeprecationWarning)
99+
assert "auth_headers" in str(w[-1].message)
100+
101+
@patch("toolbox_langchain.async_client._load_manifest")
102+
async def test_aload_tool_auth_headers_and_tokens(
103+
self, mock_load_manifest, mock_client, manifest_schema
104+
):
105+
tool_name = "test_tool_1"
106+
mock_manifest = manifest_schema
107+
mock_load_manifest.return_value = mock_manifest
108+
with catch_warnings(record=True) as w:
109+
simplefilter("always")
110+
await mock_client.aload_tool(
111+
tool_name,
112+
auth_headers={"Authorization": lambda: "Bearer token"},
113+
auth_token_getters={"test": lambda: "token"},
114+
)
115+
assert len(w) == 1
116+
assert issubclass(w[-1].category, DeprecationWarning)
117+
assert "auth_headers" in str(w[-1].message)
118+
119+
@patch("toolbox_langchain.async_client._load_manifest")
120+
async def test_aload_toolset(
121+
self, mock_load_manifest, mock_client, mock_session, manifest_schema
122+
):
123+
mock_manifest = manifest_schema
124+
mock_load_manifest.return_value = mock_manifest
125+
tools = await mock_client.aload_toolset()
126+
127+
mock_load_manifest.assert_called_once_with(f"{URL}/api/toolset/", mock_session)
128+
assert len(tools) == 2
129+
for tool in tools:
130+
assert isinstance(tool, AsyncToolboxTool)
131+
assert tool.name in ["test_tool_1", "test_tool_2"]
132+
133+
@patch("toolbox_langchain.async_client._load_manifest")
134+
async def test_aload_toolset_with_toolset_name(
135+
self, mock_load_manifest, mock_client, mock_session, manifest_schema
136+
):
137+
toolset_name = "test_toolset_1"
138+
mock_manifest = manifest_schema
139+
mock_load_manifest.return_value = mock_manifest
140+
tools = await mock_client.aload_toolset(toolset_name=toolset_name)
141+
142+
mock_load_manifest.assert_called_once_with(
143+
f"{URL}/api/toolset/{toolset_name}", mock_session
144+
)
145+
assert len(tools) == 2
146+
for tool in tools:
147+
assert isinstance(tool, AsyncToolboxTool)
148+
assert tool.name in ["test_tool_1", "test_tool_2"]
149+
150+
@patch("toolbox_langchain.async_client._load_manifest")
151+
async def test_aload_toolset_auth_headers_deprecated(
152+
self, mock_load_manifest, mock_client, manifest_schema
153+
):
154+
mock_manifest = manifest_schema
155+
mock_load_manifest.return_value = mock_manifest
156+
with catch_warnings(record=True) as w:
157+
simplefilter("always")
158+
await mock_client.aload_toolset(
159+
auth_headers={"Authorization": lambda: "Bearer token"}
160+
)
161+
assert len(w) == 1
162+
assert issubclass(w[-1].category, DeprecationWarning)
163+
assert "auth_headers" in str(w[-1].message)
164+
165+
@patch("toolbox_langchain.async_client._load_manifest")
166+
async def test_aload_toolset_auth_headers_and_tokens(
167+
self, mock_load_manifest, mock_client, manifest_schema
168+
):
169+
mock_manifest = manifest_schema
170+
mock_load_manifest.return_value = mock_manifest
171+
with catch_warnings(record=True) as w:
172+
simplefilter("always")
173+
await mock_client.aload_toolset(
174+
auth_headers={"Authorization": lambda: "Bearer token"},
175+
auth_token_getters={"test": lambda: "token"},
176+
)
177+
assert len(w) == 1
178+
assert issubclass(w[-1].category, DeprecationWarning)
179+
assert "auth_headers" in str(w[-1].message)
180+
181+
async def test_load_tool_not_implemented(self, mock_client):
182+
with pytest.raises(NotImplementedError) as excinfo:
183+
mock_client.load_tool("test_tool")
184+
assert "Synchronous methods not supported by async client." in str(
185+
excinfo.value
186+
)
187+
188+
async def test_load_toolset_not_implemented(self, mock_client):
189+
with pytest.raises(NotImplementedError) as excinfo:
190+
mock_client.load_toolset()
191+
assert "Synchronous methods not supported by async client." in str(
192+
excinfo.value
193+
)

0 commit comments

Comments
 (0)