Skip to content

Commit c8491a9

Browse files
committed
feat: add support for bound parameters
1 parent bcba462 commit c8491a9

File tree

2 files changed

+72
-11
lines changed

2 files changed

+72
-11
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,22 @@ def __parse_tool(
5959
name: str,
6060
schema: ToolSchema,
6161
auth_token_getters: dict[str, Callable[[], str]],
62+
all_bound_params: dict[str, Callable[[], str]],
6263
) -> ToolboxTool:
6364
"""Internal helper to create a callable tool from its schema."""
64-
# sort into authenticated and reg params
65+
# sort into reg, authn, and bound params
6566
params = []
6667
authn_params: dict[str, list[str]] = {}
68+
bound_params: dict[str, Callable[[], str]] = {}
6769
auth_sources: set[str] = set()
6870
for p in schema.parameters:
69-
if not p.authSources:
70-
params.append(p)
71-
else:
71+
if p.authSources: # authn parameter
7272
authn_params[p.name] = p.authSources
7373
auth_sources.update(p.authSources)
74+
elif p.name in all_bound_params: # bound parameter
75+
bound_params[p.name] = all_bound_params[p.name]
76+
else: # regular parameter
77+
params.append(p)
7478

7579
authn_params = identify_required_authn_params(
7680
authn_params, auth_token_getters.keys()
@@ -85,6 +89,7 @@ def __parse_tool(
8589
# create a read-only values for the maps to prevent mutation
8690
required_authn_params=types.MappingProxyType(authn_params),
8791
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
92+
bound_params=types.MappingProxyType(bound_params),
8893
)
8994
return tool
9095

@@ -124,6 +129,7 @@ async def load_tool(
124129
self,
125130
name: str,
126131
auth_token_getters: dict[str, Callable[[], str]] = {},
132+
bound_params: dict[str, Callable[[], str]] = {},
127133
) -> ToolboxTool:
128134
"""
129135
Asynchronously loads a tool from the server.
@@ -154,14 +160,17 @@ async def load_tool(
154160
if name not in manifest.tools:
155161
# TODO: Better exception
156162
raise Exception(f"Tool '{name}' not found!")
157-
tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters)
163+
tool = self.__parse_tool(
164+
name, manifest.tools[name], auth_token_getters, bound_params
165+
)
158166

159167
return tool
160168

161169
async def load_toolset(
162170
self,
163171
name: str,
164172
auth_token_getters: dict[str, Callable[[], str]] = {},
173+
bound_params: dict[str, Callable[[], str]] = {},
165174
) -> list[ToolboxTool]:
166175
"""
167176
Asynchronously fetches a toolset and loads all tools defined within it.
@@ -172,6 +181,7 @@ async def load_toolset(
172181
callables that return the corresponding authentication token.
173182
174183
184+
175185
Returns:
176186
list[ToolboxTool]: A list of callables, one for each tool defined
177187
in the toolset.
@@ -184,7 +194,7 @@ async def load_toolset(
184194

185195
# parse each tools name and schema into a list of ToolboxTools
186196
tools = [
187-
self.__parse_tool(n, s, auth_token_getters)
197+
self.__parse_tool(n, s, auth_token_getters, bound_params)
188198
for n, s in manifest.tools.items()
189199
]
190200
return tools

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,20 @@
1313
# limitations under the License.
1414

1515

16+
import asyncio
1617
import types
1718
from collections import defaultdict
1819
from inspect import Parameter, Signature
19-
from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence
20+
from typing import (
21+
Any,
22+
Callable,
23+
DefaultDict,
24+
Iterable,
25+
Mapping,
26+
Optional,
27+
Sequence,
28+
Union,
29+
)
2030

2131
from aiohttp import ClientSession
2232
from pytest import Session
@@ -44,6 +54,7 @@ def __init__(
4454
params: Sequence[Parameter],
4555
required_authn_params: Mapping[str, list[str]],
4656
auth_service_token_getters: Mapping[str, Callable[[], str]],
57+
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
4758
):
4859
"""
4960
Initializes a callable that will trigger the tool invocation through the
@@ -81,6 +92,8 @@ def __init__(
8192
self.__required_authn_params = required_authn_params
8293
# map of authService -> token_getter
8394
self.__auth_service_token_getters = auth_service_token_getters
95+
# map of parameter name to value or Callable
96+
self.__bound_parameters = bound_params
8497

8598
def __copy(
8699
self,
@@ -91,6 +104,7 @@ def __copy(
91104
params: Optional[list[Parameter]] = None,
92105
required_authn_params: Optional[Mapping[str, list[str]]] = None,
93106
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
107+
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
94108
) -> "ToolboxTool":
95109
"""
96110
Creates a copy of the ToolboxTool, overriding specific fields.
@@ -121,6 +135,7 @@ def __copy(
121135
auth_service_token_getters=check(
122136
auth_service_token_getters, self.__auth_service_token_getters
123137
),
138+
bound_params=check(bound_params, self.__bound_parameters),
124139
)
125140

126141
async def __call__(self, *args: Any, **kwargs: Any) -> str:
@@ -153,6 +168,14 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
153168
all_args.apply_defaults() # Include default values if not provided
154169
payload = all_args.arguments
155170

171+
# apply bounded parameters
172+
for param, value in self.__bound_parameters.items():
173+
if asyncio.iscoroutinefunction(value):
174+
value = await value()
175+
elif callable(value):
176+
value = value()
177+
payload[param] = value
178+
156179
# create headers for auth services
157180
headers = {}
158181
for auth_service, token_getter in self.__auth_service_token_getters.items():
@@ -211,13 +234,41 @@ def add_auth_token_getters(
211234
required_authn_params=new_req_authn_params,
212235
)
213236

237+
def bind_parameters(
238+
self, bound_params: Mapping[str, Callable[[], str]]
239+
) -> "ToolboxTool":
240+
"""
241+
Binds parameters to values or callables that produce values.
242+
243+
Args:
244+
bound_params: A mapping of parameter names to values or callables that
245+
produce values.
246+
247+
Returns:
248+
A new ToolboxTool instance with the specified parameters bound.
249+
"""
250+
all_params = set(p.name for p in self.__params)
251+
for name in bound_params.keys():
252+
if name not in all_params:
253+
raise Exception(f"unable to bind parameters: no parameter named {name}")
254+
255+
new_params = []
256+
for p in self.__params:
257+
if p.name not in bound_params:
258+
new_params.append(p)
259+
260+
return self.__copy(
261+
params=new_params,
262+
bound_params=bound_params,
263+
)
264+
214265

215266
def identify_required_authn_params(
216267
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
217268
) -> dict[str, list[str]]:
218269
"""
219-
Identifies authentication parameters that are still required; or not covered by
220-
the provided `auth_service_names`.
270+
Identifies authentication parameters that are still required; because they
271+
not covered by the provided `auth_service_names`.
221272
222273
Args:
223274
req_authn_params: A mapping of parameter names to sets of required
@@ -226,8 +277,8 @@ def identify_required_authn_params(
226277
token getters are available.
227278
228279
Returns:
229-
A new dictionary representing the subset of required authentication
230-
parameters that are not covered by the provided `auth_service_names`.
280+
A new dictionary representing the subset of required authentication parameters
281+
that are not covered by the provided `auth_services`.
231282
"""
232283
required_params = {} # params that are still required with provided auth_services
233284
for param, services in req_authn_params.items():

0 commit comments

Comments
 (0)