Skip to content

Commit c1ac2cd

Browse files
committed
chore: address feedback
1 parent c1a482a commit c1ac2cd

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import re
1415
import types
1516
from typing import Any, Callable, Optional
1617

@@ -79,8 +80,9 @@ def __parse_tool(
7980
name=name,
8081
desc=schema.description,
8182
params=[p.to_param() for p in params],
83+
# create a read-only values for the maps to prevent mutation
8284
required_authn_params=types.MappingProxyType(authn_params),
83-
auth_service_token_getters=auth_token_getters,
85+
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
8486
)
8587
return tool
8688

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ def __init__(
5656
desc: The description of the remote tool (used as its docstring).
5757
params: A list of `inspect.Parameter` objects defining the tool's
5858
arguments and their types/defaults.
59-
required_authn_params: A dict of required authenticated parameters that
60-
need a auth_service_token_getter set for them yet.
61-
auth_service_tokens: A dict of authService -> token (or callables that
59+
required_authn_params: A dict of required authenticated parameters to a list
60+
of services that provide values for them.
61+
auth_service_token_getters: A dict of authService -> token (or callables that
6262
produce a token)
6363
"""
6464

@@ -108,15 +108,19 @@ def __copy(
108108
that produce a token)
109109
110110
"""
111+
check = lambda val, default: val if val is not None else default
111112
return ToolboxTool(
112-
session=session or self.__session,
113-
base_url=base_url or self.__base_url,
114-
name=name or self.__name__,
115-
desc=desc or self.__desc,
116-
params=params or self.__params,
117-
required_authn_params=required_authn_params or self.__required_authn_params,
118-
auth_service_token_getters=auth_service_token_getters
119-
or self.__auth_service_token_getters,
113+
session=check(session, self.__session),
114+
base_url=check(base_url, self.__base_url),
115+
name=check(name, self.__name__),
116+
desc=check(desc, self.__desc),
117+
params=check(params, self.__params),
118+
required_authn_params=check(
119+
required_authn_params, self.__required_authn_params
120+
),
121+
auth_service_token_getters=check(
122+
auth_service_token_getters, self.__auth_service_token_getters
123+
),
120124
)
121125

122126
async def __call__(self, *args: Any, **kwargs: Any) -> str:
@@ -138,7 +142,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
138142
if len(self.__required_authn_params) > 0:
139143
req_auth_services = set(l for l in self.__required_authn_params.keys())
140144
raise Exception(
141-
f"One of more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}"
145+
f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}"
142146
)
143147

144148
# validate inputs to this call using the signature
@@ -167,7 +171,7 @@ def add_auth_token_getters(
167171
auth_token_getters: Mapping[str, Callable[[], str]],
168172
) -> "ToolboxTool":
169173
"""
170-
Registers a auth token getter function that is used for AuthServices when tools
174+
Registers an auth token getter function that is used for AuthServices when tools
171175
are invoked.
172176
173177
Args:
@@ -204,25 +208,27 @@ def add_auth_token_getters(
204208

205209

206210
def filter_required_authn_params(
207-
req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str]
211+
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
208212
) -> dict[str, list[str]]:
209213
"""
210-
Utility function for reducing 'req_authn_params' to a subset of parameters that aren't supplied by a least one service in auth_services.
214+
Utility function for reducing 'req_authn_params' to a subset of parameters that
215+
aren't supplied by at least one service in auth_services.
211216
212217
Args:
213218
req_authn_params: A mapping of parameter names to sets of required
214219
authentication services.
215-
auth_services: An iterable of authentication service names for which
220+
auth_service_names: An iterable of authentication service names for which
216221
token getters are available.
217222
218223
Returns:
219224
A new dictionary representing the subset of required authentication
220-
parameters that are not covered by the provided `auth_services`.
225+
parameters that are not covered by the provided `auth_service_names`.
221226
"""
222-
req_params = {}
227+
required_params = {} # params that are still required with provided auth_services
223228
for param, services in req_authn_params.items():
224-
# if we don't have a token_getter for any of the services required by the param, the param is still required
225-
required = not any(s in services for s in auth_services)
229+
# if we don't have a token_getter for any of the services required by the param,
230+
# the param is still required
231+
required = not any(s in services for s in auth_service_names)
226232
if required:
227-
req_params[param] = services
228-
return req_params
233+
required_params[param] = services
234+
return required_params

0 commit comments

Comments
 (0)