Skip to content

feat: Add support for async token getters to ToolboxTool #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from inspect import Signature
from typing import (
Any,
Awaitable,
Callable,
Iterable,
Mapping,
Expand Down Expand Up @@ -181,16 +182,12 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:

# apply bounded parameters
for param, value in self.__bound_parameters.items():
if asyncio.iscoroutinefunction(value):
value = await value()
elif callable(value):
value = value()
payload[param] = value
payload[param] = await resolve_value(value)

# create headers for auth services
headers = {}
for auth_service, token_getter in self.__auth_service_token_getters.items():
headers[f"{auth_service}_token"] = token_getter()
headers[f"{auth_service}_token"] = await resolve_value(token_getter)

async with self.__session.post(
self.__url,
Expand Down Expand Up @@ -330,3 +327,28 @@ def params_to_pydantic_model(
),
)
return create_model(tool_name, **field_definitions)


async def resolve_value(
source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any],
) -> Any:
"""
Asynchronously or synchronously resolves a given source to its value.

If the `source` is a coroutine function, it will be awaited.
If the `source` is a regular callable, it will be called.
Otherwise (if it's not a callable), the `source` itself is returned directly.

Args:
source: The value, a callable returning a value, or a callable
returning an awaitable value.

Returns:
The resolved value.
"""

if asyncio.iscoroutinefunction(source):
return await source()
elif callable(source):
return source()
return source
14 changes: 14 additions & 0 deletions packages/toolbox-core/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,20 @@ async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str):
response = await auth_tool(id="2")
assert "row2" in response

@pytest.mark.asyncio
async def test_run_tool_async_auth(self, toolbox: ToolboxClient, auth_token1: str):
"""Tests running a tool with correct auth using an async token getter."""
tool = await toolbox.load_tool("get-row-by-id-auth")

async def get_token_asynchronously():
return auth_token1

auth_tool = tool.add_auth_token_getters(
{"my-test-auth": get_token_asynchronously}
)
response = await auth_tool(id="2")
assert "row2" in response

async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient):
"""Tests running a tool with a param requiring auth, without auth."""
tool = await toolbox.load_tool("get-row-by-email-auth")
Expand Down
64 changes: 63 additions & 1 deletion packages/toolbox-core/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from typing import AsyncGenerator
from unittest.mock import AsyncMock, Mock

import pytest
import pytest_asyncio
Expand All @@ -22,7 +23,7 @@
from pydantic import ValidationError

from toolbox_core.protocol import ParameterSchema
from toolbox_core.tool import ToolboxTool, create_docstring
from toolbox_core.tool import ToolboxTool, create_docstring, resolve_value

TEST_BASE_URL = "http://toolbox.example.com"
TEST_TOOL_NAME = "sample_tool"
Expand Down Expand Up @@ -223,3 +224,64 @@ async def test_tool_run_with_pydantic_validation_error(
in str(exc_info.value)
)
m.assert_not_called()


@pytest.mark.asyncio
@pytest.mark.parametrize(
"non_callable_source",
[
"a simple string",
12345,
True,
False,
None,
[1, "two", 3.0],
{"key": "value", "number": 100},
object(),
],
ids=[
"string",
"integer",
"bool_true",
"bool_false",
"none",
"list",
"dict",
"object",
],
)
async def test_resolve_value_non_callable(non_callable_source):
"""
Tests resolve_value when the source is not callable.
"""
resolved = await resolve_value(non_callable_source)

assert resolved is non_callable_source


@pytest.mark.asyncio
async def test_resolve_value_sync_callable():
"""
Tests resolve_value with a synchronous callable.
"""
expected_value = "sync result"
sync_callable = Mock(return_value=expected_value)

resolved = await resolve_value(sync_callable)

sync_callable.assert_called_once()
assert resolved == expected_value


@pytest.mark.asyncio
async def test_resolve_value_async_callable():
"""
Tests resolve_value with an asynchronous callable (coroutine function).
"""
expected_value = "async result"
async_callable = AsyncMock(return_value=expected_value)

resolved = await resolve_value(async_callable)

async_callable.assert_awaited_once()
assert resolved == expected_value