Skip to content

Commit bcba462

Browse files
committed
chore: address more feedback
1 parent c1ac2cd commit bcba462

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from aiohttp import ClientSession
1919

2020
from .protocol import ManifestSchema, ToolSchema
21-
from .tool import ToolboxTool, filter_required_authn_params
21+
from .tool import ToolboxTool, identify_required_authn_params
2222

2323

2424
class ToolboxClient:
@@ -72,7 +72,9 @@ def __parse_tool(
7272
authn_params[p.name] = p.authSources
7373
auth_sources.update(p.authSources)
7474

75-
authn_params = filter_required_authn_params(authn_params, auth_sources)
75+
authn_params = identify_required_authn_params(
76+
authn_params, auth_token_getters.keys()
77+
)
7678

7779
tool = ToolboxTool(
7880
session=self.__session,

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,10 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
140140

141141
# check if any auth services need to be specified yet
142142
if len(self.__required_authn_params) > 0:
143-
req_auth_services = set(l for l in self.__required_authn_params.keys())
143+
# Gather all the required auth services into a set
144+
req_auth_services = set()
145+
for s in self.__required_authn_params.values():
146+
req_auth_services.update(s)
144147
raise Exception(
145148
f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}"
146149
)
@@ -184,10 +187,12 @@ def add_auth_token_getters(
184187
"""
185188

186189
# throw an error if the authentication source is already registered
187-
dupes = auth_token_getters.keys() & self.__auth_service_token_getters.keys()
188-
if dupes:
190+
existing_services = self.__auth_service_token_getters.keys()
191+
incoming_services = auth_token_getters.keys()
192+
duplicates = existing_services & incoming_services
193+
if duplicates:
189194
raise ValueError(
190-
f"Authentication source(s) `{', '.join(dupes)}` already registered in tool `{self.__name__}`."
195+
f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`."
191196
)
192197

193198
# create a read-only updated value for new_getters
@@ -196,7 +201,7 @@ def add_auth_token_getters(
196201
)
197202
# create a read-only updated for params that are still required
198203
new_req_authn_params = types.MappingProxyType(
199-
filter_required_authn_params(
204+
identify_required_authn_params(
200205
self.__required_authn_params, auth_token_getters.keys()
201206
)
202207
)
@@ -207,12 +212,12 @@ def add_auth_token_getters(
207212
)
208213

209214

210-
def filter_required_authn_params(
215+
def identify_required_authn_params(
211216
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
212217
) -> dict[str, list[str]]:
213218
"""
214-
Utility function for reducing 'req_authn_params' to a subset of parameters that
215-
aren't supplied by at least one service in auth_services.
219+
Identifies authentication parameters that are still required; or not covered by
220+
the provided `auth_service_names`.
216221
217222
Args:
218223
req_authn_params: A mapping of parameter names to sets of required

0 commit comments

Comments
 (0)