Skip to content

Commit adc8bc2

Browse files
committed
fix: Correctly determine remaining required authz tokens
This PR fixes an issue where the system could inaccurately identify the authorization tokens still needed for tool invocation. The `identify_required_authn_params` helper has been updated to leverage its new capability of recognizing all alternatives of required authorization tokens. A new `ToolboxTool` member variable, `__required_authz_tokens`, now stores these alternatives. The tool invocation logic correctly uses this to check if any matching token has been provided. This new member variable is also updated correctly by the remaining authz tokens while adding auth token getters, and validated right before tool invocation.
1 parent c7dbc3c commit adc8bc2

File tree

4 files changed

+46
-23
lines changed

4 files changed

+46
-23
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,9 @@ def __parse_tool(
7979
else: # regular parameter
8080
params.append(p)
8181

82-
authn_params, _, used_auth_keys = identify_required_authn_params(
83-
# TODO: Add schema.authRequired as second arg
82+
authn_params, authz_tokens, used_auth_keys = identify_required_authn_params(
8483
authn_params,
85-
[],
84+
schema.authRequired,
8685
auth_token_getters.keys(),
8786
)
8887

@@ -94,6 +93,7 @@ def __parse_tool(
9493
params=params,
9594
# create a read-only values for the maps to prevent mutation
9695
required_authn_params=types.MappingProxyType(authn_params),
96+
required_authz_tokens=authz_tokens,
9797
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
9898
bound_params=types.MappingProxyType(bound_params),
9999
client_headers=types.MappingProxyType(client_headers),

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
description: str,
5050
params: Sequence[ParameterSchema],
5151
required_authn_params: Mapping[str, list[str]],
52+
required_authz_tokens: Sequence[str],
5253
auth_service_token_getters: Mapping[str, Callable[[], str]],
5354
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
5455
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
@@ -63,12 +64,14 @@ def __init__(
6364
name: The name of the remote tool.
6465
description: The description of the remote tool.
6566
params: The args of the tool.
66-
required_authn_params: A map of required authenticated parameters to a list
67-
of alternative services that can provide values for them.
68-
auth_service_token_getters: A dict of authService -> token (or callables that
69-
produce a token)
70-
bound_params: A mapping of parameter names to bind to specific values or
71-
callables that are called to produce values as needed.
67+
required_authn_params: A map of required authenticated parameters to
68+
a list of alternative services that can provide values for them.
69+
required_authz_tokens: A sequence of alternative services for
70+
providing authorization token for the tool invocation.
71+
auth_service_token_getters: A dict of authService -> token (or
72+
callables that produce a token)
73+
bound_params: A mapping of parameter names to bind to specific
74+
values or callables that are called to produce values as needed.
7275
client_headers: Client specific headers bound to the tool.
7376
"""
7477
# used to invoke the toolbox API
@@ -106,6 +109,8 @@ def __init__(
106109

107110
# map of parameter name to auth service required by it
108111
self.__required_authn_params = required_authn_params
112+
# sequence of authorization tokens required by it
113+
self.__required_authz_tokens = required_authz_tokens
109114
# map of authService -> token_getter
110115
self.__auth_service_token_getters = auth_service_token_getters
111116
# map of parameter name to value (or callable that produces that value)
@@ -121,6 +126,7 @@ def __copy(
121126
description: Optional[str] = None,
122127
params: Optional[Sequence[ParameterSchema]] = None,
123128
required_authn_params: Optional[Mapping[str, list[str]]] = None,
129+
required_authz_tokens: Optional[Sequence[str]] = None,
124130
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
125131
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
126132
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
@@ -134,12 +140,14 @@ def __copy(
134140
name: The name of the remote tool.
135141
description: The description of the remote tool.
136142
params: The args of the tool.
137-
required_authn_params: A map of required authenticated parameters to a list
138-
of alternative services that can provide values for them.
139-
auth_service_token_getters: A dict of authService -> token (or callables
140-
that produce a token)
141-
bound_params: A mapping of parameter names to bind to specific values or
142-
callables that are called to produce values as needed.
143+
required_authn_params: A map of required authenticated parameters to
144+
a list of alternative services that can provide values for them.
145+
required_authz_tokens: A sequence of alternative services for
146+
providing authorization token for the tool invocation.
147+
auth_service_token_getters: A dict of authService -> token (or
148+
callables that produce a token)
149+
bound_params: A mapping of parameter names to bind to specific
150+
values or callables that are called to produce values as needed.
143151
client_headers: Client specific headers bound to the tool.
144152
"""
145153
check = lambda val, default: val if val is not None else default
@@ -152,6 +160,9 @@ def __copy(
152160
required_authn_params=check(
153161
required_authn_params, self.__required_authn_params
154162
),
163+
required_authz_tokens=check(
164+
required_authz_tokens, self.__required_authz_tokens
165+
),
155166
auth_service_token_getters=check(
156167
auth_service_token_getters, self.__auth_service_token_getters
157168
),
@@ -179,11 +190,15 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
179190
"""
180191

181192
# check if any auth services need to be specified yet
182-
if len(self.__required_authn_params) > 0:
193+
if (
194+
len(self.__required_authn_params) > 0
195+
or len(self.__required_authz_tokens) > 0
196+
):
183197
# Gather all the required auth services into a set
184198
req_auth_services = set()
185199
for s in self.__required_authn_params.values():
186200
req_auth_services.update(s)
201+
req_auth_services.update(self.__required_authz_tokens)
187202
raise ValueError(
188203
f"One or more of the following authn services are required to invoke this tool"
189204
f": {','.join(req_auth_services)}"
@@ -269,18 +284,20 @@ def add_auth_token_getters(
269284
dict(self.__auth_service_token_getters, **auth_token_getters)
270285
)
271286
# create a read-only updated for params that are still required
272-
new_req_authn_params = types.MappingProxyType(
287+
new_req_authn_params, new_req_authz_tokens, used_auth_token_getters = (
273288
identify_required_authn_params(
274-
# TODO: Add authRequired
275289
self.__required_authn_params,
276-
[],
290+
self.__required_authz_tokens,
277291
auth_token_getters.keys(),
278-
)[0]
292+
)
279293
)
280294

295+
# TODO: Add validation for used_auth_token_getters
296+
281297
return self.__copy(
282298
auth_service_token_getters=new_getters,
283-
required_authn_params=new_req_authn_params,
299+
required_authn_params=types.MappingProxyType(new_req_authn_params),
300+
required_authz_tokens=new_req_authz_tokens,
284301
)
285302

286303
def bind_params(

packages/toolbox-core/src/toolbox_core/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) -
4646

4747
def identify_required_authn_params(
4848
req_authn_params: Mapping[str, list[str]],
49-
req_authz_tokens: list[str],
49+
req_authz_tokens: Sequence[str],
5050
auth_service_names: Iterable[str],
5151
) -> tuple[dict[str, list[str]], list[str], set[str]]:
5252
"""
@@ -100,7 +100,7 @@ def identify_required_authn_params(
100100
if matched_authz_services:
101101
used_services.update(matched_authz_services)
102102
else:
103-
required_authz_tokens = req_authz_tokens
103+
required_authz_tokens = list(req_authz_tokens)
104104

105105
return required_authn_params, required_authz_tokens, used_services
106106

packages/toolbox-core/tests/test_tool.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ async def test_tool_creation_callable_and_run(
205205
description=sample_tool_description,
206206
params=sample_tool_params,
207207
required_authn_params={},
208+
required_authz_tokens=[],
208209
auth_service_token_getters={},
209210
bound_params={},
210211
client_headers={},
@@ -250,6 +251,7 @@ async def test_tool_run_with_pydantic_validation_error(
250251
description=sample_tool_description,
251252
params=sample_tool_params,
252253
required_authn_params={},
254+
required_authz_tokens=[],
253255
auth_service_token_getters={},
254256
bound_params={},
255257
client_headers={},
@@ -337,6 +339,7 @@ def test_tool_init_basic(http_session, sample_tool_params, sample_tool_descripti
337339
description=sample_tool_description,
338340
params=sample_tool_params,
339341
required_authn_params={},
342+
required_authz_tokens=[],
340343
auth_service_token_getters={},
341344
bound_params={},
342345
client_headers={},
@@ -361,6 +364,7 @@ def test_tool_init_with_client_headers(
361364
description=sample_tool_description,
362365
params=sample_tool_params,
363366
required_authn_params={},
367+
required_authz_tokens=[],
364368
auth_service_token_getters={},
365369
bound_params={},
366370
client_headers=static_client_header,
@@ -388,6 +392,7 @@ def test_tool_init_header_auth_conflict(
388392
description=sample_tool_description,
389393
params=sample_tool_auth_params,
390394
required_authn_params={},
395+
required_authz_tokens=[],
391396
auth_service_token_getters=auth_getters,
392397
bound_params={},
393398
client_headers=conflicting_client_header,
@@ -410,6 +415,7 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header(
410415
description=sample_tool_description,
411416
params=sample_tool_params,
412417
required_authn_params={},
418+
required_authz_tokens=[],
413419
auth_service_token_getters={},
414420
bound_params={},
415421
client_headers={

0 commit comments

Comments
 (0)