Skip to content

Commit 2dad7c8

Browse files
authored
feat: Enhance authorization token validation with authRequired (#222)
* fix: Correctly determine remaining required authz tokens This PR fixes an issue where the system could inaccurately identify the authorization tokens still needed for tool invocation. The `identify_required_authn_params` helper has been updated to leverage its new capability of recognizing all alternatives of required authorization tokens. A new `ToolboxTool` member variable, `__required_authz_tokens`, now stores these alternatives. The tool invocation logic correctly uses this to check if any matching token has been provided. This new member variable is also updated correctly by the remaining authz tokens while adding auth token getters, and validated right before tool invocation. * chore: Fix integration tests * chore: Update comments and convert required_authz_tokens to read-only This enhances readability and clarity. * chore: Delint * chore: Rename identify_required_authn_params to better reflect its updated functionality
1 parent 5d49936 commit 2dad7c8

File tree

7 files changed

+81
-59
lines changed

7 files changed

+81
-59
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from .protocol import ManifestSchema, ToolSchema
2020
from .tool import ToolboxTool
21-
from .utils import identify_required_authn_params, resolve_value
21+
from .utils import identify_auth_requirements, resolve_value
2222

2323

2424
class ToolboxClient:
@@ -79,10 +79,9 @@ def __parse_tool(
7979
else: # regular parameter
8080
params.append(p)
8181

82-
authn_params, _, used_auth_keys = identify_required_authn_params(
83-
# TODO: Add schema.authRequired as second arg
82+
authn_params, authz_tokens, used_auth_keys = identify_auth_requirements(
8483
authn_params,
85-
[],
84+
schema.authRequired,
8685
auth_token_getters.keys(),
8786
)
8887

@@ -94,6 +93,7 @@ def __parse_tool(
9493
# create a read-only values to prevent mutation
9594
params=tuple(params),
9695
required_authn_params=types.MappingProxyType(authn_params),
96+
required_authz_tokens=authz_tokens,
9797
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
9898
bound_params=types.MappingProxyType(bound_params),
9999
client_headers=types.MappingProxyType(client_headers),

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .protocol import ParameterSchema
2323
from .utils import (
2424
create_func_docstring,
25-
identify_required_authn_params,
25+
identify_auth_requirements,
2626
params_to_pydantic_model,
2727
resolve_value,
2828
)
@@ -49,6 +49,7 @@ def __init__(
4949
description: str,
5050
params: Sequence[ParameterSchema],
5151
required_authn_params: Mapping[str, list[str]],
52+
required_authz_tokens: Sequence[str],
5253
auth_service_token_getters: Mapping[str, Callable[[], str]],
5354
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
5455
client_headers: Mapping[str, Union[Callable, Coroutine, str]],
@@ -63,12 +64,14 @@ def __init__(
6364
name: The name of the remote tool.
6465
description: The description of the remote tool.
6566
params: The args of the tool.
66-
required_authn_params: A map of required authenticated parameters to a list
67-
of alternative services that can provide values for them.
68-
auth_service_token_getters: A dict of authService -> token (or callables that
69-
produce a token)
70-
bound_params: A mapping of parameter names to bind to specific values or
71-
callables that are called to produce values as needed.
67+
required_authn_params: A map of required authenticated parameters to
68+
a list of alternative services that can provide values for them.
69+
required_authz_tokens: A sequence of alternative services for
70+
providing authorization token for the tool invocation.
71+
auth_service_token_getters: A dict of authService -> token (or
72+
callables that produce a token)
73+
bound_params: A mapping of parameter names to bind to specific
74+
values or callables that are called to produce values as needed.
7275
client_headers: Client specific headers bound to the tool.
7376
"""
7477
# used to invoke the toolbox API
@@ -106,6 +109,8 @@ def __init__(
106109

107110
# map of parameter name to auth service required by it
108111
self.__required_authn_params = required_authn_params
112+
# sequence of authorization tokens required by it
113+
self.__required_authz_tokens = required_authz_tokens
109114
# map of authService -> token_getter
110115
self.__auth_service_token_getters = auth_service_token_getters
111116
# map of parameter name to value (or callable that produces that value)
@@ -149,6 +154,7 @@ def __copy(
149154
description: Optional[str] = None,
150155
params: Optional[Sequence[ParameterSchema]] = None,
151156
required_authn_params: Optional[Mapping[str, list[str]]] = None,
157+
required_authz_tokens: Optional[Sequence[str]] = None,
152158
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
153159
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
154160
client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None,
@@ -162,12 +168,14 @@ def __copy(
162168
name: The name of the remote tool.
163169
description: The description of the remote tool.
164170
params: The args of the tool.
165-
required_authn_params: A map of required authenticated parameters to a list
166-
of alternative services that can provide values for them.
167-
auth_service_token_getters: A dict of authService -> token (or callables
168-
that produce a token)
169-
bound_params: A mapping of parameter names to bind to specific values or
170-
callables that are called to produce values as needed.
171+
required_authn_params: A map of required authenticated parameters to
172+
a list of alternative services that can provide values for them.
173+
required_authz_tokens: A sequence of alternative services for
174+
providing authorization token for the tool invocation.
175+
auth_service_token_getters: A dict of authService -> token (or
176+
callables that produce a token)
177+
bound_params: A mapping of parameter names to bind to specific
178+
values or callables that are called to produce values as needed.
171179
client_headers: Client specific headers bound to the tool.
172180
"""
173181
check = lambda val, default: val if val is not None else default
@@ -180,6 +188,9 @@ def __copy(
180188
required_authn_params=check(
181189
required_authn_params, self.__required_authn_params
182190
),
191+
required_authz_tokens=check(
192+
required_authz_tokens, self.__required_authz_tokens
193+
),
183194
auth_service_token_getters=check(
184195
auth_service_token_getters, self.__auth_service_token_getters
185196
),
@@ -207,11 +218,15 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
207218
"""
208219

209220
# check if any auth services need to be specified yet
210-
if len(self.__required_authn_params) > 0:
221+
if (
222+
len(self.__required_authn_params) > 0
223+
or len(self.__required_authz_tokens) > 0
224+
):
211225
# Gather all the required auth services into a set
212226
req_auth_services = set()
213227
for s in self.__required_authn_params.values():
214228
req_auth_services.update(s)
229+
req_auth_services.update(self.__required_authz_tokens)
215230
raise ValueError(
216231
f"One or more of the following authn services are required to invoke this tool"
217232
f": {','.join(req_auth_services)}"
@@ -292,23 +307,24 @@ def add_auth_token_getters(
292307
f"Cannot register client the same headers in the client as well as tool."
293308
)
294309

295-
# create a read-only updated value for new_getters
296-
new_getters = MappingProxyType(
297-
dict(self.__auth_service_token_getters, **auth_token_getters)
298-
)
299-
# create a read-only updated for params that are still required
300-
new_req_authn_params = MappingProxyType(
301-
identify_required_authn_params(
302-
# TODO: Add authRequired
310+
new_getters = dict(self.__auth_service_token_getters, **auth_token_getters)
311+
312+
# find the updated requirements
313+
new_req_authn_params, new_req_authz_tokens, used_auth_token_getters = (
314+
identify_auth_requirements(
303315
self.__required_authn_params,
304-
[],
316+
self.__required_authz_tokens,
305317
auth_token_getters.keys(),
306-
)[0]
318+
)
307319
)
308320

321+
# TODO: Add validation for used_auth_token_getters
322+
309323
return self.__copy(
310-
auth_service_token_getters=new_getters,
311-
required_authn_params=new_req_authn_params,
324+
# create a read-only map for updated getters, params and tokens that are still required
325+
auth_service_token_getters=MappingProxyType(new_getters),
326+
required_authn_params=MappingProxyType(new_req_authn_params),
327+
required_authz_tokens=tuple(new_req_authz_tokens),
312328
)
313329

314330
def bind_params(

packages/toolbox-core/src/toolbox_core/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) -
4444
return docstring
4545

4646

47-
def identify_required_authn_params(
47+
def identify_auth_requirements(
4848
req_authn_params: Mapping[str, list[str]],
49-
req_authz_tokens: list[str],
49+
req_authz_tokens: Sequence[str],
5050
auth_service_names: Iterable[str],
5151
) -> tuple[dict[str, list[str]], list[str], set[str]]:
5252
"""
@@ -100,7 +100,7 @@ def identify_required_authn_params(
100100
if matched_authz_services:
101101
used_services.update(matched_authz_services)
102102
else:
103-
required_authz_tokens = req_authz_tokens
103+
required_authz_tokens = list(req_authz_tokens)
104104

105105
return required_authn_params, required_authz_tokens, used_services
106106

packages/toolbox-core/tests/test_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ async def test_run_tool_no_auth(self, toolbox: ToolboxClient):
148148
tool = await toolbox.load_tool("get-row-by-id-auth")
149149
with pytest.raises(
150150
Exception,
151-
match="tool invocation not authorized. Please make sure your specify correct auth headers",
151+
match="One or more of the following authn services are required to invoke this tool: my-test-auth",
152152
):
153153
await tool(id="2")
154154

packages/toolbox-core/tests/test_sync_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_run_tool_no_auth(self, toolbox: ToolboxSyncClient):
130130
tool = toolbox.load_tool("get-row-by-id-auth")
131131
with pytest.raises(
132132
Exception,
133-
match="tool invocation not authorized. Please make sure your specify correct auth headers",
133+
match="One or more of the following authn services are required to invoke this tool: my-test-auth",
134134
):
135135
tool(id="2")
136136

packages/toolbox-core/tests/test_tool.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ async def test_tool_creation_callable_and_run(
205205
description=sample_tool_description,
206206
params=sample_tool_params,
207207
required_authn_params={},
208+
required_authz_tokens=[],
208209
auth_service_token_getters={},
209210
bound_params={},
210211
client_headers={},
@@ -250,6 +251,7 @@ async def test_tool_run_with_pydantic_validation_error(
250251
description=sample_tool_description,
251252
params=sample_tool_params,
252253
required_authn_params={},
254+
required_authz_tokens=[],
253255
auth_service_token_getters={},
254256
bound_params={},
255257
client_headers={},
@@ -337,6 +339,7 @@ def test_tool_init_basic(http_session, sample_tool_params, sample_tool_descripti
337339
description=sample_tool_description,
338340
params=sample_tool_params,
339341
required_authn_params={},
342+
required_authz_tokens=[],
340343
auth_service_token_getters={},
341344
bound_params={},
342345
client_headers={},
@@ -361,6 +364,7 @@ def test_tool_init_with_client_headers(
361364
description=sample_tool_description,
362365
params=sample_tool_params,
363366
required_authn_params={},
367+
required_authz_tokens=[],
364368
auth_service_token_getters={},
365369
bound_params={},
366370
client_headers=static_client_header,
@@ -388,6 +392,7 @@ def test_tool_init_header_auth_conflict(
388392
description=sample_tool_description,
389393
params=sample_tool_auth_params,
390394
required_authn_params={},
395+
required_authz_tokens=[],
391396
auth_service_token_getters=auth_getters,
392397
bound_params={},
393398
client_headers=conflicting_client_header,
@@ -410,6 +415,7 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header(
410415
description=sample_tool_description,
411416
params=sample_tool_params,
412417
required_authn_params={},
418+
required_authz_tokens=[],
413419
auth_service_token_getters={},
414420
bound_params={},
415421
client_headers={

0 commit comments

Comments
 (0)