Skip to content

Commit d1d3aaa

Browse files
committed
chore: Add test cases for sync and static bound parameter.
1 parent 2991eca commit d1d3aaa

File tree

1 file changed

+103
-3
lines changed

1 file changed

+103
-3
lines changed

packages/toolbox-core/tests/test_client.py

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import inspect
1717
import json
18-
from unittest.mock import AsyncMock
18+
from unittest.mock import AsyncMock, Mock
1919

2020
import pytest
2121
import pytest_asyncio
@@ -284,6 +284,22 @@ async def test_bind_param_success(self, tool_name, client):
284284
assert len(tool.__signature__.parameters) == 2
285285
assert "argA" in tool.__signature__.parameters
286286

287+
tool = tool.bind_parameters({"argA": 5})
288+
289+
assert len(tool.__signature__.parameters) == 1
290+
assert "argA" not in tool.__signature__.parameters
291+
292+
res = await tool(True)
293+
assert "argA" in res
294+
295+
@pytest.mark.asyncio
296+
async def test_bind_callable_param_success(self, tool_name, client):
297+
"""Tests 'bind_param' with a bound parameter specified."""
298+
tool = await client.load_tool(tool_name)
299+
300+
assert len(tool.__signature__.parameters) == 2
301+
assert "argA" in tool.__signature__.parameters
302+
287303
tool = tool.bind_parameters({"argA": lambda: 5})
288304

289305
assert len(tool.__signature__.parameters) == 1
@@ -305,7 +321,7 @@ async def test_bind_param_fail(self, tool_name, client):
305321

306322

307323
@pytest.mark.asyncio
308-
async def test_new_invoke_tool_server_error(aioresponses, test_tool_str):
324+
async def test_invoke_tool_server_error(aioresponses, test_tool_str):
309325
"""Tests that invoking a tool raises an Exception when the server returns an
310326
error status."""
311327
TOOL_NAME = "server_error_tool"
@@ -330,6 +346,90 @@ async def test_new_invoke_tool_server_error(aioresponses, test_tool_str):
330346
await loaded_tool(param1="some input")
331347

332348

349+
@pytest.mark.asyncio
350+
async def test_bind_param_static_value_success(aioresponses, test_tool_int_bool):
351+
"""
352+
Tests bind_parameters method with a static value.
353+
"""
354+
TOOL_NAME = "async_bind_tool"
355+
manifest = ManifestSchema(
356+
serverVersion="0.0.0", tools={TOOL_NAME: test_tool_int_bool}
357+
)
358+
359+
aioresponses.get(
360+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}",
361+
payload=manifest.model_dump(),
362+
status=200,
363+
)
364+
365+
def reflect_parameters(url, **kwargs):
366+
received_params = kwargs.get("json", {})
367+
return CallbackResult(status=200, payload={"result": received_params})
368+
369+
aioresponses.post(
370+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke",
371+
callback=reflect_parameters,
372+
)
373+
374+
bound_value = "Test value"
375+
376+
async with ToolboxClient(TEST_BASE_URL) as client:
377+
tool = await client.load_tool(TOOL_NAME)
378+
bound_tool = tool.bind_parameters({"argB": bound_value})
379+
380+
assert bound_tool is not tool
381+
assert "argB" not in bound_tool.__signature__.parameters
382+
assert "argA" in bound_tool.__signature__.parameters
383+
384+
passed_value_a = 42
385+
res_payload = await bound_tool(argA=passed_value_a)
386+
387+
assert res_payload == {"argA": passed_value_a, "argB": bound_value}
388+
389+
390+
@pytest.mark.asyncio
391+
async def test_bind_param_sync_callable_value_success(aioresponses, test_tool_int_bool):
392+
"""
393+
Tests bind_parameters method with a sync callable value.
394+
"""
395+
TOOL_NAME = "async_bind_tool"
396+
manifest = ManifestSchema(
397+
serverVersion="0.0.0", tools={TOOL_NAME: test_tool_int_bool}
398+
)
399+
400+
aioresponses.get(
401+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}",
402+
payload=manifest.model_dump(),
403+
status=200,
404+
)
405+
406+
def reflect_parameters(url, **kwargs):
407+
received_params = kwargs.get("json", {})
408+
return CallbackResult(status=200, payload={"result": received_params})
409+
410+
aioresponses.post(
411+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke",
412+
callback=reflect_parameters,
413+
)
414+
415+
bound_value_result = True
416+
bound_sync_callable = Mock(return_value=bound_value_result)
417+
418+
async with ToolboxClient(TEST_BASE_URL) as client:
419+
tool = await client.load_tool(TOOL_NAME)
420+
bound_tool = tool.bind_parameters({"argB": bound_sync_callable})
421+
422+
assert bound_tool is not tool
423+
assert "argB" not in bound_tool.__signature__.parameters
424+
assert "argA" in bound_tool.__signature__.parameters
425+
426+
passed_value_a = 42
427+
res_payload = await bound_tool(argA=passed_value_a)
428+
429+
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
430+
bound_sync_callable.assert_called_once()
431+
432+
333433
@pytest.mark.asyncio
334434
async def test_bind_param_async_callable_value_success(
335435
aioresponses, test_tool_int_bool
@@ -376,7 +476,7 @@ def reflect_parameters(url, **kwargs):
376476

377477

378478
@pytest.mark.asyncio
379-
async def test_new_add_auth_token_getters_duplicate_fail(aioresponses, test_tool_auth):
479+
async def test_add_auth_token_getters_duplicate_fail(aioresponses, test_tool_auth):
380480
"""
381481
Tests that adding a duplicate auth token getter raises ValueError.
382482
"""

0 commit comments

Comments
 (0)