Skip to content

Commit 6d49fde

Browse files
committed
fix: Add validation to ensure added auth token getters are used by the tool
Previously, the `ToolboxTool.add_auth_token_getters` method only validated against existing registered getters or conflicts with client headers. It did not verify if *all* the auth token getters provided were actually used or required by the specific tool instance they were being added to. This PR enhances the validation in `add_auth_token_getters`. It now leverages the `used_auth_token_getters` information returned by the existing `identify_required_authn_params` call. This allows the method to confirm that every getter passed in is genuinely required by the tool, raising an error if any are unused. This ensures that only relevant auth token getters are attempted to be registered for a tool, preventing misconfigurations and human errors. > [!NOTE] > This validation aligns with the existing validation logic already present in the `ToolboxClient.load_tool` method, promoting a consistent approach to handling auth token getter requirements across the codebase.
1 parent e664e96 commit 6d49fde

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def add_auth_token_getters(
311311
new_getters = MappingProxyType(
312312
dict(self.__auth_service_token_getters, **auth_token_getters)
313313
)
314-
# create a read-only updated for params that are still required
314+
# find the updated required authn params and the auth token getters used
315315
new_req_authn_params, new_req_authz_tokens, used_auth_token_getters = (
316316
identify_required_authn_params(
317317
self.__required_authn_params,
@@ -320,10 +320,16 @@ def add_auth_token_getters(
320320
)
321321
)
322322

323-
# TODO: Add validation for used_auth_token_getters
323+
# ensure no auth token getter provided remains unused
324+
unused_auth = set(incoming_services) - used_auth_token_getters
325+
if unused_auth:
326+
raise ValueError(
327+
f"Authentication source(s) `{', '.join(unused_auth)}` unused by tool `{self.__name__}`."
328+
)
324329

325330
return self.__copy(
326331
auth_service_token_getters=new_getters,
332+
# create a read-only map for params that are still required
327333
required_authn_params=MappingProxyType(new_req_authn_params),
328334
required_authz_tokens=new_req_authz_tokens,
329335
)

packages/toolbox-core/tests/test_tool.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
1416
import inspect
15-
from typing import AsyncGenerator, Callable
17+
from typing import AsyncGenerator, Callable, Mapping
1618
from unittest.mock import AsyncMock, Mock
1719

1820
import pytest
@@ -92,6 +94,12 @@ def auth_header_key() -> str:
9294
return "test-auth_token"
9395

9496

97+
@pytest.fixture
98+
def unused_auth_getters() -> dict[str, Callable[[], str]]:
99+
"""Provides an auth getter for a service not required by sample_tool."""
100+
return {"unused-auth-service": lambda: "unused-token-value"}
101+
102+
95103
def test_create_func_docstring_one_param_real_schema():
96104
"""
97105
Tests create_func_docstring with one real ParameterSchema instance.
@@ -432,3 +440,32 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header(
432440

433441
with pytest.raises(ValueError, match=expected_error_message):
434442
tool_instance.add_auth_token_getters(new_auth_getters_causing_conflict)
443+
444+
445+
def test_add_auth_token_getters_unused_token(
446+
http_session: ClientSession,
447+
sample_tool_params: list[ParameterSchema],
448+
sample_tool_description: str,
449+
unused_auth_getters: Mapping[str, Callable[[], str]],
450+
):
451+
"""
452+
Tests ValueError when add_auth_token_getters is called with a getter for
453+
an unused authentication service.
454+
"""
455+
tool_instance = ToolboxTool(
456+
session=http_session,
457+
base_url=TEST_BASE_URL,
458+
name=TEST_TOOL_NAME,
459+
description=sample_tool_description,
460+
params=sample_tool_params,
461+
required_authn_params={},
462+
required_authz_tokens=[],
463+
auth_service_token_getters={},
464+
bound_params={},
465+
client_headers={},
466+
)
467+
468+
expected_error_message = "Authentication source\(s\) \`unused-auth-service\` unused by tool \`sample_tool\`."
469+
470+
with pytest.raises(ValueError, match=expected_error_message):
471+
tool_instance.add_auth_token_getters(unused_auth_getters)

0 commit comments

Comments
 (0)