diff --git a/packages/toolbox-core/pyproject.toml b/packages/toolbox-core/pyproject.toml index 602b6d1e..f9df761c 100644 --- a/packages/toolbox-core/pyproject.toml +++ b/packages/toolbox-core/pyproject.toml @@ -46,6 +46,7 @@ test = [ "pytest==8.3.5", "pytest-aioresponses==0.3.0", "pytest-asyncio==0.26.0", + "pytest-cov==6.1.0", "google-cloud-secret-manager==2.23.2", "google-cloud-storage==3.1.0", ] diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index bc8ca23c..a534e706 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + import types from typing import Any, Callable, Mapping, Optional, Union diff --git a/packages/toolbox-core/src/toolbox_core/sync_client.py b/packages/toolbox-core/src/toolbox_core/sync_client.py index 36877223..37ca6437 100644 --- a/packages/toolbox-core/src/toolbox_core/sync_client.py +++ b/packages/toolbox-core/src/toolbox_core/sync_client.py @@ -14,9 +14,7 @@ import asyncio from threading import Thread -from typing import Any, Awaitable, Callable, Mapping, Optional, TypeVar, Union - -from aiohttp import ClientSession +from typing import Any, Callable, Mapping, Optional, TypeVar, Union from .client import ToolboxClient from .sync_tool import ToolboxSyncTool diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 04e803bc..3436580a 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -18,6 +18,7 @@ from inspect import Signature from typing import ( Any, + Awaitable, Callable, Iterable, Mapping, @@ -181,16 +182,12 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # apply bounded parameters for param, value in self.__bound_parameters.items(): - if asyncio.iscoroutinefunction(value): - value = await value() - elif callable(value): - value = value() - payload[param] = value + payload[param] = await resolve_value(value) # create headers for auth services headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): - headers[f"{auth_service}_token"] = token_getter() + headers[f"{auth_service}_token"] = await resolve_value(token_getter) async with self.__session.post( self.__url, @@ -330,3 +327,28 @@ def params_to_pydantic_model( ), ) return create_model(tool_name, **field_definitions) + + +async def resolve_value( + source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any], +) -> Any: + """ + Asynchronously or synchronously resolves a given source to its value. + + If the `source` is a coroutine function, it will be awaited. + If the `source` is a regular callable, it will be called. + Otherwise (if it's not a callable), the `source` itself is returned directly. + + Args: + source: The value, a callable returning a value, or a callable + returning an awaitable value. + + Returns: + The resolved value. + """ + + if asyncio.iscoroutinefunction(source): + return await source() + elif callable(source): + return source() + return source diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 2ce600c3..a9cb091a 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -15,6 +15,7 @@ import inspect import json +from unittest.mock import AsyncMock, Mock import pytest import pytest_asyncio @@ -130,6 +131,60 @@ async def test_load_toolset_success(aioresponses, test_tool_str, test_tool_int_b assert {t.__name__ for t in tools} == manifest.tools.keys() +@pytest.mark.asyncio +async def test_invoke_tool_server_error(aioresponses, test_tool_str): + """Tests that invoking a tool raises an Exception when the server returns an + error status.""" + TOOL_NAME = "server_error_tool" + ERROR_MESSAGE = "Simulated Server Error" + manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) + + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}", + payload=manifest.model_dump(), + status=200, + ) + aioresponses.post( + f"{TEST_BASE_URL}/api/tool/{TOOL_NAME}/invoke", + payload={"error": ERROR_MESSAGE}, + status=500, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + loaded_tool = await client.load_tool(TOOL_NAME) + + with pytest.raises(Exception, match=ERROR_MESSAGE): + await loaded_tool(param1="some input") + + +@pytest.mark.asyncio +async def test_load_tool_not_found_in_manifest(aioresponses, test_tool_str): + """ + Tests that load_tool raises an Exception when the requested tool name + is not found in the manifest returned by the server, using existing fixtures. + """ + ACTUAL_TOOL_IN_MANIFEST = "actual_tool_abc" + REQUESTED_TOOL_NAME = "non_existent_tool_xyz" + + manifest = ManifestSchema( + serverVersion="0.0.0", tools={ACTUAL_TOOL_IN_MANIFEST: test_tool_str} + ) + + aioresponses.get( + f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", + payload=manifest.model_dump(), + status=200, + ) + + async with ToolboxClient(TEST_BASE_URL) as client: + with pytest.raises(Exception, match=f"Tool '{REQUESTED_TOOL_NAME}' not found!"): + await client.load_tool(REQUESTED_TOOL_NAME) + + aioresponses.assert_called_once_with( + f"{TEST_BASE_URL}/api/tool/{REQUESTED_TOOL_NAME}", method="GET" + ) + + class TestAuth: @pytest.fixture @@ -182,7 +237,7 @@ def token_handler(): tool = await client.load_tool( tool_name, auth_token_getters={"my-auth-service": token_handler} ) - res = await tool(5) + await tool(5) @pytest.mark.asyncio async def test_auth_with_add_token_success( @@ -195,7 +250,7 @@ def token_handler(): tool = await client.load_tool(tool_name) tool = tool.add_auth_token_getters({"my-auth-service": token_handler}) - res = await tool(5) + await tool(5) @pytest.mark.asyncio async def test_auth_with_load_tool_fail_no_token( @@ -203,12 +258,27 @@ async def test_auth_with_load_tool_fail_no_token( ): """Tests 'load_tool' with auth token is specified.""" - def token_handler(): - return expected_header - tool = await client.load_tool(tool_name) with pytest.raises(Exception): - res = await tool(5) + await tool(5) + + @pytest.mark.asyncio + async def test_add_auth_token_getters_duplicate_fail(self, tool_name, client): + """ + Tests that adding a duplicate auth token getter raises ValueError. + """ + AUTH_SERVICE = "my-auth-service" + + tool = await client.load_tool(tool_name) + + authed_tool = tool.add_auth_token_getters({AUTH_SERVICE: {}}) + assert AUTH_SERVICE in authed_tool._ToolboxTool__auth_service_token_getters + + with pytest.raises( + ValueError, + match=f"Authentication source\\(s\\) `{AUTH_SERVICE}` already registered in tool `{tool_name}`.", + ): + authed_tool.add_auth_token_getters({AUTH_SERVICE: {}}) class TestBoundParameter: @@ -283,6 +353,22 @@ async def test_bind_param_success(self, tool_name, client): assert len(tool.__signature__.parameters) == 2 assert "argA" in tool.__signature__.parameters + tool = tool.bind_parameters({"argA": 5}) + + assert len(tool.__signature__.parameters) == 1 + assert "argA" not in tool.__signature__.parameters + + res = await tool(True) + assert "argA" in res + + @pytest.mark.asyncio + async def test_bind_callable_param_success(self, tool_name, client): + """Tests 'bind_param' with a bound parameter specified.""" + tool = await client.load_tool(tool_name) + + assert len(tool.__signature__.parameters) == 2 + assert "argA" in tool.__signature__.parameters + tool = tool.bind_parameters({"argA": lambda: 5}) assert len(tool.__signature__.parameters) == 1 @@ -301,3 +387,67 @@ async def test_bind_param_fail(self, tool_name, client): with pytest.raises(Exception): tool = tool.bind_parameters({"argC": lambda: 5}) + + @pytest.mark.asyncio + async def test_bind_param_static_value_success(self, tool_name, client): + """ + Tests bind_parameters method with a static value. + """ + + bound_value = "Test value" + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_parameters({"argB": bound_value}) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value} + + @pytest.mark.asyncio + async def test_bind_param_sync_callable_value_success(self, tool_name, client): + """ + Tests bind_parameters method with a sync callable value. + """ + + bound_value_result = True + bound_sync_callable = Mock(return_value=bound_value_result) + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_parameters({"argB": bound_sync_callable}) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} + bound_sync_callable.assert_called_once() + + @pytest.mark.asyncio + async def test_bind_param_async_callable_value_success(self, tool_name, client): + """ + Tests bind_parameters method with an async callable value. + """ + + bound_value_result = True + bound_async_callable = AsyncMock(return_value=bound_value_result) + + tool = await client.load_tool(tool_name) + bound_tool = tool.bind_parameters({"argB": bound_async_callable}) + + assert bound_tool is not tool + assert "argB" not in bound_tool.__signature__.parameters + assert "argA" in bound_tool.__signature__.parameters + + passed_value_a = 42 + res_payload = await bound_tool(argA=passed_value_a) + + assert res_payload == {"argA": passed_value_a, "argB": bound_value_result} + bound_async_callable.assert_awaited_once() diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 68fffa75..cf7b21d1 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -166,6 +166,20 @@ async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str): response = await auth_tool(id="2") assert "row2" in response + @pytest.mark.asyncio + async def test_run_tool_async_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with correct auth using an async token getter.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + + async def get_token_asynchronously(): + return auth_token1 + + auth_tool = tool.add_auth_token_getters( + {"my-test-auth": get_token_asynchronously} + ) + response = await auth_tool(id="2") + assert "row2" in response + async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient): """Tests running a tool with a param requiring auth, without auth.""" tool = await toolbox.load_tool("get-row-by-email-auth") diff --git a/packages/toolbox-core/tests/test_protocol.py b/packages/toolbox-core/tests/test_protocol.py new file mode 100644 index 00000000..a70fa3fe --- /dev/null +++ b/packages/toolbox-core/tests/test_protocol.py @@ -0,0 +1,108 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from inspect import Parameter + +import pytest + +from toolbox_core.protocol import ParameterSchema + + +def test_parameter_schema_float(): + """Tests ParameterSchema with type 'float'.""" + schema = ParameterSchema(name="price", type="float", description="The item price") + expected_type = float + assert schema._ParameterSchema__get_type() == expected_type + + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "price" + assert param.annotation == expected_type + assert param.kind == Parameter.POSITIONAL_OR_KEYWORD + assert param.default == Parameter.empty + + +def test_parameter_schema_boolean(): + """Tests ParameterSchema with type 'boolean'.""" + schema = ParameterSchema( + name="is_active", type="boolean", description="Activity status" + ) + expected_type = bool + assert schema._ParameterSchema__get_type() == expected_type + + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "is_active" + assert param.annotation == expected_type + assert param.kind == Parameter.POSITIONAL_OR_KEYWORD + + +def test_parameter_schema_array_string(): + """Tests ParameterSchema with type 'array' containing strings.""" + item_schema = ParameterSchema(name="", type="string", description="") + schema = ParameterSchema( + name="tags", type="array", description="List of tags", items=item_schema + ) + + assert schema._ParameterSchema__get_type() == list[str] + + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "tags" + assert param.annotation == list[str] + assert param.kind == Parameter.POSITIONAL_OR_KEYWORD + + +def test_parameter_schema_array_integer(): + """Tests ParameterSchema with type 'array' containing integers.""" + item_schema = ParameterSchema(name="", type="integer", description="") + schema = ParameterSchema( + name="scores", type="array", description="List of scores", items=item_schema + ) + + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "scores" + assert param.annotation == list[int] + assert param.kind == Parameter.POSITIONAL_OR_KEYWORD + + +def test_parameter_schema_array_no_items_error(): + """Tests that 'array' type raises error if 'items' is None.""" + schema = ParameterSchema( + name="bad_list", type="array", description="List without item type" + ) + + expected_error_msg = "Unexpected value: type is 'list' but items is None" + with pytest.raises(Exception, match=expected_error_msg): + schema._ParameterSchema__get_type() + + with pytest.raises(Exception, match=expected_error_msg): + schema.to_param() + + +def test_parameter_schema_unsupported_type_error(): + """Tests that an unsupported type raises ValueError.""" + unsupported_type = "datetime" + schema = ParameterSchema( + name="event_time", type=unsupported_type, description="When it happened" + ) + + expected_error_msg = f"Unsupported schema type: {unsupported_type}" + with pytest.raises(ValueError, match=expected_error_msg): + schema._ParameterSchema__get_type() + + with pytest.raises(ValueError, match=expected_error_msg): + schema.to_param() diff --git a/packages/toolbox-core/tests/test_tools.py b/packages/toolbox-core/tests/test_tools.py new file mode 100644 index 00000000..505aa7f7 --- /dev/null +++ b/packages/toolbox-core/tests/test_tools.py @@ -0,0 +1,287 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import AsyncGenerator +from unittest.mock import AsyncMock, Mock + +import pytest +import pytest_asyncio +from aiohttp import ClientSession +from aioresponses import aioresponses +from pydantic import ValidationError + +from toolbox_core.protocol import ParameterSchema +from toolbox_core.tool import ToolboxTool, create_docstring, resolve_value + +TEST_BASE_URL = "http://toolbox.example.com" +TEST_TOOL_NAME = "sample_tool" + + +@pytest.fixture +def sample_tool_params() -> list[ParameterSchema]: + """Parameters for the sample tool.""" + return [ + ParameterSchema( + name="message", type="string", description="A message to process" + ), + ParameterSchema(name="count", type="integer", description="A number"), + ] + + +@pytest.fixture +def sample_tool_description() -> str: + """Description for the sample tool.""" + return "A sample tool that processes a message and a count." + + +@pytest_asyncio.fixture +async def http_session() -> AsyncGenerator[ClientSession, None]: + """Provides an aiohttp ClientSession that is closed after the test.""" + async with ClientSession() as session: + yield session + + +def test_create_docstring_one_param_real_schema(): + """ + Tests create_docstring with one real ParameterSchema instance. + """ + description = "This tool does one thing." + params = [ + ParameterSchema( + name="input_file", type="string", description="Path to the input file." + ) + ] + + result_docstring = create_docstring(description, params) + + expected_docstring = ( + "This tool does one thing.\n\n" + "Args:\n" + " input_file (str): Path to the input file." + ) + + assert result_docstring == expected_docstring + + +def test_create_docstring_multiple_params_real_schema(): + """ + Tests create_docstring with multiple real ParameterSchema instances. + """ + description = "This tool does multiple things." + params = [ + ParameterSchema(name="query", type="string", description="The search query."), + ParameterSchema( + name="max_results", type="integer", description="Maximum results to return." + ), + ParameterSchema( + name="verbose", type="boolean", description="Enable verbose output." + ), + ] + + result_docstring = create_docstring(description, params) + + expected_docstring = ( + "This tool does multiple things.\n\n" + "Args:\n" + " query (str): The search query.\n" + " max_results (int): Maximum results to return.\n" + " verbose (bool): Enable verbose output." + ) + + assert result_docstring == expected_docstring + + +def test_create_docstring_no_description_real_schema(): + """ + Tests create_docstring with empty description and one real ParameterSchema. + """ + description = "" + params = [ + ParameterSchema( + name="config_id", type="string", description="The ID of the configuration." + ) + ] + + result_docstring = create_docstring(description, params) + + expected_docstring = ( + "\n\nArgs:\n" " config_id (str): The ID of the configuration." + ) + + assert result_docstring == expected_docstring + assert result_docstring.startswith("\n\nArgs:") + assert "config_id (str): The ID of the configuration." in result_docstring + + +def test_create_docstring_no_params(): + """ + Tests create_docstring when the params list is empty. + """ + description = "This is a tool description." + params = [] + + result_docstring = create_docstring(description, params) + + assert result_docstring == description + assert "\n\nArgs:" not in result_docstring + + +@pytest.mark.asyncio +async def test_tool_creation_callable_and_run( + http_session: ClientSession, + sample_tool_params: list[ParameterSchema], + sample_tool_description: str, +): + """ + Tests creating a ToolboxTool, checks callability, and simulates a run. + """ + tool_name = TEST_TOOL_NAME + base_url = TEST_BASE_URL + invoke_url = f"{base_url}/api/tool/{tool_name}/invoke" + + input_args = {"message": "hello world", "count": 5} + expected_payload = input_args.copy() + mock_server_response_body = {"result": "Processed: hello world (5 times)"} + expected_tool_result = mock_server_response_body["result"] + + with aioresponses() as m: + m.post(invoke_url, status=200, payload=mock_server_response_body) + + tool_instance = ToolboxTool( + session=http_session, + base_url=base_url, + name=tool_name, + description=sample_tool_description, + params=sample_tool_params, + required_authn_params={}, + auth_service_token_getters={}, + bound_params={}, + ) + + assert callable(tool_instance), "ToolboxTool instance should be callable" + + assert "message" in tool_instance.__signature__.parameters + assert "count" in tool_instance.__signature__.parameters + assert tool_instance.__signature__.parameters["message"].annotation == str + assert tool_instance.__signature__.parameters["count"].annotation == int + + actual_result = await tool_instance("hello world", 5) + + assert actual_result == expected_tool_result + + m.assert_called_once_with( + invoke_url, method="POST", json=expected_payload, headers={} + ) + + +@pytest.mark.asyncio +async def test_tool_run_with_pydantic_validation_error( + http_session: ClientSession, + sample_tool_params: list[ParameterSchema], + sample_tool_description: str, +): + """ + Tests that calling the tool with incorrect argument types raises an error + due to Pydantic validation *before* making an HTTP request. + """ + tool_name = TEST_TOOL_NAME + base_url = TEST_BASE_URL + invoke_url = f"{base_url}/api/tool/{tool_name}/invoke" + + with aioresponses() as m: + m.post(invoke_url, status=200, payload={"result": "Should not be called"}) + + tool_instance = ToolboxTool( + session=http_session, + base_url=base_url, + name=tool_name, + description=sample_tool_description, + params=sample_tool_params, + required_authn_params={}, + auth_service_token_getters={}, + bound_params={}, + ) + + assert callable(tool_instance) + + with pytest.raises(ValidationError) as exc_info: + await tool_instance(message="hello", count="not-a-number") + + assert ( + "1 validation error for sample_tool\ncount\n Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='not-a-number', input_type=str]\n For further information visit https://errors.pydantic.dev/2.11/v/int_parsing" + in str(exc_info.value) + ) + m.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "non_callable_source", + [ + "a simple string", + 12345, + True, + False, + None, + [1, "two", 3.0], + {"key": "value", "number": 100}, + object(), + ], + ids=[ + "string", + "integer", + "bool_true", + "bool_false", + "none", + "list", + "dict", + "object", + ], +) +async def test_resolve_value_non_callable(non_callable_source): + """ + Tests resolve_value when the source is not callable. + """ + resolved = await resolve_value(non_callable_source) + + assert resolved is non_callable_source + + +@pytest.mark.asyncio +async def test_resolve_value_sync_callable(): + """ + Tests resolve_value with a synchronous callable. + """ + expected_value = "sync result" + sync_callable = Mock(return_value=expected_value) + + resolved = await resolve_value(sync_callable) + + sync_callable.assert_called_once() + assert resolved == expected_value + + +@pytest.mark.asyncio +async def test_resolve_value_async_callable(): + """ + Tests resolve_value with an asynchronous callable (coroutine function). + """ + expected_value = "async result" + async_callable = AsyncMock(return_value=expected_value) + + resolved = await resolve_value(async_callable) + + async_callable.assert_awaited_once() + assert resolved == expected_value