Skip to content

Commit 744ade9

Browse files
committed
feat: add support for bound parameters
1 parent 61d32aa commit 744ade9

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,22 @@ def __parse_tool(
5858
name: str,
5959
schema: ToolSchema,
6060
auth_token_getters: dict[str, Callable[[], str]],
61+
all_bound_params: dict[str, Callable[[], str]],
6162
) -> ToolboxTool:
6263
"""Internal helper to create a callable tool from its schema."""
63-
# sort into authenticated and reg params
64+
# sort into reg, authn, and bound params
6465
params = []
6566
authn_params: dict[str, list[str]] = {}
67+
bound_params: dict[str, Callable[[], str]] = {}
6668
auth_sources: set[str] = set()
6769
for p in schema.parameters:
68-
if not p.authSources:
69-
params.append(p)
70-
else:
70+
if p.authSources: # authn parameter
7171
authn_params[p.name] = p.authSources
7272
auth_sources.update(p.authSources)
73+
elif p.name in all_bound_params: # bound parameter
74+
bound_params[p.name] = all_bound_params[p.name]
75+
else: # regular parameter
76+
params.append(p)
7377

7478
authn_params = filter_required_authn_params(authn_params, auth_sources)
7579

@@ -80,7 +84,8 @@ def __parse_tool(
8084
desc=schema.description,
8185
params=[p.to_param() for p in params],
8286
required_authn_params=types.MappingProxyType(authn_params),
83-
auth_service_token_getters=auth_token_getters,
87+
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
88+
bound_params=types.MappingProxyType(bound_params),
8489
)
8590
return tool
8691

