diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 04e803bc..3436580a 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -18,6 +18,7 @@ from inspect import Signature from typing import ( Any, + Awaitable, Callable, Iterable, Mapping, @@ -181,16 +182,12 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: # apply bounded parameters for param, value in self.__bound_parameters.items(): - if asyncio.iscoroutinefunction(value): - value = await value() - elif callable(value): - value = value() - payload[param] = value + payload[param] = await resolve_value(value) # create headers for auth services headers = {} for auth_service, token_getter in self.__auth_service_token_getters.items(): - headers[f"{auth_service}_token"] = token_getter() + headers[f"{auth_service}_token"] = await resolve_value(token_getter) async with self.__session.post( self.__url, @@ -330,3 +327,28 @@ def params_to_pydantic_model( ), ) return create_model(tool_name, **field_definitions) + + +async def resolve_value( + source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any], +) -> Any: + """ + Asynchronously or synchronously resolves a given source to its value. + + If the `source` is a coroutine function, it will be awaited. + If the `source` is a regular callable, it will be called. + Otherwise (if it's not a callable), the `source` itself is returned directly. + + Args: + source: The value, a callable returning a value, or a callable + returning an awaitable value. + + Returns: + The resolved value. + """ + + if asyncio.iscoroutinefunction(source): + return await source() + elif callable(source): + return source() + return source diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 68fffa75..cf7b21d1 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -166,6 +166,20 @@ async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str): response = await auth_tool(id="2") assert "row2" in response + @pytest.mark.asyncio + async def test_run_tool_async_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with correct auth using an async token getter.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + + async def get_token_asynchronously(): + return auth_token1 + + auth_tool = tool.add_auth_token_getters( + {"my-test-auth": get_token_asynchronously} + ) + response = await auth_tool(id="2") + assert "row2" in response + async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient): """Tests running a tool with a param requiring auth, without auth.""" tool = await toolbox.load_tool("get-row-by-email-auth") diff --git a/packages/toolbox-core/tests/test_tools.py b/packages/toolbox-core/tests/test_tools.py index e8fc7e99..505aa7f7 100644 --- a/packages/toolbox-core/tests/test_tools.py +++ b/packages/toolbox-core/tests/test_tools.py @@ -14,6 +14,7 @@ from typing import AsyncGenerator +from unittest.mock import AsyncMock, Mock import pytest import pytest_asyncio @@ -22,7 +23,7 @@ from pydantic import ValidationError from toolbox_core.protocol import ParameterSchema -from toolbox_core.tool import ToolboxTool, create_docstring +from toolbox_core.tool import ToolboxTool, create_docstring, resolve_value TEST_BASE_URL = "http://toolbox.example.com" TEST_TOOL_NAME = "sample_tool" @@ -223,3 +224,64 @@ async def test_tool_run_with_pydantic_validation_error( in str(exc_info.value) ) m.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "non_callable_source", + [ + "a simple string", + 12345, + True, + False, + None, + [1, "two", 3.0], + {"key": "value", "number": 100}, + object(), + ], + ids=[ + "string", + "integer", + "bool_true", + "bool_false", + "none", + "list", + "dict", + "object", + ], +) +async def test_resolve_value_non_callable(non_callable_source): + """ + Tests resolve_value when the source is not callable. + """ + resolved = await resolve_value(non_callable_source) + + assert resolved is non_callable_source + + +@pytest.mark.asyncio +async def test_resolve_value_sync_callable(): + """ + Tests resolve_value with a synchronous callable. + """ + expected_value = "sync result" + sync_callable = Mock(return_value=expected_value) + + resolved = await resolve_value(sync_callable) + + sync_callable.assert_called_once() + assert resolved == expected_value + + +@pytest.mark.asyncio +async def test_resolve_value_async_callable(): + """ + Tests resolve_value with an asynchronous callable (coroutine function). + """ + expected_value = "async result" + async_callable = AsyncMock(return_value=expected_value) + + resolved = await resolve_value(async_callable) + + async_callable.assert_awaited_once() + assert resolved == expected_value