Skip to content

Commit 13fe2a8

Browse files
authored
feat: Add support for async token getters to ToolboxTool (#147)
* feat: Add support for async token getters to ToolboxTool * chore: Improve variable names and docstring for more clarity * chore: Improve docstring * chore: Add unit test cases * chore: Add e2e test case * chore: Fix e2e test case
1 parent d3e20e5 commit 13fe2a8

File tree

3 files changed

+105
-7
lines changed

3 files changed

+105
-7
lines changed

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from inspect import Signature
1919
from typing import (
2020
Any,
21+
Awaitable,
2122
Callable,
2223
Iterable,
2324
Mapping,
@@ -181,16 +182,12 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
181182

182183
# apply bounded parameters
183184
for param, value in self.__bound_parameters.items():
184-
if asyncio.iscoroutinefunction(value):
185-
value = await value()
186-
elif callable(value):
187-
value = value()
188-
payload[param] = value
185+
payload[param] = await resolve_value(value)
189186

190187
# create headers for auth services
191188
headers = {}
192189
for auth_service, token_getter in self.__auth_service_token_getters.items():
193-
headers[f"{auth_service}_token"] = token_getter()
190+
headers[f"{auth_service}_token"] = await resolve_value(token_getter)
194191

195192
async with self.__session.post(
196193
self.__url,
@@ -330,3 +327,28 @@ def params_to_pydantic_model(
330327
),
331328
)
332329
return create_model(tool_name, **field_definitions)
330+
331+
332+
async def resolve_value(
333+
source: Union[Callable[[], Awaitable[Any]], Callable[[], Any], Any],
334+
) -> Any:
335+
"""
336+
Asynchronously or synchronously resolves a given source to its value.
337+
338+
If the `source` is a coroutine function, it will be awaited.
339+
If the `source` is a regular callable, it will be called.
340+
Otherwise (if it's not a callable), the `source` itself is returned directly.
341+
342+
Args:
343+
source: The value, a callable returning a value, or a callable
344+
returning an awaitable value.
345+
346+
Returns:
347+
The resolved value.
348+
"""
349+
350+
if asyncio.iscoroutinefunction(source):
351+
return await source()
352+
elif callable(source):
353+
return source()
354+
return source

packages/toolbox-core/tests/test_e2e.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,20 @@ async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str):
166166
response = await auth_tool(id="2")
167167
assert "row2" in response
168168

169+
@pytest.mark.asyncio
170+
async def test_run_tool_async_auth(self, toolbox: ToolboxClient, auth_token1: str):
171+
"""Tests running a tool with correct auth using an async token getter."""
172+
tool = await toolbox.load_tool("get-row-by-id-auth")
173+
174+
async def get_token_asynchronously():
175+
return auth_token1
176+
177+
auth_tool = tool.add_auth_token_getters(
178+
{"my-test-auth": get_token_asynchronously}
179+
)
180+
response = await auth_tool(id="2")
181+
assert "row2" in response
182+
169183
async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient):
170184
"""Tests running a tool with a param requiring auth, without auth."""
171185
tool = await toolbox.load_tool("get-row-by-email-auth")

packages/toolbox-core/tests/test_tools.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
from typing import AsyncGenerator
17+
from unittest.mock import AsyncMock, Mock
1718

1819
import pytest
1920
import pytest_asyncio
@@ -22,7 +23,7 @@
2223
from pydantic import ValidationError
2324

2425
from toolbox_core.protocol import ParameterSchema
25-
from toolbox_core.tool import ToolboxTool, create_docstring
26+
from toolbox_core.tool import ToolboxTool, create_docstring, resolve_value
2627

2728
TEST_BASE_URL = "http://toolbox.example.com"
2829
TEST_TOOL_NAME = "sample_tool"
@@ -223,3 +224,64 @@ async def test_tool_run_with_pydantic_validation_error(
223224
in str(exc_info.value)
224225
)
225226
m.assert_not_called()
227+
228+
229+
@pytest.mark.asyncio
230+
@pytest.mark.parametrize(
231+
"non_callable_source",
232+
[
233+
"a simple string",
234+
12345,
235+
True,
236+
False,
237+
None,
238+
[1, "two", 3.0],
239+
{"key": "value", "number": 100},
240+
object(),
241+
],
242+
ids=[
243+
"string",
244+
"integer",
245+
"bool_true",
246+
"bool_false",
247+
"none",
248+
"list",
249+
"dict",
250+
"object",
251+
],
252+
)
253+
async def test_resolve_value_non_callable(non_callable_source):
254+
"""
255+
Tests resolve_value when the source is not callable.
256+
"""
257+
resolved = await resolve_value(non_callable_source)
258+
259+
assert resolved is non_callable_source
260+
261+
262+
@pytest.mark.asyncio
263+
async def test_resolve_value_sync_callable():
264+
"""
265+
Tests resolve_value with a synchronous callable.
266+
"""
267+
expected_value = "sync result"
268+
sync_callable = Mock(return_value=expected_value)
269+
270+
resolved = await resolve_value(sync_callable)
271+
272+
sync_callable.assert_called_once()
273+
assert resolved == expected_value
274+
275+
276+
@pytest.mark.asyncio
277+
async def test_resolve_value_async_callable():
278+
"""
279+
Tests resolve_value with an asynchronous callable (coroutine function).
280+
"""
281+
expected_value = "async result"
282+
async_callable = AsyncMock(return_value=expected_value)
283+
284+
resolved = await resolve_value(async_callable)
285+
286+
async_callable.assert_awaited_once()
287+
assert resolved == expected_value

0 commit comments

Comments
 (0)