Skip to content

Commit 79782ee

Browse files
committed
chore: address feedback
1 parent de4f55c commit 79782ee

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import re
1515
import types
16-
from typing import Any, Callable, Optional
16+
from typing import Any, Callable, Mapping, Optional, Union
1717

1818
from aiohttp import ClientSession
1919

@@ -59,7 +59,7 @@ def __parse_tool(
5959
name: str,
6060
schema: ToolSchema,
6161
auth_token_getters: dict[str, Callable[[], str]],
62-
all_bound_params: dict[str, Callable[[], str]],
62+
all_bound_params: Mapping[str, Union[Callable[[], Any], Any]],
6363
) -> ToolboxTool:
6464
"""Internal helper to create a callable tool from its schema."""
6565
# sort into reg, authn, and bound params
@@ -129,7 +129,7 @@ async def load_tool(
129129
self,
130130
name: str,
131131
auth_token_getters: dict[str, Callable[[], str]] = {},
132-
bound_params: dict[str, Callable[[], str]] = {},
132+
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
133133
) -> ToolboxTool:
134134
"""
135135
Asynchronously loads a tool from the server.
@@ -142,6 +142,10 @@ async def load_tool(
142142
name: The unique name or identifier of the tool to load.
143143
auth_token_getters: A mapping of authentication service names to
144144
callables that return the corresponding authentication token.
145+
bound_params: A mapping of parameter names to bind to specific values or
146+
callables that are called to produce values as needed.
147+
148+
145149
146150
Returns:
147151
ToolboxTool: A callable object representing the loaded tool, ready
@@ -170,7 +174,7 @@ async def load_toolset(
170174
self,
171175
name: str,
172176
auth_token_getters: dict[str, Callable[[], str]] = {},
173-
bound_params: dict[str, Callable[[], str]] = {},
177+
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
174178
) -> list[ToolboxTool]:
175179
"""
176180
Asynchronously fetches a toolset and loads all tools defined within it.
@@ -179,6 +183,8 @@ async def load_toolset(
179183
name: Name of the toolset to load tools.
180184
auth_token_getters: A mapping of authentication service names to
181185
callables that return the corresponding authentication token.
186+
bound_params: A mapping of parameter names to bind to specific values or
187+
callables that are called to produce values as needed.
182188
183189
184190

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
params: Sequence[Parameter],
5555
required_authn_params: Mapping[str, list[str]],
5656
auth_service_token_getters: Mapping[str, Callable[[], str]],
57-
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
57+
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
5858
):
5959
"""
6060
Initializes a callable that will trigger the tool invocation through the
@@ -71,6 +71,9 @@ def __init__(
7171
of services that provide values for them.
7272
auth_service_token_getters: A dict of authService -> token (or callables that
7373
produce a token)
74+
bound_params: A mapping of parameter names to bind to specific values or
75+
callables that are called to produce values as needed.
76+
7477
"""
7578

7679
# used to invoke the toolbox API
@@ -92,7 +95,7 @@ def __init__(
9295
self.__required_authn_params = required_authn_params
9396
# map of authService -> token_getter
9497
self.__auth_service_token_getters = auth_service_token_getters
95-
# map of parameter name to value or Callable
98+
# map of parameter name to value (or callable that produces that value)
9699
self.__bound_parameters = bound_params
97100

98101
def __copy(
@@ -120,6 +123,8 @@ def __copy(
120123
a auth_service_token_getter set for them yet.
121124
auth_service_token_getters: A dict of authService -> token (or callables
122125
that produce a token)
126+
bound_params: A mapping of parameter names to bind to specific values or
127+
callables that are called to produce values as needed.
123128
124129
"""
125130
check = lambda val, default: val if val is not None else default
@@ -235,7 +240,7 @@ def add_auth_token_getters(
235240
)
236241

237242
def bind_parameters(
238-
self, bound_params: Mapping[str, Callable[[], str]]
243+
self, bound_params: Mapping[str, Union[Callable[[], Any], Any]]
239244
) -> "ToolboxTool":
240245
"""
241246
Binds parameters to values or callables that produce values.
@@ -247,9 +252,9 @@ def bind_parameters(
247252
Returns:
248253
A new ToolboxTool instance with the specified parameters bound.
249254
"""
250-
all_params = set(p.name for p in self.__params)
255+
param_names = set(p.name for p in self.__params)
251256
for name in bound_params.keys():
252-
if name not in all_params:
257+
if name not in param_names:
253258
raise Exception(f"unable to bind parameters: no parameter named {name}")
254259

255260
new_params = []
@@ -270,11 +275,11 @@ def identify_required_authn_params(
270275
Identifies authentication parameters that are still required; because they
271276
not covered by the provided `auth_service_names`.
272277
273-
Args:
274-
req_authn_params: A mapping of parameter names to sets of required
275-
authentication services.
276-
auth_service_names: An iterable of authentication service names for which
277-
token getters are available.
278+
Args:
279+
req_authn_params: A mapping of parameter names to sets of required
280+
authentication services.
281+
auth_service_names: An iterable of authentication service names for which
282+
token getters are available.
278283
279284
Returns:
280285
A new dictionary representing the subset of required authentication parameters

0 commit comments

Comments
 (0)