Skip to content

Commit d9ec753

Browse files
committed
feat: Wrap toolbox-langchain's AsyncToolboxTool over toolbox-core's ToolboxTool.
This simplifies code and increases maintainability while removing duplicate code.
1 parent afebaa6 commit d9ec753

File tree

4 files changed

+428
-501
lines changed

4 files changed

+428
-501
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
# limitations under the License.
1414

1515

16-
import types
17-
from typing import Any, Callable, Mapping, Optional, Union
16+
from typing import Any, Callable, Mapping, Optional, Sequence, Union
1817

1918
from aiohttp import ClientSession
2019

21-
from .protocol import ManifestSchema, ToolSchema
22-
from .tool import ToolboxTool, identify_required_authn_params
20+
from .protocol import ManifestSchema, ParameterSchema, ToolSchema
21+
from .tool import ToolboxTool
2322

2423

2524
class ToolboxClient:
@@ -59,24 +58,26 @@ def __parse_tool(
5958
self,
6059
name: str,
6160
schema: ToolSchema,
62-
auth_token_getters: dict[str, Callable[[], str]],
61+
auth_token_getters: Mapping[str, Callable[[], str]],
6362
all_bound_params: Mapping[str, Union[Callable[[], Any], Any]],
63+
strict: bool,
6464
) -> ToolboxTool:
65-
"""Internal helper to create a callable tool from its schema."""
66-
# sort into reg, authn, and bound params
67-
params = []
68-
authn_params: dict[str, list[str]] = {}
69-
bound_params: dict[str, Callable[[], str]] = {}
70-
for p in schema.parameters:
71-
if p.authSources: # authn parameter
72-
authn_params[p.name] = p.authSources
73-
elif p.name in all_bound_params: # bound parameter
74-
bound_params[p.name] = all_bound_params[p.name]
75-
else: # regular parameter
76-
params.append(p)
77-
78-
authn_params = identify_required_authn_params(
79-
authn_params, auth_token_getters.keys()
65+
"""
66+
Internal helper to create a callable ToolboxTool from its schema.
67+
68+
Args:
69+
name: The name of the tool.
70+
schema: The ToolSchema defining the tool.
71+
auth_token_getters: Mapping of auth service names to token getters.
72+
all_bound_params: Mapping of all initially bound parameter names to values/callables.
73+
strict: The strictness setting for the created ToolboxTool instance.
74+
75+
Returns:
76+
An initialized ToolboxTool instance.
77+
"""
78+
79+
params: Sequence[ParameterSchema] = (
80+
schema.parameters if schema.parameters is not None else []
8081
)
8182

8283
tool = ToolboxTool(
@@ -85,10 +86,9 @@ def __parse_tool(
8586
name=name,
8687
description=schema.description,
8788
params=params,
88-
# create a read-only values for the maps to prevent mutation
89-
required_authn_params=types.MappingProxyType(authn_params),
90-
auth_service_token_getters=types.MappingProxyType(auth_token_getters),
91-
bound_params=types.MappingProxyType(bound_params),
89+
auth_service_token_getters=auth_token_getters,
90+
bound_params=all_bound_params,
91+
strict=strict,
9292
)
9393
return tool
9494

@@ -127,8 +127,9 @@ async def close(self):
127127
async def load_tool(
128128
self,
129129
name: str,
130-
auth_token_getters: dict[str, Callable[[], str]] = {},
130+
auth_token_getters: Mapping[str, Callable[[], str]] = {},
131131
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
132+
strict: bool = True,
132133
) -> ToolboxTool:
133134
"""
134135
Asynchronously loads a tool from the server.
@@ -143,8 +144,8 @@ async def load_tool(
143144
callables that return the corresponding authentication token.
144145
bound_params: A mapping of parameter names to bind to specific values or
145146
callables that are called to produce values as needed.
146-
147-
147+
strict: If True (default), the loaded tool instance will operate in
148+
strict validation mode. If False, it will be non-strict.
148149
149150
Returns:
150151
ToolboxTool: A callable object representing the loaded tool, ready
@@ -161,35 +162,38 @@ async def load_tool(
161162

162163
# parse the provided definition to a tool
163164
if name not in manifest.tools:
164-
# TODO: Better exception
165-
raise Exception(f"Tool '{name}' not found!")
165+
raise Exception(
166+
f"Tool '{name}' not found in the manifest received from {url}"
167+
)
166168
tool = self.__parse_tool(
167-
name, manifest.tools[name], auth_token_getters, bound_params
169+
name, manifest.tools[name], auth_token_getters, bound_params, strict
168170
)
169171

170172
return tool
171173

172174
async def load_toolset(
173175
self,
174176
name: Optional[str] = None,
175-
auth_token_getters: dict[str, Callable[[], str]] = {},
177+
auth_token_getters: Mapping[str, Callable[[], str]] = {},
176178
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
179+
strict: bool = True,
177180
) -> list[ToolboxTool]:
178181
"""
179182
Asynchronously fetches a toolset and loads all tools defined within it.
180183
181184
Args:
182-
name: Name of the toolset to load tools.
185+
name: Optional name of the toolset to load. If None, attempts to load
186+
the default toolset.
183187
auth_token_getters: A mapping of authentication service names to
184188
callables that return the corresponding authentication token.
185189
bound_params: A mapping of parameter names to bind to specific values or
186190
callables that are called to produce values as needed.
187-
188-
191+
strict: If True (default), all loaded tool instances will operate in
192+
strict validation mode. If False, they will be non-strict.
189193
190194
Returns:
191-
list[ToolboxTool]: A list of callables, one for each tool defined
192-
in the toolset.
195+
list[ToolboxTool]: A list of callables, one for each tool defined in
196+
the toolset.
193197
"""
194198
# Request the definition of the tool from the server
195199
url = f"{self.__base_url}/api/toolset/{name or ''}"
@@ -199,7 +203,7 @@ async def load_toolset(
199203

200204
# parse each tools name and schema into a list of ToolboxTools
201205
tools = [
202-
self.__parse_tool(n, s, auth_token_getters, bound_params)
206+
self.__parse_tool(n, s, auth_token_getters, bound_params, strict)
203207
for n, s in manifest.tools.items()
204208
]
205209
return tools

0 commit comments

Comments
 (0)