Skip to content

Commit c9f8d20

Browse files
committed
fix: Correctly determine remaining required authz tokens
This PR fixes an issue where the system could inaccurately identify the authorization tokens still needed for tool invocation. The `identify_required_authn_params` helper has been updated to leverage its new capability of recognizing all alternatives of required authorization tokens. A new `ToolboxTool` member variable, `__required_authz_tokens`, now stores these alternatives. The tool invocation logic correctly uses this to check if any matching token has been provided. This new member variable is also updated correctly by the remaining authz tokens while adding auth token getters, and validated right before tool invocation.
1 parent 5d49936 commit c9f8d20

File tree

4 files changed

+46
-23
lines changed

4 files changed

+46
-23
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,9 @@ def __parse_tool(
7979
else: # regular parameter
8080
params.append(p)
8181

82-
authn_params, _, used_auth_keys = identify_required_authn_params(
83-
# TODO: Add schema.authRequired as second arg
82+
authn_params, authz_tokens, used_auth_keys = identify_required_authn_params(
8483
authn_params,
85-
[],
84+
schema.authRequired,
8685
auth_token_getters.keys(),
8786
)
8887

@@ -94,6 +93,7 @@ def __parse_tool(
9493
# create a read-only values to prevent mutation
9594
params=tuple(params),
9695
required_authn_params=types.MappingProxyType(authn_params),
96+
required_authz_tokens=authz_tokens,
9797
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
9898
bound_params=types.MappingProxyType(bound_params),
9999
client_headers=types.MappingProxyType(client_headers),

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
description: str,
5050
params: Sequence[ParameterSchema],
5151
required_authn_params: Mapping[str, list[str]],
52+
required_authz_tokens: Sequence[str],
5253
auth_service_token_getters: Mapping[str, Callable[[], str]],
5354
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
5455
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
@@ -63,12 +64,14 @@ def __init__(
6364
name: The name of the remote tool.
6465
description: The description of the remote tool.
6566
params: The args of the tool.
66-
required_authn_params: A map of required authenticated parameters to a list
67-
of alternative services that can provide values for them.
68-
auth_service_token_getters: A dict of authService -> token (or callables that
69-
produce a token)
70-
bound_params: A mapping of parameter names to bind to specific values or
71-
callables that are called to produce values as needed.
67+
required_authn_params: A map of required authenticated parameters to
68+
a list of alternative services that can provide values for them.
69+
required_authz_tokens: A sequence of alternative services for
70+
providing authorization token for the tool invocation.
71+
auth_service_token_getters: A dict of authService -> token (or
72+
callables that produce a token)
73+
bound_params: A mapping of parameter names to bind to specific
74+
values or callables that are called to produce values as needed.
7275
client_headers: Client specific headers bound to the tool.
7376
"""
7477
# used to invoke the toolbox API
@@ -106,6 +109,8 @@ def __init__(
106109

107110
# map of parameter name to auth service required by it
108111
self.__required_authn_params = required_authn_params
112+
# sequence of authorization tokens required by it
113+
self.__required_authz_tokens = required_authz_tokens
109114
# map of authService -> token_getter
110115
self.__auth_service_token_getters = auth_service_token_getters
111116
# map of parameter name to value (or callable that produces that value)
@@ -149,6 +154,7 @@ def __copy(
149154
description: Optional[str] = None,
150155
params: Optional[Sequence[ParameterSchema]] = None,
151156
required_authn_params: Optional[Mapping[str, list[str]]] = None,
157+
required_authz_tokens: Optional[Sequence[str]] = None,
152158
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
153159
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
154160
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
@@ -162,12 +168,14 @@ def __copy(
162168
name: The name of the remote tool.
163169
description: The description of the remote tool.
164170
params: The args of the tool.
165-
required_authn_params: A map of required authenticated parameters to a list
166-
of alternative services that can provide values for them.
167-
auth_service_token_getters: A dict of authService -> token (or callables
168-
that produce a token)
169-
bound_params: A mapping of parameter names to bind to specific values or
170-
callables that are called to produce values as needed.
171+
required_authn_params: A map of required authenticated parameters to
172+
a list of alternative services that can provide values for them.
173+
required_authz_tokens: A sequence of alternative services for
174+
providing authorization token for the tool invocation.
175+
auth_service_token_getters: A dict of authService -> token (or
176+
callables that produce a token)
177+
bound_params: A mapping of parameter names to bind to specific
178+
values or callables that are called to produce values as needed.
171179
client_headers: Client specific headers bound to the tool.
172180
"""
173181
check = lambda val, default: val if val is not None else default
@@ -180,6 +188,9 @@ def __copy(
180188
required_authn_params=check(
181189
required_authn_params, self.__required_authn_params
182190
),
191+
required_authz_tokens=check(
192+
required_authz_tokens, self.__required_authz_tokens
193+
),
183194
auth_service_token_getters=check(
184195
auth_service_token_getters, self.__auth_service_token_getters
185196
),
@@ -207,11 +218,15 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
207218
"""
208219

209220
# check if any auth services need to be specified yet
210-
if len(self.__required_authn_params) > 0:
221+
if (
222+
len(self.__required_authn_params) > 0
223+
or len(self.__required_authz_tokens) > 0
224+
):
211225
# Gather all the required auth services into a set
212226
req_auth_services = set()
213227
for s in self.__required_authn_params.values():
214228
req_auth_services.update(s)
229+
req_auth_services.update(self.__required_authz_tokens)
215230
raise ValueError(
216231
f"One or more of the following authn services are required to invoke this tool"
217232
f": {','.join(req_auth_services)}"
@@ -297,18 +312,20 @@ def add_auth_token_getters(
297312
dict(self.__auth_service_token_getters, **auth_token_getters)
298313
)
299314
# create a read-only updated for params that are still required
300-
new_req_authn_params = MappingProxyType(
315+
new_req_authn_params, new_req_authz_tokens, used_auth_token_getters = (
301316
identify_required_authn_params(
302-
# TODO: Add authRequired
303317
self.__required_authn_params,
304-
[],
318+
self.__required_authz_tokens,
305319
auth_token_getters.keys(),
306-
)[0]
320+
)
307321
)
308322

323+
# TODO: Add validation for used_auth_token_getters
324+
309325
return self.__copy(
310326
auth_service_token_getters=new_getters,
311-
required_authn_params=new_req_authn_params,
327+
required_authn_params=MappingProxyType(new_req_authn_params),
328+
required_authz_tokens=new_req_authz_tokens,
312329
)
313330

314331
def bind_params(

packages/toolbox-core/src/toolbox_core/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) -
4646

4747
def identify_required_authn_params(
4848
req_authn_params: Mapping[str, list[str]],
49-
req_authz_tokens: list[str],
49+
req_authz_tokens: Sequence[str],
5050
auth_service_names: Iterable[str],
5151
) -> tuple[dict[str, list[str]], list[str], set[str]]:
5252
"""
@@ -100,7 +100,7 @@ def identify_required_authn_params(
100100
if matched_authz_services:
101101
used_services.update(matched_authz_services)
102102
else:
103-
required_authz_tokens = req_authz_tokens
103+
required_authz_tokens = list(req_authz_tokens)
104104

105105
return required_authn_params, required_authz_tokens, used_services
106106

packages/toolbox-core/tests/test_tool.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ async def test_tool_creation_callable_and_run(
205205
description=sample_tool_description,
206206
params=sample_tool_params,
207207
required_authn_params={},
208+
required_authz_tokens=[],
208209
auth_service_token_getters={},
209210
bound_params={},
210211
client_headers={},
@@ -250,6 +251,7 @@ async def test_tool_run_with_pydantic_validation_error(
250251
description=sample_tool_description,
251252
params=sample_tool_params,
252253
required_authn_params={},
254+
required_authz_tokens=[],
253255
auth_service_token_getters={},
254256
bound_params={},
255257
client_headers={},
@@ -337,6 +339,7 @@ def test_tool_init_basic(http_session, sample_tool_params, sample_tool_descripti
337339
description=sample_tool_description,
338340
params=sample_tool_params,
339341
required_authn_params={},
342+
required_authz_tokens=[],
340343
auth_service_token_getters={},
341344
bound_params={},
342345
client_headers={},
@@ -361,6 +364,7 @@ def test_tool_init_with_client_headers(
361364
description=sample_tool_description,
362365
params=sample_tool_params,
363366
required_authn_params={},
367+
required_authz_tokens=[],
364368
auth_service_token_getters={},
365369
bound_params={},
366370
client_headers=static_client_header,
@@ -388,6 +392,7 @@ def test_tool_init_header_auth_conflict(
388392
description=sample_tool_description,
389393
params=sample_tool_auth_params,
390394
required_authn_params={},
395+
required_authz_tokens=[],
391396
auth_service_token_getters=auth_getters,
392397
bound_params={},
393398
client_headers=conflicting_client_header,
@@ -410,6 +415,7 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header(
410415
description=sample_tool_description,
411416
params=sample_tool_params,
412417
required_authn_params={},
418+
required_authz_tokens=[],
413419
auth_service_token_getters={},
414420
bound_params={},
415421
client_headers={

0 commit comments

Comments
 (0)