Skip to content

Commit 6626f69

Browse files
authored
Merge branch 'main' into twisha-auth-methods
2 parents 308764a + 5d49936 commit 6626f69

File tree

9 files changed

+1183
-114
lines changed

9 files changed

+1183
-114
lines changed

packages/toolbox-core/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ specific tool instance.
369369
toolbox = ToolboxClient("http://127.0.0.1:5000")
370370
tool = await toolbox.load_tool("my-tool")
371371

372-
bound_tool = tool.bind_parameters({"param": "value"})
372+
bound_tool = tool.bind_params({"param": "value"})
373373
```
374374

375375
### Option B: Binding Parameters While Loading Tools

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,11 @@ def __parse_tool(
7979
else: # regular parameter
8080
params.append(p)
8181

82-
authn_params = identify_required_authn_params(
83-
authn_params, auth_token_getters.keys()
82+
authn_params, _, used_auth_keys = identify_required_authn_params(
83+
# TODO: Add schema.authRequired as second arg
84+
authn_params,
85+
[],
86+
auth_token_getters.keys(),
8487
)
8588

8689
tool = ToolboxTool(
@@ -97,9 +100,6 @@ def __parse_tool(
97100
)
98101

99102
used_bound_keys = set(bound_params.keys())
100-
used_auth_keys: set[str] = set()
101-
for required_sources in authn_params.values():
102-
used_auth_keys.update(required_sources)
103103

104104
return tool, used_auth_keys, used_bound_keys
105105

@@ -160,6 +160,9 @@ async def load_tool(
160160
for execution. The specific arguments and behavior of the callable
161161
depend on the tool itself.
162162
163+
Raises:
164+
ValueError: If the loaded tool instance fails to utilize at least
165+
one provided parameter or auth token (if any provided).
163166
"""
164167
# Resolve client headers
165168
resolved_headers = {
@@ -176,56 +179,131 @@ async def load_tool(
176179
# parse the provided definition to a tool
177180
if name not in manifest.tools:
178181
# TODO: Better exception
179-
raise Exception(f"Tool '{name}' not found!")
180-
tool, _, _ = self.__parse_tool(
182+
raise ValueError(f"Tool '{name}' not found!")
183+
tool, used_auth_keys, used_bound_keys = self.__parse_tool(
181184
name,
182185
manifest.tools[name],
183186
auth_token_getters,
184187
bound_params,
185188
self.__client_headers,
186189
)
187190

191+
provided_auth_keys = set(auth_token_getters.keys())
192+
provided_bound_keys = set(bound_params.keys())
193+
194+
unused_auth = provided_auth_keys - used_auth_keys
195+
unused_bound = provided_bound_keys - used_bound_keys
196+
197+
if unused_auth or unused_bound:
198+
error_messages = []
199+
if unused_auth:
200+
error_messages.append(f"unused auth tokens: {', '.join(unused_auth)}")
201+
if unused_bound:
202+
error_messages.append(
203+
f"unused bound parameters: {', '.join(unused_bound)}"
204+
)
205+
raise ValueError(
206+
f"Validation failed for tool '{name}': { '; '.join(error_messages) }."
207+
)
208+
188209
return tool
189210

190211
async def load_toolset(
191212
self,
192213
name: Optional[str] = None,
193214
auth_token_getters: dict[str, Callable[[], str]] = {},
194215
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
216+
strict: bool = False,
195217
) -> list[ToolboxTool]:
196218
"""
197219
Asynchronously fetches a toolset and loads all tools defined within it.
198220
199221
Args:
200-
name: Name of the toolset to load tools.
222+
name: Name of the toolset to load. If None, loads the default toolset.
201223
auth_token_getters: A mapping of authentication service names to
202224
callables that return the corresponding authentication token.
203225
bound_params: A mapping of parameter names to bind to specific values or
204226
callables that are called to produce values as needed.
227+
strict: If True, raises an error if *any* loaded tool instance fails
228+
to utilize at least one provided parameter or auth token (if any
229+
provided). If False (default), raises an error only if a
230+
user-provided parameter or auth token cannot be applied to *any*
231+
loaded tool across the set.
205232
206233
Returns:
207234
list[ToolboxTool]: A list of callables, one for each tool defined
208235
in the toolset.
236+
237+
Raises:
238+
ValueError: If validation fails based on the `strict` flag.
209239
"""
240+
210241
# Resolve client headers
211242
original_headers = self.__client_headers
212243
resolved_headers = {
213244
header_name: await resolve_value(original_headers[header_name])
214245
for header_name in original_headers
215246
}
216-
# Request the definition of the tool from the server
247+
# Request the definition of the toolset from the server
217248
url = f"{self.__base_url}/api/toolset/{name or ''}"
218249
async with self.__session.get(url, headers=resolved_headers) as response:
219250
json = await response.json()
220251
manifest: ManifestSchema = ManifestSchema(**json)
221252

222-
# parse each tools name and schema into a list of ToolboxTools
223-
tools = [
224-
self.__parse_tool(
225-
n, s, auth_token_getters, bound_params, self.__client_headers
226-
)[0]
227-
for n, s in manifest.tools.items()
228-
]
253+
tools: list[ToolboxTool] = []
254+
overall_used_auth_keys: set[str] = set()
255+
overall_used_bound_params: set[str] = set()
256+
provided_auth_keys = set(auth_token_getters.keys())
257+
provided_bound_keys = set(bound_params.keys())
258+
259+
# parse each tool's name and schema into a list of ToolboxTools
260+
for tool_name, schema in manifest.tools.items():
261+
tool, used_auth_keys, used_bound_keys = self.__parse_tool(
262+
tool_name,
263+
schema,
264+
auth_token_getters,
265+
bound_params,
266+
self.__client_headers,
267+
)
268+
tools.append(tool)
269+
270+
if strict:
271+
unused_auth = provided_auth_keys - used_auth_keys
272+
unused_bound = provided_bound_keys - used_bound_keys
273+
if unused_auth or unused_bound:
274+
error_messages = []
275+
if unused_auth:
276+
error_messages.append(
277+
f"unused auth tokens: {', '.join(unused_auth)}"
278+
)
279+
if unused_bound:
280+
error_messages.append(
281+
f"unused bound parameters: {', '.join(unused_bound)}"
282+
)
283+
raise ValueError(
284+
f"Validation failed for tool '{tool_name}': { '; '.join(error_messages) }."
285+
)
286+
else:
287+
overall_used_auth_keys.update(used_auth_keys)
288+
overall_used_bound_params.update(used_bound_keys)
289+
290+
unused_auth = provided_auth_keys - overall_used_auth_keys
291+
unused_bound = provided_bound_keys - overall_used_bound_params
292+
293+
if unused_auth or unused_bound:
294+
error_messages = []
295+
if unused_auth:
296+
error_messages.append(
297+
f"unused auth tokens could not be applied to any tool: {', '.join(unused_auth)}"
298+
)
299+
if unused_bound:
300+
error_messages.append(
301+
f"unused bound parameters could not be applied to any tool: {', '.join(unused_bound)}"
302+
)
303+
raise ValueError(
304+
f"Validation failed for toolset '{name or 'default'}': { '; '.join(error_messages) }."
305+
)
306+
229307
return tools
230308

231309
async def add_headers(

packages/toolbox-core/src/toolbox_core/sync_tool.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,19 +153,37 @@ def add_auth_token_getters(
153153
new_async_tool = self.__async_tool.add_auth_token_getters(auth_token_getters)
154154
return ToolboxSyncTool(new_async_tool, self.__loop, self.__thread)
155155

156-
def bind_parameters(
156+
def bind_params(
157157
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
158158
) -> "ToolboxSyncTool":
159159
"""
160160
Binds parameters to values or callables that produce values.
161161
162-
Args:
163-
bound_params: A mapping of parameter names to values or callables that
164-
produce values.
162+
Args:
163+
bound_params: A mapping of parameter names to values or callables that
164+
produce values.
165165
166-
Returns:
167-
A new ToolboxSyncTool instance with the specified parameters bound.
166+
Returns:
167+
A new ToolboxSyncTool instance with the specified parameters bound.
168168
"""
169169

170-
new_async_tool = self.__async_tool.bind_parameters(bound_params)
170+
new_async_tool = self.__async_tool.bind_params(bound_params)
171171
return ToolboxSyncTool(new_async_tool, self.__loop, self.__thread)
172+
173+
def bind_param(
174+
self,
175+
param_name: str,
176+
param_value: Union[Callable[[], Any], Any],
177+
) -> "ToolboxSyncTool":
178+
"""
179+
Binds a parameter to the value or callables that produce it.
180+
181+
Args:
182+
param_name: The name of the bound parameter.
183+
param_value: The value of the bound parameter, or a callable that
184+
returns the value.
185+
186+
Returns:
187+
A new ToolboxSyncTool instance with the specified parameter bound.
188+
"""
189+
return self.bind_params({param_name: param_value})

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
# Validate conflicting Headers/Auth Tokens
9595
request_header_names = client_headers.keys()
9696
auth_token_names = [
97-
auth_token_name + "_token"
97+
self.__get_auth_header(auth_token_name)
9898
for auth_token_name in auth_service_token_getters.keys()
9999
]
100100
duplicates = request_header_names & auth_token_names
@@ -187,6 +187,10 @@ def __copy(
187187
client_headers=check(client_headers, self.__client_headers),
188188
)
189189

190+
def __get_auth_header(self, auth_token_name: str) -> str:
191+
"""Returns the formatted auth token header name."""
192+
return f"{auth_token_name}_token"
193+
190194
async def __call__(self, *args: Any, **kwargs: Any) -> str:
191195
"""
192196
Asynchronously calls the remote tool with the provided arguments.
@@ -208,7 +212,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
208212
req_auth_services = set()
209213
for s in self.__required_authn_params.values():
210214
req_auth_services.update(s)
211-
raise Exception(
215+
raise ValueError(
212216
f"One or more of the following authn services are required to invoke this tool"
213217
f": {','.join(req_auth_services)}"
214218
)
@@ -228,7 +232,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
228232
# create headers for auth services
229233
headers = {}
230234
for auth_service, token_getter in self.__auth_service_token_getters.items():
231-
headers[f"{auth_service}_token"] = await resolve_value(token_getter)
235+
headers[self.__get_auth_header(auth_service)] = await resolve_value(
236+
token_getter
237+
)
232238
for client_header_name, client_header_val in self.__client_headers.items():
233239
headers[client_header_name] = await resolve_value(client_header_val)
234240

@@ -276,7 +282,8 @@ def add_auth_token_getters(
276282
# Validate duplicates with client headers
277283
request_header_names = self.__client_headers.keys()
278284
auth_token_names = [
279-
auth_token_name + "_token" for auth_token_name in incoming_services
285+
self.__get_auth_header(auth_token_name)
286+
for auth_token_name in incoming_services
280287
]
281288
duplicates = request_header_names & auth_token_names
282289
if duplicates:
@@ -292,32 +299,42 @@ def add_auth_token_getters(
292299
# create a read-only updated for params that are still required
293300
new_req_authn_params = MappingProxyType(
294301
identify_required_authn_params(
295-
self.__required_authn_params, auth_token_getters.keys()
296-
)
302+
# TODO: Add authRequired
303+
self.__required_authn_params,
304+
[],
305+
auth_token_getters.keys(),
306+
)[0]
297307
)
298308

299309
return self.__copy(
300310
auth_service_token_getters=new_getters,
301311
required_authn_params=new_req_authn_params,
302312
)
303313

304-
def bind_parameters(
314+
def bind_params(
305315
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
306316
) -> "ToolboxTool":
307317
"""
308318
Binds parameters to values or callables that produce values.
309319
310-
Args:
311-
bound_params: A mapping of parameter names to values or callables that
312-
produce values.
320+
Args:
321+
bound_params: A mapping of parameter names to values or callables that
322+
produce values.
313323
314-
Returns:
315-
A new ToolboxTool instance with the specified parameters bound.
324+
Returns:
325+
A new ToolboxTool instance with the specified parameters bound.
316326
"""
317327
param_names = set(p.name for p in self.__params)
318328
for name in bound_params.keys():
329+
if name in self.__bound_parameters:
330+
raise ValueError(
331+
f"cannot re-bind parameter: parameter '{name}' is already bound"
332+
)
333+
319334
if name not in param_names:
320-
raise Exception(f"unable to bind parameters: no parameter named {name}")
335+
raise ValueError(
336+
f"unable to bind parameters: no parameter named {name}"
337+
)
321338

322339
new_params = []
323340
for p in self.__params:
@@ -330,3 +347,21 @@ def bind_parameters(
330347
params=new_params,
331348
bound_params=MappingProxyType(all_bound_params),
332349
)
350+
351+
def bind_param(
352+
self,
353+
param_name: str,
354+
param_value: Union[Callable[[], Any], Any],
355+
) -> "ToolboxTool":
356+
"""
357+
Binds a parameter to the value or callables that produce it.
358+
359+
Args:
360+
param_name: The name of the bound parameter.
361+
param_value: The value of the bound parameter, or a callable that
362+
returns the value.
363+
364+
Returns:
365+
A new ToolboxTool instance with the specified parameters bound.
366+
"""
367+
return self.bind_params({param_name: param_value})

0 commit comments

Comments
 (0)