Skip to content

Commit 65b729d

Browse files
committed
use and update schema
# Conflicts: # packages/toolbox-core/src/toolbox_core/tool.py
1 parent 10601aa commit 65b729d

File tree

2 files changed

+37
-46
lines changed

2 files changed

+37
-46
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,13 @@ def __parse_tool(
7979
authn_params = identify_required_authn_params(
8080
authn_params, auth_token_getters.keys()
8181
)
82+
schema.parameters = params
8283

8384
tool = ToolboxTool(
8485
session=self.__session,
8586
base_url=self.__base_url,
8687
name=name,
87-
desc=schema.description,
88-
params=[p.to_param() for p in params],
89-
params_metadata=types.MappingProxyType(
90-
{p.name: (p.type, p.description) for p in schema.parameters}
91-
),
88+
schema=schema,
9289
# create a read-only values for the maps to prevent mutation
9390
required_authn_params=types.MappingProxyType(authn_params),
9491
auth_service_token_getters=types.MappingProxyType(auth_token_getters),

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@
1414

1515

1616
import asyncio
17+
import copy
1718
import types
18-
from inspect import Parameter, Signature
19+
from inspect import Signature
1920
from typing import (
2021
Any,
2122
Callable,
2223
Iterable,
2324
Mapping,
2425
Optional,
25-
Sequence,
2626
Union,
2727
)
28-
28+
from toolbox_core.protocol import ToolSchema
2929
from aiohttp import ClientSession
3030

3131

@@ -47,9 +47,7 @@ def __init__(
4747
session: ClientSession,
4848
base_url: str,
4949
name: str,
50-
desc: str,
51-
params: Sequence[Parameter],
52-
params_metadata: Mapping[str, tuple[str, str]],
50+
schema: ToolSchema,
5351
required_authn_params: Mapping[str, list[str]],
5452
auth_service_token_getters: Mapping[str, Callable[[], str]],
5553
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
@@ -62,10 +60,7 @@ def __init__(
6260
session: The `aiohttp.ClientSession` used for making API requests.
6361
base_url: The base URL of the Toolbox server API.
6462
name: The name of the remote tool.
65-
desc: The description of the remote tool (used as its docstring).
66-
params: A list of `inspect.Parameter` objects defining the tool's
67-
arguments and their types/defaults.
68-
params_metadata: A mapping of param names to their types and descriptions.
63+
schema: The schema of the tool.
6964
required_authn_params: A dict of required authenticated parameters to a list
7065
of services that provide values for them.
7166
auth_service_token_getters: A dict of authService -> token (or callables that
@@ -80,15 +75,14 @@ def __init__(
8075
self.__base_url: str = base_url
8176
self.__url = f"{base_url}/api/tool/{name}/invoke"
8277

83-
self.__desc = desc
84-
self.__params = params
85-
self.__params_metadata = params_metadata
78+
self.__params = [param.to_param() for param in schema.parameters]
79+
self.__schema = schema
8680

8781
# the following properties are set to help anyone that might inspect it determine usage
8882
self.__name__ = name
89-
self.__doc__ = self._schema_to_docstring(desc, params, params_metadata)
90-
self.__signature__ = Signature(parameters=params, return_annotation=str)
91-
self.__annotations__ = {p.name: p.annotation for p in params}
83+
self.__doc__ = self._schema_to_docstring(self.__schema)
84+
self.__signature__ = Signature(parameters=self.__params, return_annotation=str)
85+
self.__annotations__ = {p.name: p.annotation for p in self.__params}
9286
# TODO: self.__qualname__ ??
9387

9488
# map of parameter name to auth service required by it
@@ -100,27 +94,23 @@ def __init__(
10094

10195
@staticmethod
10296
def _schema_to_docstring(
103-
tool_description: str,
104-
params: Sequence[Parameter],
105-
params_metadata: Mapping[str, tuple[str, str]],
97+
schema: ToolSchema
10698
) -> str:
107-
"""Creates a python function docstring from a tool and it's params."""
108-
docstring = tool_description
109-
if not params:
99+
"""Convert a tool schema into its function docstring"""
100+
docstring = schema.description
101+
if not schema.parameters:
110102
return docstring
111103
docstring += "\n\nArgs:"
112-
for p in params:
113-
param_metadata = params_metadata[p.name]
114-
docstring += f"\n {p.name} ({param_metadata[0]}): {param_metadata[1]}"
104+
for p in schema.parameters:
105+
docstring += f"\n {p.name} ({p.type}): {p.description}"
115106
return docstring
116107

117108
def __copy(
118109
self,
119110
session: Optional[ClientSession] = None,
120111
base_url: Optional[str] = None,
121112
name: Optional[str] = None,
122-
desc: Optional[str] = None,
123-
params: Optional[list[Parameter]] = None,
113+
schema: ToolSchema = None,
124114
required_authn_params: Optional[Mapping[str, list[str]]] = None,
125115
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
126116
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
@@ -132,9 +122,7 @@ def __copy(
132122
session: The `aiohttp.ClientSession` used for making API requests.
133123
base_url: The base URL of the Toolbox server API.
134124
name: The name of the remote tool.
135-
desc: The description of the remote tool (used as its docstring).
136-
params: A list of `inspect.Parameter` objects defining the tool's
137-
arguments and their types/defaults.
125+
schema: The schema of the tool.
138126
required_authn_params: A dict of required authenticated parameters that need
139127
a auth_service_token_getter set for them yet.
140128
auth_service_token_getters: A dict of authService -> token (or callables
@@ -148,9 +136,7 @@ def __copy(
148136
session=check(session, self.__session),
149137
base_url=check(base_url, self.__base_url),
150138
name=check(name, self.__name__),
151-
desc=check(desc, self.__desc),
152-
params=check(params, self.__params),
153-
params_metadata=self.__params_metadata,
139+
schema=check(schema, self.__schema),
154140
required_authn_params=check(
155141
required_authn_params, self.__required_authn_params
156142
),
@@ -251,7 +237,14 @@ def add_auth_token_getters(
251237
)
252238
)
253239

240+
# Update tool params in schema
241+
new_schema = copy.deepcopy(self.__schema)
242+
for param in new_schema.parameters:
243+
if param.name in auth_token_getters.keys():
244+
new_schema.parameters.remove(param)
245+
254246
return self.__copy(
247+
schema=new_schema,
255248
auth_service_token_getters=new_getters,
256249
required_authn_params=new_req_authn_params,
257250
)
@@ -269,19 +262,20 @@ def bind_parameters(
269262
Returns:
270263
A new ToolboxTool instance with the specified parameters bound.
271264
"""
272-
param_names = set(p.name for p in self.__params)
265+
param_names = set(p.name for p in self.__schema.parameters)
273266
for name in bound_params.keys():
274267
if name not in param_names:
275268
raise Exception(f"unable to bind parameters: no parameter named {name}")
276269

277-
new_params = []
278-
for p in self.__params:
279-
if p.name not in bound_params:
280-
new_params.append(p)
270+
# Update tool params in schema
271+
new_schema = copy.deepcopy(self.__schema)
272+
for param in new_schema.parameters:
273+
if param.name in bound_params:
274+
new_schema.parameters.remove(param)
281275

282276
return self.__copy(
283-
params=new_params,
284-
bound_params=bound_params,
277+
schema=new_schema,
278+
bound_params=types.MappingProxyType(dict(self.__bound_parameters, **bound_params))
285279
)
286280

287281

@@ -309,4 +303,4 @@ def identify_required_authn_params(
309303
required = not any(s in services for s in auth_service_names)
310304
if required:
311305
required_params[param] = services
312-
return required_params
306+
return required_params

0 commit comments

Comments
 (0)