Skip to content

Commit b0ea075

Browse files
committed
chore: Add client and protocol unit tests
1 parent 986b0f7 commit b0ea075

File tree

2 files changed

+240
-0
lines changed

2 files changed

+240
-0
lines changed

packages/toolbox-core/tests/test_client.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020
import pytest_asyncio
2121
from aioresponses import CallbackResult
22+
from unittest.mock import AsyncMock
2223

2324
from toolbox_core import ToolboxClient
2425
from toolbox_core.protocol import ManifestSchema, ParameterSchema, ToolSchema
@@ -301,3 +302,118 @@ async def test_bind_param_fail(self, tool_name, client):
301302

302303
with pytest.raises(Exception):
303304
tool = tool.bind_parameters({"argC": lambda: 5})
305+
306+
307+
@pytest.mark.asyncio
308+
async def test_new_invoke_tool_server_error(aioresponses, test_tool_str):
309+
"""Tests that invoking a tool raises an Exception when the server returns an
310+
error status."""
311+
TOOL_NAME = "server_error_tool"
312+
ERROR_MESSAGE = "Simulated Server Error"
313+
manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str})
314+
315+
aioresponses.get(
316+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}",
317+
payload=manifest.model_dump(),
318+
status=200,
319+
)
320+
# Mock POST invoke call to return a server error status
321+
aioresponses.post(
322+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke",
323+
payload={"error": ERROR_MESSAGE},
324+
status=500, # Simulate server error
325+
)
326+
327+
async with ToolboxClient(TEST_BASE_URL) as client:
328+
loaded_tool = await client.load_tool(TOOL_NAME)
329+
330+
# Assert that calling the tool raises an Exception with the server's error message
331+
with pytest.raises(Exception, match=ERROR_MESSAGE):
332+
# Ensure required parameters are passed for the call attempt
333+
await loaded_tool(param1="some input")
334+
335+
336+
@pytest.mark.asyncio
337+
async def test_bind_param_async_callable_value_success(aioresponses, test_tool_int_bool):
338+
"""
339+
Tests bind_parameters method with an async callable value.
340+
Covers: `if asyncio.iscoroutinefunction(value):`
341+
"""
342+
TOOL_NAME = "async_bind_tool"
343+
manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_int_bool})
344+
345+
# Mock GET tool definition
346+
aioresponses.get(
347+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}",
348+
payload=manifest.model_dump(), status=200
349+
)
350+
351+
# Mock INVOKE to reflect received parameters
352+
def reflect_parameters(url, **kwargs):
353+
received_params = kwargs.get("json", {})
354+
return CallbackResult(status=200, payload={"result": received_params})
355+
356+
aioresponses.post(
357+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke",
358+
callback=reflect_parameters,
359+
)
360+
361+
bound_value_result = True
362+
# Use AsyncMock for the async callable
363+
bound_async_callable = AsyncMock(return_value=bound_value_result)
364+
365+
async with ToolboxClient(TEST_BASE_URL) as client:
366+
tool = await client.load_tool(TOOL_NAME)
367+
# Bind 'argB' using the method with an async function
368+
bound_tool = tool.bind_parameters({"argB": bound_async_callable})
369+
370+
assert bound_tool is not tool
371+
assert "argB" not in bound_tool.__signature__.parameters
372+
assert "argA" in bound_tool.__signature__.parameters
373+
374+
passed_value_a = 42
375+
# Invoke the tool, the async callable for 'argB' should be awaited internally
376+
res_payload = await bound_tool(argA=passed_value_a)
377+
378+
# Check that the result includes the awaited value from the async callable
379+
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
380+
# Verify the async mock was awaited
381+
bound_async_callable.assert_awaited_once()
382+
383+
384+
@pytest.mark.asyncio
385+
async def test_new_add_auth_token_getters_duplicate_fail(aioresponses, test_tool_auth):
386+
"""
387+
[NEW] Tests that adding a duplicate auth token getter raises ValueError.
388+
Covers: `ValueError` for duplicate auth source registration.
389+
"""
390+
TOOL_NAME = "duplicate_auth_tool"
391+
AUTH_SERVICE = "my-auth-service" # From test_tool_auth fixture
392+
manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_auth})
393+
394+
# Mock GET tool definition
395+
aioresponses.get(
396+
f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}",
397+
payload=manifest.model_dump(), status=200
398+
)
399+
400+
def token_handler_1():
401+
return "token1"
402+
403+
def token_handler_2():
404+
return "token2"
405+
406+
async with ToolboxClient(TEST_BASE_URL) as client:
407+
# Load the tool without initial getters
408+
tool = await client.load_tool(TOOL_NAME)
409+
410+
# Add the getter the first time - should succeed
411+
authed_tool = tool.add_auth_token_getters({AUTH_SERVICE: token_handler_1})
412+
# Check internal state to confirm addition (optional assertion)
413+
assert AUTH_SERVICE in authed_tool._ToolboxTool__auth_service_token_getters # type: ignore
414+
415+
# Attempt to add a getter for the *same* service again - should fail
416+
with pytest.raises(ValueError, match=f"Authentication source\\(s\\) `{AUTH_SERVICE}` already registered in tool `{TOOL_NAME}`."):
417+
authed_tool.add_auth_token_getters({AUTH_SERVICE: token_handler_2})
418+
419+
# === End of new code block ===
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import inspect
17+
from inspect import Parameter
18+
19+
import pytest
20+
21+
from toolbox_core.protocol import ParameterSchema
22+
23+
24+
def test_parameter_schema_float():
25+
"""Tests ParameterSchema with type 'float'."""
26+
schema = ParameterSchema(name="price", type="float", description="The item price")
27+
expected_type = float
28+
# Use internal method directly for type check, though testing to_param implicitly tests it
29+
assert schema._ParameterSchema__get_type() == expected_type
30+
31+
param = schema.to_param()
32+
assert isinstance(param, Parameter)
33+
assert param.name == "price"
34+
assert param.annotation == expected_type
35+
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
36+
assert param.default == Parameter.empty # No default specified
37+
38+
39+
def test_parameter_schema_boolean():
40+
"""Tests ParameterSchema with type 'boolean'."""
41+
schema = ParameterSchema(
42+
name="is_active", type="boolean", description="Activity status"
43+
)
44+
expected_type = bool
45+
assert schema._ParameterSchema__get_type() == expected_type
46+
47+
param = schema.to_param()
48+
assert isinstance(param, Parameter)
49+
assert param.name == "is_active"
50+
assert param.annotation == expected_type
51+
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
52+
53+
54+
def test_parameter_schema_array_string():
55+
"""Tests ParameterSchema with type 'array' containing strings."""
56+
item_schema = ParameterSchema(
57+
name="", type="string", description=""
58+
) # Name/desc not relevant for item type
59+
schema = ParameterSchema(
60+
name="tags", type="array", description="List of tags", items=item_schema
61+
)
62+
# Note: Direct comparison with list[str] might differ slightly depending on Python version's typing internals
63+
# We check the Parameter annotation which uses the correct runtime representation
64+
# assert schema._ParameterSchema__get_type() == list[str] # This might fail equality check
65+
66+
param = schema.to_param()
67+
assert isinstance(param, Parameter)
68+
assert param.name == "tags"
69+
# Check annotation for list of strings. How typing represents this can vary slightly.
70+
# Using get_origin and get_args is robust across Python versions >= 3.8
71+
if hasattr(inspect, "get_origin"): # Python 3.8+
72+
from typing import get_args, get_origin
73+
74+
assert get_origin(param.annotation) is list
75+
assert get_args(param.annotation) == (str,)
76+
else: # Fallback for older versions (might need adjustment)
77+
assert param.annotation == list[str] # For older typing
78+
79+
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
80+
81+
82+
def test_parameter_schema_array_integer():
83+
"""Tests ParameterSchema with type 'array' containing integers."""
84+
item_schema = ParameterSchema(name="", type="integer", description="")
85+
schema = ParameterSchema(
86+
name="scores", type="array", description="List of scores", items=item_schema
87+
)
88+
89+
param = schema.to_param()
90+
assert isinstance(param, Parameter)
91+
assert param.name == "scores"
92+
assert param.annotation == list[int]
93+
assert param.kind == Parameter.POSITIONAL_OR_KEYWORD
94+
95+
96+
def test_parameter_schema_array_no_items_error():
97+
"""Tests that 'array' type raises error if 'items' is None."""
98+
schema = ParameterSchema(
99+
name="bad_list", type="array", description="List without item type"
100+
)
101+
102+
expected_error_msg = "Unexpected value: type is 'list' but items is None"
103+
with pytest.raises(Exception, match=expected_error_msg):
104+
schema._ParameterSchema__get_type()
105+
106+
# Also test via to_param()
107+
with pytest.raises(Exception, match=expected_error_msg):
108+
schema.to_param()
109+
110+
111+
def test_parameter_schema_unsupported_type_error():
112+
"""Tests that an unsupported type raises ValueError."""
113+
unsupported_type = "datetime"
114+
schema = ParameterSchema(
115+
name="event_time", type=unsupported_type, description="When it happened"
116+
)
117+
118+
expected_error_msg = f"Unsupported schema type: {unsupported_type}"
119+
with pytest.raises(ValueError, match=expected_error_msg):
120+
schema._ParameterSchema__get_type() # Call the method that raises
121+
122+
# Also test via to_param()
123+
with pytest.raises(ValueError, match=expected_error_msg):
124+
schema.to_param()

0 commit comments

Comments
 (0)