Skip to content

Commit 4d562bd

Browse files
committed
revert package files
1 parent 6dca7f9 commit 4d562bd

File tree

2 files changed

+70
-45
lines changed

2 files changed

+70
-45
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import re
1415
import types
15-
from typing import Any, Callable, Optional
16+
from typing import Any, Callable, Mapping, Optional, Union
1617

1718
from aiohttp import ClientSession
1819

1920
from .protocol import ManifestSchema, ToolSchema
20-
from .tool import ToolboxTool, filter_required_authn_params
21+
from .tool import ToolboxTool, identify_required_authn_params
2122

2223

2324
class ToolboxClient:
@@ -58,7 +59,7 @@ def __parse_tool(
5859
name: str,
5960
schema: ToolSchema,
6061
auth_token_getters: dict[str, Callable[[], str]],
61-
all_bound_params: dict[str, Callable[[], str]],
62+
all_bound_params: Mapping[str, Union[Callable[[], Any], Any]],
6263
) -> ToolboxTool:
6364
"""Internal helper to create a callable tool from its schema."""
6465
# sort into reg, authn, and bound params
@@ -75,14 +76,17 @@ def __parse_tool(
7576
else: # regular parameter
7677
params.append(p)
7778

78-
authn_params = filter_required_authn_params(authn_params, auth_sources)
79+
authn_params = identify_required_authn_params(
80+
authn_params, auth_token_getters.keys()
81+
)
7982

8083
tool = ToolboxTool(
8184
session=self.__session,
8285
base_url=self.__base_url,
8386
name=name,
8487
desc=schema.description,
8588
params=[p.to_param() for p in params],
89+
# create a read-only values for the maps to prevent mutation
8690
required_authn_params=types.MappingProxyType(authn_params),
8791
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
8892
bound_params=types.MappingProxyType(bound_params),
@@ -125,7 +129,7 @@ async def load_tool(
125129
self,
126130
name: str,
127131
auth_token_getters: dict[str, Callable[[], str]] = {},
128-
bound_params: dict[str, Callable[[], str]] = {},
132+
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
129133
) -> ToolboxTool:
130134
"""
131135
Asynchronously loads a tool from the server.
@@ -138,6 +142,10 @@ async def load_tool(
138142
name: The unique name or identifier of the tool to load.
139143
auth_token_getters: A mapping of authentication service names to
140144
callables that return the corresponding authentication token.
145+
bound_params: A mapping of parameter names to bind to specific values or
146+
callables that are called to produce values as needed.
147+
148+
141149
142150
Returns:
143151
ToolboxTool: A callable object representing the loaded tool, ready
@@ -166,7 +174,7 @@ async def load_toolset(
166174
self,
167175
name: str,
168176
auth_token_getters: dict[str, Callable[[], str]] = {},
169-
bound_params: dict[str, Callable[[], str]] = {},
177+
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
170178
) -> list[ToolboxTool]:
171179
"""
172180
Asynchronously fetches a toolset and loads all tools defined within it.
@@ -175,6 +183,8 @@ async def load_toolset(
175183
name: Name of the toolset to load tools.
176184
auth_token_getters: A mapping of authentication service names to
177185
callables that return the corresponding authentication token.
186+
bound_params: A mapping of parameter names to bind to specific values or
187+
callables that are called to produce values as needed.
178188
179189
180190
@@ -193,4 +203,4 @@ async def load_toolset(
193203
self.__parse_tool(n, s, auth_token_getters, bound_params)
194204
for n, s in manifest.tools.items()
195205
]
196-
return tools
206+
return tools

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
params: Sequence[Parameter],
5555
required_authn_params: Mapping[str, list[str]],
5656
auth_service_token_getters: Mapping[str, Callable[[], str]],
57-
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
57+
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
5858
):
5959
"""
6060
Initializes a callable that will trigger the tool invocation through the
@@ -67,10 +67,13 @@ def __init__(
6767
desc: The description of the remote tool (used as its docstring).
6868
params: A list of `inspect.Parameter` objects defining the tool's
6969
arguments and their types/defaults.
70-
required_authn_params: A dict of required authenticated parameters that
71-
need a auth_service_token_getter set for them yet.
72-
auth_service_tokens: A dict of authService -> token (or callables that
70+
required_authn_params: A dict of required authenticated parameters to a list
71+
of services that provide values for them.
72+
auth_service_token_getters: A dict of authService -> token (or callables that
7373
produce a token)
74+
bound_params: A mapping of parameter names to bind to specific values or
75+
callables that are called to produce values as needed.
76+
7477
"""
7578

7679
# used to invoke the toolbox API
@@ -92,7 +95,7 @@ def __init__(
9295
self.__required_authn_params = required_authn_params
9396
# map of authService -> token_getter
9497
self.__auth_service_token_getters = auth_service_token_getters
95-
# map of parameter name to value or Callable
98+
# map of parameter name to value (or callable that produces that value)
9699
self.__bound_parameters = bound_params
97100

98101
def __copy(
@@ -120,18 +123,24 @@ def __copy(
120123
a auth_service_token_getter set for them yet.
121124
auth_service_token_getters: A dict of authService -> token (or callables
122125
that produce a token)
126+
bound_params: A mapping of parameter names to bind to specific values or
127+
callables that are called to produce values as needed.
123128
124129
"""
130+
check = lambda val, default: val if val is not None else default
125131
return ToolboxTool(
126-
session=session or self.__session,
127-
base_url=base_url or self.__base_url,
128-
name=name or self.__name__,
129-
desc=desc or self.__desc,
130-
params=params or self.__params,
131-
required_authn_params=required_authn_params or self.__required_authn_params,
132-
auth_service_token_getters=auth_service_token_getters
133-
or self.__auth_service_token_getters,
134-
bound_params=bound_params or self.__bound_parameters,
132+
session=check(session, self.__session),
133+
base_url=check(base_url, self.__base_url),
134+
name=check(name, self.__name__),
135+
desc=check(desc, self.__desc),
136+
params=check(params, self.__params),
137+
required_authn_params=check(
138+
required_authn_params, self.__required_authn_params
139+
),
140+
auth_service_token_getters=check(
141+
auth_service_token_getters, self.__auth_service_token_getters
142+
),
143+
bound_params=check(bound_params, self.__bound_parameters),
135144
)
136145

137146
async def __call__(self, *args: Any, **kwargs: Any) -> str:
@@ -151,9 +160,12 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
151160

152161
# check if any auth services need to be specified yet
153162
if len(self.__required_authn_params) > 0:
154-
req_auth_services = set(l for l in self.__required_authn_params.keys())
163+
# Gather all the required auth services into a set
164+
req_auth_services = set()
165+
for s in self.__required_authn_params.values():
166+
req_auth_services.update(s)
155167
raise Exception(
156-
f"One of more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}"
168+
f"One or more of the following authn services are required to invoke this tool: {','.join(req_auth_services)}"
157169
)
158170

159171
# validate inputs to this call using the signature
@@ -190,7 +202,7 @@ def add_auth_token_getters(
190202
auth_token_getters: Mapping[str, Callable[[], str]],
191203
) -> "ToolboxTool":
192204
"""
193-
Registers a auth token getter function that is used for AuthServices when tools
205+
Registers an auth token getter function that is used for AuthServices when tools
194206
are invoked.
195207
196208
Args:
@@ -203,10 +215,12 @@ def add_auth_token_getters(
203215
"""
204216

205217
# throw an error if the authentication source is already registered
206-
dupes = auth_token_getters.keys() & self.__auth_service_token_getters.keys()
207-
if dupes:
218+
existing_services = self.__auth_service_token_getters.keys()
219+
incoming_services = auth_token_getters.keys()
220+
duplicates = existing_services & incoming_services
221+
if duplicates:
208222
raise ValueError(
209-
f"Authentication source(s) `{', '.join(dupes)}` already registered in tool `{self.__name__}`."
223+
f"Authentication source(s) `{', '.join(duplicates)}` already registered in tool `{self.__name__}`."
210224
)
211225

212226
# create a read-only updated value for new_getters
@@ -215,7 +229,7 @@ def add_auth_token_getters(
215229
)
216230
# create a read-only updated for params that are still required
217231
new_req_authn_params = types.MappingProxyType(
218-
filter_required_authn_params(
232+
identify_required_authn_params(
219233
self.__required_authn_params, auth_token_getters.keys()
220234
)
221235
)
@@ -226,7 +240,7 @@ def add_auth_token_getters(
226240
)
227241

228242
def bind_parameters(
229-
self, bound_params: Mapping[str, Callable[[], str]]
243+
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
230244
) -> "ToolboxTool":
231245
"""
232246
Binds parameters to values or callables that produce values.
@@ -238,9 +252,9 @@ def bind_parameters(
238252
Returns:
239253
A new ToolboxTool instance with the specified parameters bound.
240254
"""
241-
all_params = set(p.name for p in self.__params)
255+
param_names = set(p.name for p in self.__params)
242256
for name in bound_params.keys():
243-
if name not in all_params:
257+
if name not in param_names:
244258
raise Exception(f"unable to bind parameters: no parameter named {name}")
245259

246260
new_params = []
@@ -254,27 +268,28 @@ def bind_parameters(
254268
)
255269

256270

257-
def filter_required_authn_params(
258-
req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str]
271+
def identify_required_authn_params(
272+
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
259273
) -> dict[str, list[str]]:
260274
"""
261-
Utility function for reducing 'req_authn_params' to a subset of parameters that
262-
aren't supplied by a least one service in auth_services.
275+
Identifies authentication parameters that are still required; because they
276+
not covered by the provided `auth_service_names`.
263277
264-
Args:
265-
req_authn_params: A mapping of parameter names to sets of required
266-
authentication services.
267-
auth_services: An iterable of authentication service names for which
268-
token getters are available.
278+
Args:
279+
req_authn_params: A mapping of parameter names to sets of required
280+
authentication services.
281+
auth_service_names: An iterable of authentication service names for which
282+
token getters are available.
269283
270284
Returns:
271285
A new dictionary representing the subset of required authentication parameters
272286
that are not covered by the provided `auth_services`.
273287
"""
274-
req_params = {}
288+
required_params = {} # params that are still required with provided auth_services
275289
for param, services in req_authn_params.items():
276-
# if we don't have a token_getter for any of the services required by the param, the param is still required
277-
required = not any(s in services for s in auth_services)
290+
# if we don't have a token_getter for any of the services required by the param,
291+
# the param is still required
292+
required = not any(s in services for s in auth_service_names)
278293
if required:
279-
req_params[param] = services
280-
return req_params
294+
required_params[param] = services
295+
return required_params

0 commit comments

Comments
 (0)