@@ -120,6 +125,7 @@ async def load_tool(
120125
self,
121126
name: str,
122127
auth_token_getters: dict[str, Callable[[], str]] = {},
128+
bound_params: dict[str, Callable[[], str]] = {},
123129
) -> ToolboxTool:
124130
"""
125131
Asynchronously loads a tool from the server.
@@ -150,14 +156,17 @@ async def load_tool(
150156
if name not in manifest.tools:
151157
# TODO: Better exception
152158
raise Exception(f"Tool '{name}' not found!")
153-
tool = self.__parse_tool(name, manifest.tools[name], auth_token_getters)
159+
tool = self.__parse_tool(
160+
name, manifest.tools[name], auth_token_getters, bound_params
161+
)
154162

155163
return tool
156164

157165
async def load_toolset(
158166
self,
159167
name: str,
160168
auth_token_getters: dict[str, Callable[[], str]] = {},
169+
bound_params: dict[str, Callable[[], str]] = {},
161170
) -> list[ToolboxTool]:
162171
"""
163172
Asynchronously fetches a toolset and loads all tools defined within it.
@@ -168,6 +177,7 @@ async def load_toolset(
168177
callables that return the corresponding authentication token.
169178
170179
180+
171181
Returns:
172182
list[ToolboxTool]: A list of callables, one for each tool defined
173183
in the toolset.
@@ -180,7 +190,7 @@ async def load_toolset(
180190

181191
# parse each tools name and schema into a list of ToolboxTools
182192
tools = [
183-
self.__parse_tool(n, s, auth_token_getters)
193+
self.__parse_tool(n, s, auth_token_getters, bound_params)
184194
for n, s in manifest.tools.items()
185195
]
186196
return tools

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,20 @@
1313
# limitations under the License.
1414

1515

16+
import asyncio
1617
import types
1718
from collections import defaultdict
1819
from inspect import Parameter, Signature
19-
from typing import Any, Callable, DefaultDict, Iterable, Mapping, Optional, Sequence
20+
from typing import (
21+
Any,
22+
Callable,
23+
DefaultDict,
24+
Iterable,
25+
Mapping,
26+
Optional,
27+
Sequence,
28+
Union,
29+
)
2030

2131
from aiohttp import ClientSession
2232
from pytest import Session
@@ -44,6 +54,7 @@ def __init__(
4454
params: Sequence[Parameter],
4555
required_authn_params: Mapping[str, list[str]],
4656
auth_service_token_getters: Mapping[str, Callable[[], str]],
57+
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
4758
):
4859
"""
4960
Initializes a callable that will trigger the tool invocation through the
@@ -81,6 +92,8 @@ def __init__(
8192
self.__required_authn_params = required_authn_params
8293
# map of authService -> token_getter
8394
self.__auth_service_token_getters = auth_service_token_getters
95+
# map of parameter name to value or Callable
96+
self.__bound_parameters = bound_params
8497

8598
def __copy(
8699
self,
@@ -91,6 +104,7 @@ def __copy(
91104
params: Optional[list[Parameter]] = None,
92105
required_authn_params: Optional[Mapping[str, list[str]]] = None,
93106
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
107+
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
94108
) -> "ToolboxTool":
95109
"""
96110
Creates a copy of the ToolboxTool, overriding specific fields.
@@ -117,6 +131,7 @@ def __copy(
117131
required_authn_params=required_authn_params or self.__required_authn_params,
118132
auth_service_token_getters=auth_service_token_getters
119133
or self.__auth_service_token_getters,
134+
bound_params=bound_params or self.__bound_parameters,
120135
)
121136

122137
async def __call__(self, *args: Any, **kwargs: Any) -> str:
@@ -146,6 +161,14 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
146161
all_args.apply_defaults() # Include default values if not provided
147162
payload = all_args.arguments
148163

164+
# apply bounded parameters
165+
for param, value in self.__bound_parameters.items():
166+
if asyncio.iscoroutinefunction(value):
167+
value = await value()
168+
elif callable(value):
169+
value = value()
170+
payload[param] = value
171+
149172
# create headers for auth services
150173
headers = {}
151174
for auth_service, token_getter in self.__auth_service_token_getters.items():
@@ -202,12 +225,41 @@ def add_auth_token_getters(
202225
required_authn_params=new_req_authn_params,
203226
)
204227

228+
def bind_parameters(
229+
self, bound_params: Mapping[str, Callable[[], str]]
230+
) -> "ToolboxTool":
231+
"""
232+
Binds parameters to values or callables that produce values.
233+
234+
Args:
235+
bound_params: A mapping of parameter names to values or callables that
236+
produce values.
237+
238+
Returns:
239+
A new ToolboxTool instance with the specified parameters bound.
240+
"""
241+
all_params = set(p.name for p in self.__params)
242+
for name in bound_params.keys():
243+
if name not in all_params:
244+
raise Exception(f"unable to bind parameters: no parameter named {name}")
245+
246+
new_params = []
247+
for p in self.__params:
248+
if p.name not in bound_params:
249+
new_params.append(p)
250+
251+
return self.__copy(
252+
params=new_params,
253+
bound_params=bound_params,
254+
)
255+
205256

206257
def filter_required_authn_params(
207258
req_authn_params: Mapping[str, list[str]], auth_services: Iterable[str]
208259
) -> dict[str, list[str]]:
209260
"""
210-
Utility function for reducing 'req_authn_params' to a subset of parameters that aren't supplied by a least one service in auth_services.
261+
Utility function for reducing 'req_authn_params' to a subset of parameters that
262+
aren't supplied by a least one service in auth_services.
211263
212264
Args:
213265
req_authn_params: A mapping of parameter names to sets of required
@@ -216,8 +268,8 @@ def filter_required_authn_params(
216268
token getters are available.
217269
218270
Returns:
219-
A new dictionary representing the subset of required authentication
220-
parameters that are not covered by the provided `auth_services`.
271+
A new dictionary representing the subset of required authentication parameters
272+
that are not covered by the provided `auth_services`.
221273
"""
222274
req_params = {}
223275
for param, services in req_authn_params.items():

0 commit comments

Comments
 (0)