Skip to content

Commit 018cd88

Browse files
authored
Merge branch 'twisha-core-docstring' into twisha-core-pydantic
2 parents 8793fa7 + 5cc8978 commit 018cd88

File tree

6 files changed

+118
-108
lines changed

6 files changed

+118
-108
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import re
1514
import types
1615
from typing import Any, Callable, Mapping, Optional, Union
1716

@@ -79,13 +78,13 @@ def __parse_tool(
7978
authn_params = identify_required_authn_params(
8079
authn_params, auth_token_getters.keys()
8180
)
82-
schema.parameters = params
8381

8482
tool = ToolboxTool(
8583
session=self.__session,
8684
base_url=self.__base_url,
8785
name=name,
88-
schema=schema,
86+
description=schema.description,
87+
params=params,
8988
# create a read-only values for the maps to prevent mutation
9089
required_authn_params=types.MappingProxyType(authn_params),
9190
auth_service_token_getters=types.MappingProxyType(auth_token_getters),

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,21 @@
1414

1515

1616
import asyncio
17-
import copy
1817
import types
1918
from inspect import Signature
20-
from typing import Any, Callable, Iterable, Mapping, Optional, Union
19+
from typing import (
20+
Any,
21+
Callable,
22+
Iterable,
23+
Mapping,
24+
Optional,
25+
Sequence,
26+
Union,
27+
)
2128

2229
from aiohttp import ClientSession
2330

24-
from toolbox_core.protocol import ToolSchema
31+
from toolbox_core.protocol import ParameterSchema
2532

2633

2734
class ToolboxTool:
@@ -42,7 +49,8 @@ def __init__(
4249
session: ClientSession,
4350
base_url: str,
4451
name: str,
45-
schema: ToolSchema,
52+
description: str,
53+
params: Sequence[ParameterSchema],
4654
required_authn_params: Mapping[str, list[str]],
4755
auth_service_token_getters: Mapping[str, Callable[[], str]],
4856
bound_params: Mapping[str, Union[Callable[[], Any], Any]],
@@ -55,29 +63,30 @@ def __init__(
5563
session: The `aiohttp.ClientSession` used for making API requests.
5664
base_url: The base URL of the Toolbox server API.
5765
name: The name of the remote tool.
58-
schema: The schema of the tool.
66+
description: The description of the remote tool.
67+
params: The args of the tool.
5968
required_authn_params: A dict of required authenticated parameters to a list
6069
of services that provide values for them.
6170
auth_service_token_getters: A dict of authService -> token (or callables that
6271
produce a token)
6372
bound_params: A mapping of parameter names to bind to specific values or
6473
callables that are called to produce values as needed.
65-
6674
"""
67-
6875
# used to invoke the toolbox API
6976
self.__session: ClientSession = session
7077
self.__base_url: str = base_url
7178
self.__url = f"{base_url}/api/tool/{name}/invoke"
72-
73-
self.__params = [param.to_param() for param in schema.parameters]
74-
self.__schema = schema
79+
self.__description = description
80+
self.__params = params
81+
inspect_type_params = [param.to_param() for param in self.__params]
7582

7683
# the following properties are set to help anyone that might inspect it determine usage
7784
self.__name__ = name
78-
self.__doc__ = self._schema_to_docstring(self.__schema)
79-
self.__signature__ = Signature(parameters=self.__params, return_annotation=str)
80-
self.__annotations__ = {p.name: p.annotation for p in self.__params}
85+
self.__doc__ = self._create_docstring()
86+
self.__signature__ = Signature(
87+
parameters=inspect_type_params, return_annotation=str
88+
)
89+
self.__annotations__ = {p.name: p.annotation for p in inspect_type_params}
8190
# TODO: self.__qualname__ ??
8291

8392
# map of parameter name to auth service required by it
@@ -87,23 +96,25 @@ def __init__(
8796
# map of parameter name to value (or callable that produces that value)
8897
self.__bound_parameters = bound_params
8998

90-
@staticmethod
91-
def _schema_to_docstring(schema: ToolSchema) -> str:
99+
def _create_docstring(self) -> str:
92100
"""Convert a tool schema into its function docstring"""
93-
docstring = schema.description
94-
if not schema.parameters:
101+
docstring = self.__description
102+
if not self.__params:
95103
return docstring
96104
docstring += "\n\nArgs:"
97-
for p in schema.parameters:
98-
docstring += f"\n {p.name} ({p.type}): {p.description}"
105+
for p in self.__params:
106+
docstring += (
107+
f"\n {p.name} ({p.to_param().annotation.__name__}): {p.description}"
108+
)
99109
return docstring
100110

101111
def __copy(
102112
self,
103113
session: Optional[ClientSession] = None,
104114
base_url: Optional[str] = None,
105115
name: Optional[str] = None,
106-
schema: Optional[ToolSchema] = None,
116+
description: Optional[str] = None,
117+
params: Optional[Sequence[ParameterSchema]] = None,
107118
required_authn_params: Optional[Mapping[str, list[str]]] = None,
108119
auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None,
109120
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
@@ -115,7 +126,8 @@ def __copy(
115126
session: The `aiohttp.ClientSession` used for making API requests.
116127
base_url: The base URL of the Toolbox server API.
117128
name: The name of the remote tool.
118-
schema: The schema of the tool.
129+
description: The description of the remote tool.
130+
params: The args of the tool.
119131
required_authn_params: A dict of required authenticated parameters that need
120132
a auth_service_token_getter set for them yet.
121133
auth_service_token_getters: A dict of authService -> token (or callables
@@ -129,7 +141,8 @@ def __copy(
129141
session=check(session, self.__session),
130142
base_url=check(base_url, self.__base_url),
131143
name=check(name, self.__name__),
132-
schema=check(schema, self.__schema),
144+
description=check(description, self.__description),
145+
params=check(params, self.__params),
133146
required_authn_params=check(
134147
required_authn_params, self.__required_authn_params
135148
),
@@ -234,14 +247,7 @@ def add_auth_token_getters(
234247
)
235248
)
236249

237-
# Update tool params in schema
238-
new_schema = copy.deepcopy(self.__schema)
239-
for param in new_schema.parameters:
240-
if param.name in auth_token_getters.keys():
241-
new_schema.parameters.remove(param)
242-
243250
return self.__copy(
244-
schema=new_schema,
245251
auth_service_token_getters=new_getters,
246252
required_authn_params=new_req_authn_params,
247253
)
@@ -259,22 +265,21 @@ def bind_parameters(
259265
Returns:
260266
A new ToolboxTool instance with the specified parameters bound.
261267
"""
262-
param_names = set(p.name for p in self.__schema.parameters)
268+
param_names = set(p.name for p in self.__params)
263269
for name in bound_params.keys():
264270
if name not in param_names:
265271
raise Exception(f"unable to bind parameters: no parameter named {name}")
266272

267-
# Update tool params in schema
268-
new_schema = copy.deepcopy(self.__schema)
269-
for param in new_schema.parameters:
270-
if param.name in bound_params:
271-
new_schema.parameters.remove(param)
273+
new_params = []
274+
for p in self.__params:
275+
if p.name not in bound_params:
276+
new_params.append(p)
277+
all_bound_params = dict(self.__bound_parameters)
278+
all_bound_params.update(bound_params)
272279

273280
return self.__copy(
274-
schema=new_schema,
275-
bound_params=types.MappingProxyType(
276-
dict(self.__bound_parameters, **bound_params)
277-
),
281+
params=new_params,
282+
bound_params=types.MappingProxyType(all_bound_params),
278283
)
279284

280285

packages/toolbox-core/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,4 @@ def toolbox_server(toolbox_version: str, tools_file_path: str) -> Generator[None
163163

164164
# Clean up toolbox server
165165
toolbox_server.terminate()
166-
toolbox_server.wait()
166+
toolbox_server.wait(timeout=5)

packages/toolbox-core/tests/test_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ async def test_load_tool_success(aioresponses, test_tool_str):
9292
assert callable(loaded_tool)
9393
# Assert introspection attributes are set correctly
9494
assert loaded_tool.__name__ == TOOL_NAME
95-
assert (
96-
loaded_tool.__doc__
97-
== test_tool_str.description
98-
+ f"\n\nArgs:\n param1 (string): Description of Param1"
95+
expected_description = (
96+
test_tool_str.description
97+
+ f"\n\nArgs:\n param1 (str): Description of Param1"
9998
)
99+
assert loaded_tool.__doc__ == expected_description
100100

101101
# Assert signature inspection
102102
sig = inspect.signature(loaded_tool)

0 commit comments

Comments
 (0)