14
14
15
15
16
16
import asyncio
17
- import copy
18
17
import types
19
18
from inspect import Signature
20
- from typing import Any , Callable , Iterable , Mapping , Optional , Union
19
+ from typing import (
20
+ Any ,
21
+ Callable ,
22
+ Iterable ,
23
+ Mapping ,
24
+ Optional ,
25
+ Sequence ,
26
+ Union ,
27
+ )
21
28
22
29
from aiohttp import ClientSession
23
30
24
- from toolbox_core .protocol import ToolSchema
31
+ from toolbox_core .protocol import ParameterSchema
25
32
26
33
27
34
class ToolboxTool :
@@ -42,7 +49,8 @@ def __init__(
42
49
session : ClientSession ,
43
50
base_url : str ,
44
51
name : str ,
45
- schema : ToolSchema ,
52
+ description : str ,
53
+ params : Sequence [ParameterSchema ],
46
54
required_authn_params : Mapping [str , list [str ]],
47
55
auth_service_token_getters : Mapping [str , Callable [[], str ]],
48
56
bound_params : Mapping [str , Union [Callable [[], Any ], Any ]],
@@ -55,29 +63,30 @@ def __init__(
55
63
session: The `aiohttp.ClientSession` used for making API requests.
56
64
base_url: The base URL of the Toolbox server API.
57
65
name: The name of the remote tool.
58
- schema: The schema of the tool.
66
+ description: The description of the remote tool.
67
+ params: The args of the tool.
59
68
required_authn_params: A dict of required authenticated parameters to a list
60
69
of services that provide values for them.
61
70
auth_service_token_getters: A dict of authService -> token (or callables that
62
71
produce a token)
63
72
bound_params: A mapping of parameter names to bind to specific values or
64
73
callables that are called to produce values as needed.
65
-
66
74
"""
67
-
68
75
# used to invoke the toolbox API
69
76
self .__session : ClientSession = session
70
77
self .__base_url : str = base_url
71
78
self .__url = f"{ base_url } /api/tool/{ name } /invoke"
72
-
73
- self .__params = [ param . to_param () for param in schema . parameters ]
74
- self . __schema = schema
79
+ self . __description = description
80
+ self .__params = params
81
+ inspect_type_params = [ param . to_param () for param in self . __params ]
75
82
76
83
# the following properties are set to help anyone that might inspect it determine usage
77
84
self .__name__ = name
78
- self .__doc__ = self ._schema_to_docstring (self .__schema )
79
- self .__signature__ = Signature (parameters = self .__params , return_annotation = str )
80
- self .__annotations__ = {p .name : p .annotation for p in self .__params }
85
+ self .__doc__ = self ._create_docstring ()
86
+ self .__signature__ = Signature (
87
+ parameters = inspect_type_params , return_annotation = str
88
+ )
89
+ self .__annotations__ = {p .name : p .annotation for p in inspect_type_params }
81
90
# TODO: self.__qualname__ ??
82
91
83
92
# map of parameter name to auth service required by it
@@ -87,23 +96,25 @@ def __init__(
87
96
# map of parameter name to value (or callable that produces that value)
88
97
self .__bound_parameters = bound_params
89
98
90
- @staticmethod
91
- def _schema_to_docstring (schema : ToolSchema ) -> str :
99
+ def _create_docstring (self ) -> str :
92
100
"""Convert a tool schema into its function docstring"""
93
- docstring = schema . description
94
- if not schema . parameters :
101
+ docstring = self . __description
102
+ if not self . __params :
95
103
return docstring
96
104
docstring += "\n \n Args:"
97
- for p in schema .parameters :
98
- docstring += f"\n { p .name } ({ p .type } ): { p .description } "
105
+ for p in self .__params :
106
+ docstring += (
107
+ f"\n { p .name } ({ p .to_param ().annotation .__name__ } ): { p .description } "
108
+ )
99
109
return docstring
100
110
101
111
def __copy (
102
112
self ,
103
113
session : Optional [ClientSession ] = None ,
104
114
base_url : Optional [str ] = None ,
105
115
name : Optional [str ] = None ,
106
- schema : Optional [ToolSchema ] = None ,
116
+ description : Optional [str ] = None ,
117
+ params : Optional [Sequence [ParameterSchema ]] = None ,
107
118
required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
108
119
auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
109
120
bound_params : Optional [Mapping [str , Union [Callable [[], Any ], Any ]]] = None ,
@@ -115,7 +126,8 @@ def __copy(
115
126
session: The `aiohttp.ClientSession` used for making API requests.
116
127
base_url: The base URL of the Toolbox server API.
117
128
name: The name of the remote tool.
118
- schema: The schema of the tool.
129
+ description: The description of the remote tool.
130
+ params: The args of the tool.
119
131
required_authn_params: A dict of required authenticated parameters that need
120
132
a auth_service_token_getter set for them yet.
121
133
auth_service_token_getters: A dict of authService -> token (or callables
@@ -129,7 +141,8 @@ def __copy(
129
141
session = check (session , self .__session ),
130
142
base_url = check (base_url , self .__base_url ),
131
143
name = check (name , self .__name__ ),
132
- schema = check (schema , self .__schema ),
144
+ description = check (description , self .__description ),
145
+ params = check (params , self .__params ),
133
146
required_authn_params = check (
134
147
required_authn_params , self .__required_authn_params
135
148
),
@@ -234,14 +247,7 @@ def add_auth_token_getters(
234
247
)
235
248
)
236
249
237
- # Update tool params in schema
238
- new_schema = copy .deepcopy (self .__schema )
239
- for param in new_schema .parameters :
240
- if param .name in auth_token_getters .keys ():
241
- new_schema .parameters .remove (param )
242
-
243
250
return self .__copy (
244
- schema = new_schema ,
245
251
auth_service_token_getters = new_getters ,
246
252
required_authn_params = new_req_authn_params ,
247
253
)
@@ -259,22 +265,21 @@ def bind_parameters(
259
265
Returns:
260
266
A new ToolboxTool instance with the specified parameters bound.
261
267
"""
262
- param_names = set (p .name for p in self .__schema . parameters )
268
+ param_names = set (p .name for p in self .__params )
263
269
for name in bound_params .keys ():
264
270
if name not in param_names :
265
271
raise Exception (f"unable to bind parameters: no parameter named { name } " )
266
272
267
- # Update tool params in schema
268
- new_schema = copy .deepcopy (self .__schema )
269
- for param in new_schema .parameters :
270
- if param .name in bound_params :
271
- new_schema .parameters .remove (param )
273
+ new_params = []
274
+ for p in self .__params :
275
+ if p .name not in bound_params :
276
+ new_params .append (p )
277
+ all_bound_params = dict (self .__bound_parameters )
278
+ all_bound_params .update (bound_params )
272
279
273
280
return self .__copy (
274
- schema = new_schema ,
275
- bound_params = types .MappingProxyType (
276
- dict (self .__bound_parameters , ** bound_params )
277
- ),
281
+ params = new_params ,
282
+ bound_params = types .MappingProxyType (all_bound_params ),
278
283
)
279
284
280
285
0 commit comments