Skip to content

Commit c26c453

Browse files
committed
chore: add tests for bound parameters
1 parent c8491a9 commit c26c453

File tree

1 file changed

+113
-20
lines changed

1 file changed

+113
-20
lines changed

packages/toolbox-core/tests/test_client.py

Lines changed: 113 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import inspect
17+
import json
1718

1819
import pytest
1920
import pytest_asyncio
@@ -100,6 +101,31 @@ async def test_load_tool_success(aioresponses, test_tool_str):
100101
assert await loaded_tool("some value") == "ok"
101102

102103

104+
@pytest.mark.asyncio
105+
async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool):
106+
"""Tests successfully loading a toolset with multiple tools."""
107+
TOOLSET_NAME = "my_toolset"
108+
TOOL1 = "tool1"
109+
TOOL2 = "tool2"
110+
manifest = ManifestSchema(
111+
serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool}
112+
)
113+
aioresponses.get(
114+
f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}",
115+
payload=manifest.model_dump(),
116+
status=200,
117+
)
118+
119+
async with ToolboxClient(TEST_BASE_URL) as client:
120+
tools = await client.load_toolset(TOOLSET_NAME)
121+
122+
assert isinstance(tools, list)
123+
assert len(tools) == len(manifest.tools)
124+
125+
# Check if tools were created correctly
126+
assert {t.__name__ for t in tools} == manifest.tools.keys()
127+
128+
103129
class TestAuth:
104130

105131
@pytest.fixture
@@ -181,26 +207,93 @@ def token_handler():
181207
res = await tool(5)
182208

183209

184-
@pytest.mark.asyncio
185-
async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_bool):
186-
"""Tests successfully loading a toolset with multiple tools."""
187-
TOOLSET_NAME = "my_toolset"
188-
TOOL1 = "tool1"
189-
TOOL2 = "tool2"
190-
manifest = ManifestSchema(
191-
serverVersion="0.0.0", tools={TOOL1: test_tool_str, TOOL2: test_tool_int_bool}
192-
)
193-
aioresponses.get(
194-
f"{TEST_BASE_URL}/api/toolset/{TOOLSET_NAME}",
195-
payload=manifest.model_dump(),
196-
status=200,
197-
)
210+
class TestBoundParameter:
198211

199-
async with ToolboxClient(TEST_BASE_URL) as client:
200-
tools = await client.load_toolset(TOOLSET_NAME)
212+
@pytest.fixture
213+
def tool_name(self):
214+
return "tool1"
201215

202-
assert isinstance(tools, list)
203-
assert len(tools) == len(manifest.tools)
216+
@pytest_asyncio.fixture
217+
async def client(self, aioresponses, test_tool_int_bool, tool_name):
218+
manifest = ManifestSchema(
219+
serverVersion="0.0.0", tools={tool_name: test_tool_int_bool}
220+
)
204221

205-
# Check if tools were created correctly
206-
assert {t.__name__ for t in tools} == manifest.tools.keys()
222+
# mock toolset GET call
223+
aioresponses.get(
224+
f"{TEST_BASE_URL}/api/toolset/",
225+
payload=manifest.model_dump(),
226+
status=200,
227+
)
228+
229+
# mock tool GET call
230+
aioresponses.get(
231+
f"{TEST_BASE_URL}/api/tool/{tool_name}",
232+
payload=manifest.model_dump(),
233+
status=200,
234+
)
235+
236+
# mock tool INVOKE call
237+
def reflect_parameters(url, **kwargs):
238+
body = {"result": kwargs["json"]}
239+
return CallbackResult(status=200, body=json.dumps(body))
240+
241+
aioresponses.post(
242+
f"{TEST_BASE_URL}/api/tool/{tool_name}/invoke",
243+
payload=manifest.model_dump(),
244+
callback=reflect_parameters,
245+
status=200,
246+
)
247+
248+
async with ToolboxClient(TEST_BASE_URL) as client:
249+
yield client
250+
251+
@pytest.mark.asyncio
252+
async def test_load_tool_success(self, tool_name, client):
253+
"""Tests 'load_tool' with a bound parameter specified."""
254+
tool = await client.load_tool(tool_name, bound_params={"argA": lambda: 5})
255+
256+
assert len(tool.__signature__.parameters) == 1
257+
assert "argA" not in tool.__signature__.parameters
258+
259+
res = await tool(True)
260+
assert "argA" in res
261+
262+
@pytest.mark.asyncio
263+
async def test_load_toolset_success(self, tool_name, client):
264+
"""Tests 'load_toolset' with a bound parameter specified."""
265+
tools = await client.load_toolset("", bound_params={"argB": lambda: "hello"})
266+
tool = tools[0]
267+
268+
assert len(tool.__signature__.parameters) == 1
269+
assert "argB" not in tool.__signature__.parameters
270+
271+
res = await tool(True)
272+
assert "argB" in res
273+
274+
@pytest.mark.asyncio
275+
async def test_bind_param_success(self, tool_name, client):
276+
"""Tests 'bind_param' with a bound parameter specified."""
277+
tool = await client.load_tool(tool_name)
278+
279+
assert len(tool.__signature__.parameters) == 2
280+
assert "argA" in tool.__signature__.parameters
281+
282+
tool = tool.bind_parameters({"argA": lambda: 5})
283+
284+
assert len(tool.__signature__.parameters) == 1
285+
assert "argA" not in tool.__signature__.parameters
286+
287+
res = await tool(True)
288+
assert "argA" in res
289+
290+
@pytest.mark.asyncio
291+
async def test_bind_param_fail(self, tool_name, client):
292+
"""Tests 'bind_param' with a bound parameter that doesn't exist."""
293+
tool = await client.load_tool(tool_name)
294+
295+
assert len(tool.__signature__.parameters) == 2
296+
assert "argA" in tool.__signature__.parameters
297+
298+
with pytest.raises(Exception):
299+
tool = tool.bind_parameters({"argC": lambda: 5})

0 commit comments

Comments
 (0)