14
14
15
15
16
16
import asyncio
17
+ import copy
17
18
import types
18
- from inspect import Parameter , Signature
19
+ from inspect import Signature
19
20
from typing import (
20
21
Any ,
21
22
Callable ,
22
23
Iterable ,
23
24
Mapping ,
24
25
Optional ,
25
- Sequence ,
26
26
Union ,
27
27
)
28
-
28
+ from toolbox_core . protocol import ToolSchema
29
29
from aiohttp import ClientSession
30
30
31
31
@@ -47,9 +47,7 @@ def __init__(
47
47
session : ClientSession ,
48
48
base_url : str ,
49
49
name : str ,
50
- desc : str ,
51
- params : Sequence [Parameter ],
52
- params_metadata : Mapping [str , tuple [str , str ]],
50
+ schema : ToolSchema ,
53
51
required_authn_params : Mapping [str , list [str ]],
54
52
auth_service_token_getters : Mapping [str , Callable [[], str ]],
55
53
bound_params : Mapping [str , Union [Callable [[], Any ], Any ]],
@@ -62,10 +60,7 @@ def __init__(
62
60
session: The `aiohttp.ClientSession` used for making API requests.
63
61
base_url: The base URL of the Toolbox server API.
64
62
name: The name of the remote tool.
65
- desc: The description of the remote tool (used as its docstring).
66
- params: A list of `inspect.Parameter` objects defining the tool's
67
- arguments and their types/defaults.
68
- params_metadata: A mapping of param names to their types and descriptions.
63
+ schema: The schema of the tool.
69
64
required_authn_params: A dict of required authenticated parameters to a list
70
65
of services that provide values for them.
71
66
auth_service_token_getters: A dict of authService -> token (or callables that
@@ -80,15 +75,14 @@ def __init__(
80
75
self .__base_url : str = base_url
81
76
self .__url = f"{ base_url } /api/tool/{ name } /invoke"
82
77
83
- self .__desc = desc
84
- self .__params = params
85
- self .__params_metadata = params_metadata
78
+ self .__params = [param .to_param () for param in schema .parameters ]
79
+ self .__schema = schema
86
80
87
81
# the following properties are set to help anyone that might inspect it determine usage
88
82
self .__name__ = name
89
- self .__doc__ = self ._schema_to_docstring (desc , params , params_metadata )
90
- self .__signature__ = Signature (parameters = params , return_annotation = str )
91
- self .__annotations__ = {p .name : p .annotation for p in params }
83
+ self .__doc__ = self ._schema_to_docstring (self . __schema )
84
+ self .__signature__ = Signature (parameters = self . __params , return_annotation = str )
85
+ self .__annotations__ = {p .name : p .annotation for p in self . __params }
92
86
# TODO: self.__qualname__ ??
93
87
94
88
# map of parameter name to auth service required by it
@@ -100,27 +94,23 @@ def __init__(
100
94
101
95
@staticmethod
102
96
def _schema_to_docstring (
103
- tool_description : str ,
104
- params : Sequence [Parameter ],
105
- params_metadata : Mapping [str , tuple [str , str ]],
97
+ schema : ToolSchema
106
98
) -> str :
107
- """Creates a python function docstring from a tool and it's params. """
108
- docstring = tool_description
109
- if not params :
99
+ """Convert a tool schema into its function docstring """
100
+ docstring = schema . description
101
+ if not schema . parameters :
110
102
return docstring
111
103
docstring += "\n \n Args:"
112
- for p in params :
113
- param_metadata = params_metadata [p .name ]
114
- docstring += f"\n { p .name } ({ param_metadata [0 ]} ): { param_metadata [1 ]} "
104
+ for p in schema .parameters :
105
+ docstring += f"\n { p .name } ({ p .type } ): { p .description } "
115
106
return docstring
116
107
117
108
def __copy (
118
109
self ,
119
110
session : Optional [ClientSession ] = None ,
120
111
base_url : Optional [str ] = None ,
121
112
name : Optional [str ] = None ,
122
- desc : Optional [str ] = None ,
123
- params : Optional [list [Parameter ]] = None ,
113
+ schema : ToolSchema = None ,
124
114
required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
125
115
auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
126
116
bound_params : Optional [Mapping [str , Union [Callable [[], Any ], Any ]]] = None ,
@@ -132,9 +122,7 @@ def __copy(
132
122
session: The `aiohttp.ClientSession` used for making API requests.
133
123
base_url: The base URL of the Toolbox server API.
134
124
name: The name of the remote tool.
135
- desc: The description of the remote tool (used as its docstring).
136
- params: A list of `inspect.Parameter` objects defining the tool's
137
- arguments and their types/defaults.
125
+ schema: The schema of the tool.
138
126
required_authn_params: A dict of required authenticated parameters that need
139
127
a auth_service_token_getter set for them yet.
140
128
auth_service_token_getters: A dict of authService -> token (or callables
@@ -148,9 +136,7 @@ def __copy(
148
136
session = check (session , self .__session ),
149
137
base_url = check (base_url , self .__base_url ),
150
138
name = check (name , self .__name__ ),
151
- desc = check (desc , self .__desc ),
152
- params = check (params , self .__params ),
153
- params_metadata = self .__params_metadata ,
139
+ schema = check (schema , self .__schema ),
154
140
required_authn_params = check (
155
141
required_authn_params , self .__required_authn_params
156
142
),
@@ -251,7 +237,14 @@ def add_auth_token_getters(
251
237
)
252
238
)
253
239
240
+ # Update tool params in schema
241
+ new_schema = copy .deepcopy (self .__schema )
242
+ for param in new_schema .parameters :
243
+ if param .name in auth_token_getters .keys ():
244
+ new_schema .parameters .remove (param )
245
+
254
246
return self .__copy (
247
+ schema = new_schema ,
255
248
auth_service_token_getters = new_getters ,
256
249
required_authn_params = new_req_authn_params ,
257
250
)
@@ -269,19 +262,20 @@ def bind_parameters(
269
262
Returns:
270
263
A new ToolboxTool instance with the specified parameters bound.
271
264
"""
272
- param_names = set (p .name for p in self .__params )
265
+ param_names = set (p .name for p in self .__schema . parameters )
273
266
for name in bound_params .keys ():
274
267
if name not in param_names :
275
268
raise Exception (f"unable to bind parameters: no parameter named { name } " )
276
269
277
- new_params = []
278
- for p in self .__params :
279
- if p .name not in bound_params :
280
- new_params .append (p )
270
+ # Update tool params in schema
271
+ new_schema = copy .deepcopy (self .__schema )
272
+ for param in new_schema .parameters :
273
+ if param .name in bound_params :
274
+ new_schema .parameters .remove (param )
281
275
282
276
return self .__copy (
283
- params = new_params ,
284
- bound_params = bound_params ,
277
+ schema = new_schema ,
278
+ bound_params = types . MappingProxyType ( dict ( self . __bound_parameters , ** bound_params ))
285
279
)
286
280
287
281
@@ -309,4 +303,4 @@ def identify_required_authn_params(
309
303
required = not any (s in services for s in auth_service_names )
310
304
if required :
311
305
required_params [param ] = services
312
- return required_params
306
+ return required_params
0 commit comments