Skip to content

Commit 84b45b6

Browse files
committed
chore: Add unit tests for the tool and client classes
1 parent 986b0f7 commit 84b45b6

File tree

3 files changed

+209
-0
lines changed

3 files changed

+209
-0
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
1416
import types
1517
from typing import Any, Callable, Mapping, Optional, Union
1618

packages/toolbox-core/tests/test_client.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020
import pytest_asyncio
2121
from aioresponses import CallbackResult
22+
from unittest.mock import AsyncMock
2223

2324
from toolbox_core import ToolboxClient
2425
from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema
@@ -301,3 +302,99 @@ async def test_bind_param_fail(self, tool_name, client):
301302

302303
with pytest.raises(Exception):
303304
tool = tool.bind_parameters({"argC": lambda: 5})
305+
306+
307+
@pytest.mark.asyncio
308+
async def test_new_invoke_tool_server_error(aioresponses, test_tool_str):
309+
"""Tests that invoking a tool raises an Exception when the server returns an
310+
error status."""
311+
TOOL_NAME = "server_error_tool"
312+
ERROR_MESSAGE = "Simulated Server Error"
313+
manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str})
314+
315+
aioresponses.get(
316+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}",
317+
payload=manifest.model_dump(),
318+
status=200,
319+
)
320+
aioresponses.post(
321+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke",
322+
payload={"error": ERROR_MESSAGE},
323+
status=500,
324+
)
325+
326+
async with ToolboxClient(TEST_BASE_URL) as client:
327+
loaded_tool = await client.load_tool(TOOL_NAME)
328+
329+
with pytest.raises(Exception, match=ERROR_MESSAGE):
330+
await loaded_tool(param1="some input")
331+
332+
333+
@pytest.mark.asyncio
334+
async def test_bind_param_async_callable_value_success(aioresponses, test_tool_int_bool):
335+
"""
336+
Tests bind_parameters method with an async callable value.
337+
"""
338+
TOOL_NAME = "async_bind_tool"
339+
manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_int_bool})
340+
341+
aioresponses.get(
342+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}",
343+
payload=manifest.model_dump(), status=200
344+
)
345+
346+
def reflect_parameters(url, **kwargs):
347+
received_params = kwargs.get("json", {})
348+
return CallbackResult(status=200, payload={"result": received_params})
349+
350+
aioresponses.post(
351+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke",
352+
callback=reflect_parameters,
353+
)
354+
355+
bound_value_result = True
356+
bound_async_callable = AsyncMock(return_value=bound_value_result)
357+
358+
async with ToolboxClient(TEST_BASE_URL) as client:
359+
tool = await client.load_tool(TOOL_NAME)
360+
bound_tool = tool.bind_parameters({"argB": bound_async_callable})
361+
362+
assert bound_tool is not tool
363+
assert "argB" not in bound_tool.__signature__.parameters
364+
assert "argA" in bound_tool.__signature__.parameters
365+
366+
passed_value_a = 42
367+
res_payload = await bound_tool(argA=passed_value_a)
368+
369+
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
370+
bound_async_callable.assert_awaited_once()
371+
372+
373+
@pytest.mark.asyncio
374+
async def test_new_add_auth_token_getters_duplicate_fail(aioresponses, test_tool_auth):
375+
"""
376+
Tests that adding a duplicate auth token getter raises ValueError.
377+
"""
378+
TOOL_NAME = "duplicate_auth_tool"
379+
AUTH_SERVICE = "my-auth-service"
380+
manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_auth})
381+
382+
aioresponses.get(
383+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}",
384+
payload=manifest.model_dump(), status=200
385+
)
386+
387+
def token_handler_1():
388+
return "token1"
389+
390+
def token_handler_2():
391+
return "token2"
392+
393+
async with ToolboxClient(TEST_BASE_URL) as client:
394+
tool = await client.load_tool(TOOL_NAME)
395+
396+
authed_tool = tool.add_auth_token_getters({AUTH_SERVICE: token_handler_1})
397+
assert AUTH_SERVICE in authed_tool._ToolboxTool__auth_service_token_getters
398+
399+
with pytest.raises(ValueError, match=f"Authentication source\\(s\\) `{AUTH_SERVICE}` already registered in tool `{TOOL_NAME}`."):
400+
authed_tool.add_auth_token_getters({AUTH_SERVICE: token_handler_2})
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from inspect import Parameter
17+
18+
import pytest
19+
20+
from toolbox_core.protocol import ParameterSchema
21+
22+
23+
def test_parameter_schema_float():
24+
"""Tests ParameterSchema with type 'float'."""
25+
schema = ParameterSchema(name="price", type="float", description="The item price")
26+
expected_type = float
27+
assert schema._ParameterSchema__get_type() == expected_type
28+
29+
param = schema.to_param()
30+
assert isinstance(param, Parameter)
31+
assert param.name == "price"
32+
assert param.annotation == expected_type
33+
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
34+
assert param.default == Parameter.empty
35+
36+
37+
def test_parameter_schema_boolean():
38+
"""Tests ParameterSchema with type 'boolean'."""
39+
schema = ParameterSchema(
40+
name="is_active", type="boolean", description="Activity status"
41+
)
42+
expected_type = bool
43+
assert schema._ParameterSchema__get_type() == expected_type
44+
45+
param = schema.to_param()
46+
assert isinstance(param, Parameter)
47+
assert param.name == "is_active"
48+
assert param.annotation == expected_type
49+
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
50+
51+
52+
def test_parameter_schema_array_string():
53+
"""Tests ParameterSchema with type 'array' containing strings."""
54+
item_schema = ParameterSchema(
55+
name="", type="string", description=""
56+
)
57+
schema = ParameterSchema(
58+
name="tags", type="array", description="List of tags", items=item_schema
59+
)
60+
61+
assert schema._ParameterSchema__get_type() == list[str]
62+
63+
param = schema.to_param()
64+
assert isinstance(param, Parameter)
65+
assert param.name == "tags"
66+
assert param.annotation == list[str]
67+
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
68+
69+
70+
def test_parameter_schema_array_integer():
71+
"""Tests ParameterSchema with type 'array' containing integers."""
72+
item_schema = ParameterSchema(name="", type="integer", description="")
73+
schema = ParameterSchema(
74+
name="scores", type="array", description="List of scores", items=item_schema
75+
)
76+
77+
param = schema.to_param()
78+
assert isinstance(param, Parameter)
79+
assert param.name == "scores"
80+
assert param.annotation == list[int]
81+
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
82+
83+
84+
def test_parameter_schema_array_no_items_error():
85+
"""Tests that 'array' type raises error if 'items' is None."""
86+
schema = ParameterSchema(
87+
name="bad_list", type="array", description="List without item type"
88+
)
89+
90+
expected_error_msg = "Unexpected value: type is 'list' but items is None"
91+
with pytest.raises(Exception, match=expected_error_msg):
92+
schema._ParameterSchema__get_type()
93+
94+
with pytest.raises(Exception, match=expected_error_msg):
95+
schema.to_param()
96+
97+
98+
def test_parameter_schema_unsupported_type_error():
99+
"""Tests that an unsupported type raises ValueError."""
100+
unsupported_type = "datetime"
101+
schema = ParameterSchema(
102+
name="event_time", type=unsupported_type, description="When it happened"
103+
)
104+
105+
expected_error_msg = f"Unsupported schema type: {unsupported_type}"
106+
with pytest.raises(ValueError, match=expected_error_msg):
107+
schema._ParameterSchema__get_type()
108+
109+
with pytest.raises(ValueError, match=expected_error_msg):
110+
schema.to_param()

0 commit comments

Comments
 (0)