Skip to content

Commit 527369c

Browse files
authored
feat: Track utilized auth services (#203)
* fix: Add the no parameter check back again. We will remove this once we actually implement the `strict` flag and centralize this functionality by moving this check to the tool's constructor in a future PR. * fix: Reverse the error conditions to avoid masking of the second error. * feat: Track utilized auth services This PR expands the `identify_required_authn_params` helper function. Previously, it only identified missing auth params based on requirements and available token getters. Now, the helper also returns a list of the auth token getters that were actually used during the identification process. > [!NOTE] > This enhancement is a preparatory step for implementing `strict` flag validation in an upcoming PR, allowing us to determine if all provided authentication methods were necessary. * chore: Delint * chore: Unpack tuple instead of index it This makes it more readable.
1 parent 906d0e0 commit 527369c

File tree

4 files changed

+46
-23
lines changed

4 files changed

+46
-23
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __parse_tool(
7979
else: # regular parameter
8080
params.append(p)
8181

82-
authn_params = identify_required_authn_params(
82+
authn_params, _ = identify_required_authn_params(
8383
authn_params, auth_token_getters.keys()
8484
)
8585

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def add_auth_token_getters(
293293
new_req_authn_params = MappingProxyType(
294294
identify_required_authn_params(
295295
self.__required_authn_params, auth_token_getters.keys()
296-
)
296+
)[0]
297297
)
298298

299299
return self.__copy(

packages/toolbox-core/src/toolbox_core/utils.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,29 +46,40 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) -
4646

4747
def identify_required_authn_params(
4848
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
49-
) -> dict[str, list[str]]:
49+
) -> tuple[dict[str, list[str]], set[str]]:
5050
"""
5151
Identifies authentication parameters that are still required; because they
52-
are not covered by the provided `auth_service_names`.
52+
are not covered by the provided `auth_service_names`, and also returns a
53+
set of all authentication services that were found to be matching.
5354
5455
Args:
55-
req_authn_params: A mapping of parameter names to sets of required
56+
req_authn_params: A mapping of parameter names to lists of required
5657
authentication services.
5758
auth_service_names: An iterable of authentication service names for which
5859
token getters are available.
5960
6061
Returns:
61-
A new dictionary representing the subset of required authentication parameters
62-
that are not covered by the provided `auth_service_names`.
62+
A tuple containing:
63+
- A new dictionary representing the subset of required
64+
authentication parameters that are not covered by the provided
65+
`auth_service_names`.
66+
- A list of authentication service names from `auth_service_names`
67+
that were found to satisfy at least one parameter's requirements.
6368
"""
64-
required_params = {} # params that are still required with provided auth_services
69+
required_params: dict[str, list[str]] = {}
70+
used_services: set[str] = set()
71+
6572
for param, services in req_authn_params.items():
6673
# if we don't have a token_getter for any of the services required by the param,
6774
# the param is still required
68-
required = not any(s in services for s in auth_service_names)
69-
if required:
75+
matched_services = [s for s in services if s in auth_service_names]
76+
77+
if matched_services:
78+
used_services.update(matched_services)
79+
else:
7080
required_params[param] = services
71-
return required_params
81+
82+
return required_params, used_services
7283

7384

7485
def params_to_pydantic_model(

packages/toolbox-core/tests/test_utils.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ def test_identify_required_authn_params_none_required():
8787
req_authn_params = {}
8888
auth_service_names = ["service_a", "service_b"]
8989
expected = {}
90-
assert (
91-
identify_required_authn_params(req_authn_params, auth_service_names) == expected
90+
expected_used = set()
91+
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
92+
expected,
93+
expected_used,
9294
)
9395

9496

@@ -100,8 +102,10 @@ def test_identify_required_authn_params_all_covered():
100102
}
101103
auth_service_names = ["service_a", "service_b"]
102104
expected = {}
103-
assert (
104-
identify_required_authn_params(req_authn_params, auth_service_names) == expected
105+
expected_used = set(auth_service_names)
106+
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
107+
expected,
108+
expected_used,
105109
)
106110

107111

@@ -118,8 +122,10 @@ def test_identify_required_authn_params_some_covered():
118122
"token_d": ["service_d"],
119123
"token_e": ["service_e", "service_f"],
120124
}
121-
assert (
122-
identify_required_authn_params(req_authn_params, auth_service_names) == expected
125+
expected_used = set(auth_service_names)
126+
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
127+
expected,
128+
expected_used,
123129
)
124130

125131

@@ -134,8 +140,10 @@ def test_identify_required_authn_params_none_covered():
134140
"token_d": ["service_d"],
135141
"token_e": ["service_e", "service_f"],
136142
}
137-
assert (
138-
identify_required_authn_params(req_authn_params, auth_service_names) == expected
143+
expected_used = set()
144+
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
145+
expected,
146+
expected_used,
139147
)
140148

141149

@@ -150,8 +158,10 @@ def test_identify_required_authn_params_no_available_services():
150158
"token_a": ["service_a"],
151159
"token_b": ["service_b", "service_c"],
152160
}
153-
assert (
154-
identify_required_authn_params(req_authn_params, auth_service_names) == expected
161+
expected_used = set()
162+
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
163+
expected,
164+
expected_used,
155165
)
156166

157167

@@ -164,8 +174,10 @@ def test_identify_required_authn_params_empty_services_for_param():
164174
expected = {
165175
"token_x": [],
166176
}
167-
assert (
168-
identify_required_authn_params(req_authn_params, auth_service_names) == expected
177+
expected_used = set()
178+
assert identify_required_authn_params(req_authn_params, auth_service_names) == (
179+
expected,
180+
expected_used,
169181
)
170182

171183

0 commit comments

Comments
 (0)