Skip to content

Commit 4adda1b

Browse files
committed
remove toolschema usage
1 parent 57843e9 commit 4adda1b

File tree

2 files changed

+33
-40
lines changed

2 files changed

+33
-40
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ def __parse_tool(
7878
authn_params = identify_required_authn_params(
7979
authn_params, auth_token_getters.keys()
8080
)
81-
schema.parameters = params
8281

8382
tool = ToolboxTool(
8483
session=self.__session,
8584
base_url=self.__base_url,
8685
name=name,
87-
schema=schema,
86+
description=schema.description,
87+
params=params,
8888
# create a read-only values for the maps to prevent mutation
8989
required_authn_params=types.MappingProxyType(authn_params),
9090
auth_service_token_getters=types.MappingProxyType(auth_token_getters),

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515

1616
import asyncio
17-
import copy
1817
import types
1918
from inspect import Signature
2019
from typing import (
@@ -23,12 +22,13 @@
2322
Iterable,
2423
Mapping,
2524
Optional,
25+
Sequence,
2626
Union,
2727
)
2828

2929
from aiohttp import ClientSession
3030

31-
from toolbox_core.protocol import ToolSchema
31+
from toolbox_core.protocol import ParameterSchema
3232

3333

3434
class ToolboxTool:
@@ -49,7 +49,8 @@ def __init__(
4949
session: ClientSession,
5050
base_url: str,
5151
name: str,
52-
schema: ToolSchema,
52+
description: str,
53+
params: Sequence[ParameterSchema],
5354
required_authn_params: Mapping[str, list[str]],
5455
auth_service_token_getters: Mapping[str, Callable[[], str]],
5556
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
@@ -62,29 +63,28 @@ def __init__(
6263
session: The `aiohttp.ClientSession` used for making API requests.
6364
base_url: The base URL of the Toolbox server API.
6465
name: The name of the remote tool.
65-
schema: The schema of the tool.
66+
description: The description of the remote tool.
67+
params: The args of the tool.
6668
required_authn_params: A dict of required authenticated parameters to a list
6769
of services that provide values for them.
6870
auth_service_token_getters: A dict of authService -> token (or callables that
6971
produce a token)
7072
bound_params: A mapping of parameter names to bind to specific values or
7173
callables that are called to produce values as needed.
72-
7374
"""
74-
7575
# used to invoke the toolbox API
7676
self.__session: ClientSession = session
7777
self.__base_url: str = base_url
7878
self.__url = f"{base_url}/api/tool/{name}/invoke"
79-
80-
self.__params = [param.to_param() for param in schema.parameters]
81-
self.__schema = schema
79+
self.__description = description
80+
self.__params = params
81+
inspect_type_params = [param.to_param() for param in self.__params]
8282

8383
# the following properties are set to help anyone that might inspect it determine usage
8484
self.__name__ = name
85-
self.__doc__ = self._schema_to_docstring(self.__schema)
86-
self.__signature__ = Signature(parameters=self.__params, return_annotation=str)
87-
self.__annotations__ = {p.name: p.annotation for p in self.__params}
85+
self.__doc__ = self._create_docstring()
86+
self.__signature__ = Signature(parameters=inspect_type_params, return_annotation=str)
87+
self.__annotations__ = {p.name: p.annotation for p in inspect_type_params}
8888
# TODO: self.__qualname__ ??
8989

9090
# map of parameter name to auth service required by it
@@ -94,14 +94,13 @@ def __init__(
9494
# map of parameter name to value (or callable that produces that value)
9595
self.__bound_parameters = bound_params
9696

97-
@staticmethod
98-
def _schema_to_docstring(schema: ToolSchema) -> str:
97+
def _create_docstring(self) -> str:
9998
"""Convert a tool schema into its function docstring"""
100-
docstring = schema.description
101-
if not schema.parameters:
99+
docstring = self.__description
100+
if not self.__params:
102101
return docstring
103102
docstring += "\n\nArgs:"
104-
for p in schema.parameters:
103+
for p in self.__params:
105104
docstring += (
106105
f"\n {p.name} ({p.to_param().annotation.__name__}): {p.description}"
107106
)
@@ -112,7 +111,8 @@ def __copy(
112111
session: Optional[ClientSession] = None,
113112
base_url: Optional[str] = None,
114113
name: Optional[str] = None,
115-
schema: Optional[ToolSchema] = None,
114+
description: Optional[str] = None,
115+
params: Optional[Sequence[ParameterSchema]] = None,
116116
required_authn_params: Optional[Mapping[str, list[str]]] = None,
117117
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
118118
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
@@ -124,7 +124,8 @@ def __copy(
124124
session: The `aiohttp.ClientSession` used for making API requests.
125125
base_url: The base URL of the Toolbox server API.
126126
name: The name of the remote tool.
127-
schema: The schema of the tool.
127+
description: The description of the remote tool.
128+
params: The args of the tool.
128129
required_authn_params: A dict of required authenticated parameters that need
129130
a auth_service_token_getter set for them yet.
130131
auth_service_token_getters: A dict of authService -> token (or callables
@@ -138,7 +139,8 @@ def __copy(
138139
session=check(session, self.__session),
139140
base_url=check(base_url, self.__base_url),
140141
name=check(name, self.__name__),
141-
schema=check(schema, self.__schema),
142+
description=check(description, self.__description),
143+
params=check(params, self.__params),
142144
required_authn_params=check(
143145
required_authn_params, self.__required_authn_params
144146
),
@@ -239,14 +241,7 @@ def add_auth_token_getters(
239241
)
240242
)
241243

242-
# Update tool params in schema
243-
new_schema = copy.deepcopy(self.__schema)
244-
for param in new_schema.parameters:
245-
if param.name in auth_token_getters.keys():
246-
new_schema.parameters.remove(param)
247-
248244
return self.__copy(
249-
schema=new_schema,
250245
auth_service_token_getters=new_getters,
251246
required_authn_params=new_req_authn_params,
252247
)
@@ -264,25 +259,23 @@ def bind_parameters(
264259
Returns:
265260
A new ToolboxTool instance with the specified parameters bound.
266261
"""
267-
param_names = set(p.name for p in self.__schema.parameters)
262+
param_names = set(p.name for p in self.__params)
268263
for name in bound_params.keys():
269264
if name not in param_names:
270265
raise Exception(f"unable to bind parameters: no parameter named {name}")
271266

272-
# Update tool params in schema
273-
new_schema = copy.deepcopy(self.__schema)
274-
for param in new_schema.parameters:
275-
if param.name in bound_params:
276-
new_schema.parameters.remove(param)
267+
new_params = []
268+
for p in self.__params:
269+
if p.name not in bound_params:
270+
new_params.append(p)
271+
all_bound_params = dict(self.__bound_parameters)
272+
all_bound_params.update(bound_params)
277273

278274
return self.__copy(
279-
schema=new_schema,
280-
bound_params=types.MappingProxyType(
281-
dict(self.__bound_parameters, **bound_params)
282-
),
275+
params=new_params,
276+
bound_params=types.MappingProxyType(all_bound_params),
283277
)
284278

285-
286279
def identify_required_authn_params(
287280
req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str]
288281
) -> dict[str, list[str]]:

0 commit comments

Comments
 (0)