Skip to content

Commit 44b0697

Browse files
authored
feat: Add convenience methods for adding single auth token getter to tools (#245)
* feat: Add convenience methods for adding single auth token getter to tools * chore: Add unit test coverage * chore: Delint * docs: Add to README * docs: Fix and improve docstrings * docs: Improve docstring * chore: Delint * docs: Improve function docstrings and README (#246) * docs: Improve function docstrings in tool classes * docs: Add singular bind param method to README
1 parent 64aa5a8 commit 44b0697

File tree

5 files changed

+153
-17
lines changed

5 files changed

+153
-17
lines changed

packages/toolbox-core/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ loaded. This modifies the specific tool instance.
381381
toolbox = ToolboxClient("http://127.0.0.1:5000")
382382
tool = await toolbox.load_tool("my-tool")
383383

384-
auth_tool = tool.add_auth_token_getters({"my_auth": get_auth_token}) # Single token
384+
auth_tool = tool.add_auth_token_getter("my_auth", get_auth_token) # Single token
385385

386386
# OR
387387

@@ -459,6 +459,10 @@ specific tool instance.
459459
toolbox = ToolboxClient("http://127.0.0.1:5000")
460460
tool = await toolbox.load_tool("my-tool")
461461

462+
bound_tool = tool.bind_param("param", "value")
463+
464+
# OR
465+
462466
bound_tool = tool.bind_params({"param": "value"})
463467
```
464468

packages/toolbox-core/src/toolbox_core/sync_tool.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,35 +139,65 @@ def add_auth_token_getters(
139139
auth_token_getters: Mapping[str, Callable[[], str]],
140140
) -> "ToolboxSyncTool":
141141
"""
142-
Registers an auth token getter function that is used for AuthServices when tools
143-
are invoked.
142+
Registers auth token getter functions that are used for AuthServices
143+
when tools are invoked.
144144
145145
Args:
146146
auth_token_getters: A mapping of authentication service names to
147147
callables that return the corresponding authentication token.
148148
149149
Returns:
150-
A new ToolboxSyncTool instance with the specified authentication token
151-
getters registered.
152-
"""
150+
A new ToolboxSyncTool instance with the specified authentication
151+
token getters registered.
152+
153+
Raises:
154+
ValueError: If an auth source has already been registered either to
155+
the tool or to the corresponding client.
153156
157+
"""
154158
new_async_tool = self.__async_tool.add_auth_token_getters(auth_token_getters)
155159
return ToolboxSyncTool(new_async_tool, self.__loop, self.__thread)
156160

161+
def add_auth_token_getter(
162+
self, auth_source: str, get_id_token: Callable[[], str]
163+
) -> "ToolboxSyncTool":
164+
"""
165+
Registers an auth token getter function that is used for AuthService
166+
when tools are invoked.
167+
168+
Args:
169+
auth_source: The name of the authentication source.
170+
get_id_token: A function that returns the ID token.
171+
172+
Returns:
173+
A new ToolboxSyncTool instance with the specified authentication
174+
token getter registered.
175+
176+
Raises:
177+
ValueError: If the auth source has already been registered either to
178+
the tool or to the corresponding client.
179+
180+
"""
181+
return self.add_auth_token_getters({auth_source: get_id_token})
182+
157183
def bind_params(
158184
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
159185
) -> "ToolboxSyncTool":
160186
"""
161187
Binds parameters to values or callables that produce values.
162188
163189
Args:
164-
bound_params: A mapping of parameter names to values or callables that
165-
produce values.
190+
bound_params: A mapping of parameter names to values or callables
191+
that produce values.
166192
167193
Returns:
168194
A new ToolboxSyncTool instance with the specified parameters bound.
169-
"""
170195
196+
Raises:
197+
ValueError: If a parameter is already bound or is not defined by the
198+
tool's definition.
199+
200+
"""
171201
new_async_tool = self.__async_tool.bind_params(bound_params)
172202
return ToolboxSyncTool(new_async_tool, self.__loop, self.__thread)
173203

@@ -177,7 +207,7 @@ def bind_param(
177207
param_value: Union[Callable[[], Any], Any],
178208
) -> "ToolboxSyncTool":
179209
"""
180-
Binds a parameter to the value or callables that produce it.
210+
Binds a parameter to the value or callable that produce the value.
181211
182212
Args:
183213
param_name: The name of the bound parameter.
@@ -186,5 +216,10 @@ def bind_param(
186216
187217
Returns:
188218
A new ToolboxSyncTool instance with the specified parameter bound.
219+
220+
Raises:
221+
ValueError: If the parameter is already bound or is not defined by
222+
the tool's definition.
223+
189224
"""
190225
return self.bind_params({param_name: param_value})

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ def add_auth_token_getters(
281281
auth_token_getters: Mapping[str, Callable[[], str]],
282282
) -> "ToolboxTool":
283283
"""
284-
Registers an auth token getter function that is used for AuthServices when tools
285-
are invoked.
284+
Registers auth token getter functions that are used for AuthServices
285+
when tools are invoked.
286286
287287
Args:
288288
auth_token_getters: A mapping of authentication service names to
@@ -292,9 +292,9 @@ def add_auth_token_getters(
292292
A new ToolboxTool instance with the specified authentication token
293293
getters registered.
294294
295-
Raises
296-
ValueError: If the auth source has already been registered either
297-
to the tool or to the corresponding client.
295+
Raises:
296+
ValueError: If an auth source has already been registered either to
297+
the tool or to the corresponding client.
298298
"""
299299

300300
# throw an error if the authentication source is already registered
@@ -346,6 +346,28 @@ def add_auth_token_getters(
346346
required_authz_tokens=tuple(new_req_authz_tokens),
347347
)
348348

349+
def add_auth_token_getter(
350+
self, auth_source: str, get_id_token: Callable[[], str]
351+
) -> "ToolboxTool":
352+
"""
353+
Registers an auth token getter function that is used for AuthService
354+
when tools are invoked.
355+
356+
Args:
357+
auth_source: The name of the authentication source.
358+
get_id_token: A function that returns the ID token.
359+
360+
Returns:
361+
A new ToolboxTool instance with the specified authentication token
362+
getter registered.
363+
364+
Raises:
365+
ValueError: If the auth source has already been registered either to
366+
the tool or to the corresponding client.
367+
368+
"""
369+
return self.add_auth_token_getters({auth_source: get_id_token})
370+
349371
def bind_params(
350372
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
351373
) -> "ToolboxTool":
@@ -358,6 +380,11 @@ def bind_params(
358380
359381
Returns:
360382
A new ToolboxTool instance with the specified parameters bound.
383+
384+
Raises:
385+
ValueError: If a parameter is already bound or is not defined by the
386+
tool's definition.
387+
361388
"""
362389
param_names = set(p.name for p in self.__params)
363390
for name in bound_params.keys():
@@ -389,14 +416,19 @@ def bind_param(
389416
param_value: Union[Callable[[], Any], Any],
390417
) -> "ToolboxTool":
391418
"""
392-
Binds a parameter to the value or callables that produce it.
419+
Binds a parameter to the value or callable that produce the value.
393420
394421
Args:
395422
param_name: The name of the bound parameter.
396423
param_value: The value of the bound parameter, or a callable that
397424
returns the value.
398425
399426
Returns:
400-
A new ToolboxTool instance with the specified parameters bound.
427+
A new ToolboxTool instance with the specified parameter bound.
428+
429+
Raises:
430+
ValueError: If the parameter is already bound or is not defined by
431+
the tool's definition.
432+
401433
"""
402434
return self.bind_params({param_name: param_value})

packages/toolbox-core/tests/test_sync_tool.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,39 @@ def test_toolbox_sync_tool_add_auth_token_getters(
237237
)
238238

239239

240+
def test_toolbox_sync_tool_add_auth_token_getter(
241+
toolbox_sync_tool: ToolboxSyncTool,
242+
mock_async_tool: MagicMock,
243+
event_loop: asyncio.AbstractEventLoop,
244+
mock_thread: MagicMock,
245+
):
246+
"""Tests the add_auth_token_getter method."""
247+
auth_service = "service1"
248+
auth_token_getter = lambda: "token1"
249+
250+
new_mock_async_tool = mock_async_tool.add_auth_token_getters.return_value
251+
new_mock_async_tool.__name__ = "new_async_tool_with_auth"
252+
253+
new_sync_tool = toolbox_sync_tool.add_auth_token_getter(
254+
auth_service, auth_token_getter
255+
)
256+
257+
mock_async_tool.add_auth_token_getters.assert_called_once_with(
258+
{auth_service: auth_token_getter}
259+
)
260+
261+
assert isinstance(new_sync_tool, ToolboxSyncTool)
262+
assert new_sync_tool is not toolbox_sync_tool
263+
assert new_sync_tool._ToolboxSyncTool__async_tool is new_mock_async_tool
264+
assert new_sync_tool._ToolboxSyncTool__loop is event_loop # Should be the same loop
265+
assert (
266+
new_sync_tool._ToolboxSyncTool__thread is mock_thread
267+
) # Should be the same thread
268+
assert (
269+
new_sync_tool.__qualname__ == f"ToolboxSyncTool.{new_mock_async_tool.__name__}"
270+
)
271+
272+
240273
def test_toolbox_sync_tool_bind_params(
241274
toolbox_sync_tool: ToolboxSyncTool,
242275
mock_async_tool: MagicMock,

packages/toolbox-core/tests/test_tool.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,38 @@ def test_add_auth_token_getters_unused_token(
502502
tool_instance.add_auth_token_getters(unused_auth_getters)
503503

504504

505+
def test_add_auth_token_getter_unused_token(
506+
http_session: ClientSession,
507+
sample_tool_params: list[ParameterSchema],
508+
sample_tool_description: str,
509+
unused_auth_getters: Mapping[str, Callable[[], str]],
510+
):
511+
"""
512+
Tests ValueError when add_auth_token_getters is called with a getter for
513+
an unused authentication service.
514+
"""
515+
tool_instance = ToolboxTool(
516+
session=http_session,
517+
base_url=HTTPS_BASE_URL,
518+
name=TEST_TOOL_NAME,
519+
description=sample_tool_description,
520+
params=sample_tool_params,
521+
required_authn_params={},
522+
required_authz_tokens=[],
523+
auth_service_token_getters={},
524+
bound_params={},
525+
client_headers={},
526+
)
527+
528+
expected_error_message = "Authentication source\(s\) \`unused-auth-service\` unused by tool \`sample_tool\`."
529+
530+
with pytest.raises(ValueError, match=expected_error_message):
531+
tool_instance.add_auth_token_getter(
532+
next(iter(unused_auth_getters)),
533+
unused_auth_getters[next(iter(unused_auth_getters))],
534+
)
535+
536+
505537
def test_toolbox_tool_underscore_name_property(toolbox_tool: ToolboxTool):
506538
"""Tests the _name property."""
507539
assert toolbox_tool._name == TEST_TOOL_NAME

0 commit comments

Comments
 (0)