Skip to content

Commit f3c384f

Browse files
committed
fix: Add validation to ensure added auth token getters are used by the tool
Previously, the `ToolboxTool.add_auth_token_getters` method only validated against existing registered getters or conflicts with client headers. It did not verify if *all* the auth token getters provided were actually used or required by the specific tool instance they were being added to. This PR enhances the validation in `add_auth_token_getters`. It now leverages the `used_auth_token_getters` information returned by the existing `identify_required_authn_params` call. This allows the method to confirm that every getter passed in is genuinely required by the tool, raising an error if any are unused. This ensures that only relevant auth token getters are attempted to be registered for a tool, preventing misconfigurations and human errors. > [!NOTE] > This validation aligns with the existing validation logic already present in the `ToolboxClient.load_tool` method, promoting a consistent approach to handling auth token getter requirements across the codebase.
1 parent 2dad7c8 commit f3c384f

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def add_auth_token_getters(
309309

310310
new_getters = dict(self.__auth_service_token_getters, **auth_token_getters)
311311

312-
# find the updated requirements
312+
# find the updated auth requirements
313313
new_req_authn_params, new_req_authz_tokens, used_auth_token_getters = (
314314
identify_auth_requirements(
315315
self.__required_authn_params,
@@ -318,7 +318,12 @@ def add_auth_token_getters(
318318
)
319319
)
320320

321-
# TODO: Add validation for used_auth_token_getters
321+
# ensure no auth token getter provided remains unused
322+
unused_auth = set(incoming_services) - used_auth_token_getters
323+
if unused_auth:
324+
raise ValueError(
325+
f"Authentication source(s) `{', '.join(unused_auth)}` unused by tool `{self.__name__}`."
326+
)
322327

323328
return self.__copy(
324329
# create a read-only map for updated getters, params and tokens that are still required

packages/toolbox-core/tests/test_tool.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
1416
import inspect
15-
from typing import AsyncGenerator, Callable
17+
from typing import AsyncGenerator, Callable, Mapping
1618
from unittest.mock import AsyncMock, Mock
1719

1820
import pytest
@@ -92,6 +94,12 @@ def auth_header_key() -> str:
9294
return "test-auth_token"
9395

9496

97+
@pytest.fixture
98+
def unused_auth_getters() -> dict[str, Callable[[], str]]:
99+
"""Provides an auth getter for a service not required by sample_tool."""
100+
return {"unused-auth-service": lambda: "unused-token-value"}
101+
102+
95103
def test_create_func_docstring_one_param_real_schema():
96104
"""
97105
Tests create_func_docstring with one real ParameterSchema instance.
@@ -432,3 +440,32 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header(
432440

433441
with pytest.raises(ValueError, match=expected_error_message):
434442
tool_instance.add_auth_token_getters(new_auth_getters_causing_conflict)
443+
444+
445+
def test_add_auth_token_getters_unused_token(
446+
http_session: ClientSession,
447+
sample_tool_params: list[ParameterSchema],
448+
sample_tool_description: str,
449+
unused_auth_getters: Mapping[str, Callable[[], str]],
450+
):
451+
"""
452+
Tests ValueError when add_auth_token_getters is called with a getter for
453+
an unused authentication service.
454+
"""
455+
tool_instance = ToolboxTool(
456+
session=http_session,
457+
base_url=TEST_BASE_URL,
458+
name=TEST_TOOL_NAME,
459+
description=sample_tool_description,
460+
params=sample_tool_params,
461+
required_authn_params={},
462+
required_authz_tokens=[],
463+
auth_service_token_getters={},
464+
bound_params={},
465+
client_headers={},
466+
)
467+
468+
expected_error_message = "Authentication source\(s\) \`unused-auth-service\` unused by tool \`sample_tool\`."
469+
470+
with pytest.raises(ValueError, match=expected_error_message):
471+
tool_instance.add_auth_token_getters(unused_auth_getters)

0 commit comments

Comments
 (0)