14
14
15
15
16
16
import asyncio
17
- import copy
18
17
import types
19
18
from inspect import Signature
20
19
from typing import (
23
22
Iterable ,
24
23
Mapping ,
25
24
Optional ,
25
+ Sequence ,
26
26
Union ,
27
27
)
28
28
29
29
from aiohttp import ClientSession
30
30
31
- from toolbox_core .protocol import ToolSchema
31
+ from toolbox_core .protocol import ParameterSchema
32
32
33
33
34
34
class ToolboxTool :
@@ -49,7 +49,8 @@ def __init__(
49
49
session : ClientSession ,
50
50
base_url : str ,
51
51
name : str ,
52
- schema : ToolSchema ,
52
+ description : str ,
53
+ params : Sequence [ParameterSchema ],
53
54
required_authn_params : Mapping [str , list [str ]],
54
55
auth_service_token_getters : Mapping [str , Callable [[], str ]],
55
56
bound_params : Mapping [str , Union [Callable [[], Any ], Any ]],
@@ -62,29 +63,28 @@ def __init__(
62
63
session: The `aiohttp.ClientSession` used for making API requests.
63
64
base_url: The base URL of the Toolbox server API.
64
65
name: The name of the remote tool.
65
- schema: The schema of the tool.
66
+ description: The description of the remote tool.
67
+ params: The args of the tool.
66
68
required_authn_params: A dict of required authenticated parameters to a list
67
69
of services that provide values for them.
68
70
auth_service_token_getters: A dict of authService -> token (or callables that
69
71
produce a token)
70
72
bound_params: A mapping of parameter names to bind to specific values or
71
73
callables that are called to produce values as needed.
72
-
73
74
"""
74
-
75
75
# used to invoke the toolbox API
76
76
self .__session : ClientSession = session
77
77
self .__base_url : str = base_url
78
78
self .__url = f"{ base_url } /api/tool/{ name } /invoke"
79
-
80
- self .__params = [ param . to_param () for param in schema . parameters ]
81
- self . __schema = schema
79
+ self . __description = description
80
+ self .__params = params
81
+ inspect_type_params = [ param . to_param () for param in self . __params ]
82
82
83
83
# the following properties are set to help anyone that might inspect it determine usage
84
84
self .__name__ = name
85
- self .__doc__ = self ._schema_to_docstring ( self . __schema )
86
- self .__signature__ = Signature (parameters = self . __params , return_annotation = str )
87
- self .__annotations__ = {p .name : p .annotation for p in self . __params }
85
+ self .__doc__ = self ._create_docstring ( )
86
+ self .__signature__ = Signature (parameters = inspect_type_params , return_annotation = str )
87
+ self .__annotations__ = {p .name : p .annotation for p in inspect_type_params }
88
88
# TODO: self.__qualname__ ??
89
89
90
90
# map of parameter name to auth service required by it
@@ -94,14 +94,13 @@ def __init__(
94
94
# map of parameter name to value (or callable that produces that value)
95
95
self .__bound_parameters = bound_params
96
96
97
- @staticmethod
98
- def _schema_to_docstring (schema : ToolSchema ) -> str :
97
+ def _create_docstring (self ) -> str :
99
98
"""Convert a tool schema into its function docstring"""
100
- docstring = schema . description
101
- if not schema . parameters :
99
+ docstring = self . __description
100
+ if not self . __params :
102
101
return docstring
103
102
docstring += "\n \n Args:"
104
- for p in schema . parameters :
103
+ for p in self . __params :
105
104
docstring += (
106
105
f"\n { p .name } ({ p .to_param ().annotation .__name__ } ): { p .description } "
107
106
)
@@ -112,7 +111,8 @@ def __copy(
112
111
session : Optional [ClientSession ] = None ,
113
112
base_url : Optional [str ] = None ,
114
113
name : Optional [str ] = None ,
115
- schema : Optional [ToolSchema ] = None ,
114
+ description : Optional [str ] = None ,
115
+ params : Optional [Sequence [ParameterSchema ]] = None ,
116
116
required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
117
117
auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
118
118
bound_params : Optional [Mapping [str , Union [Callable [[], Any ], Any ]]] = None ,
@@ -124,7 +124,8 @@ def __copy(
124
124
session: The `aiohttp.ClientSession` used for making API requests.
125
125
base_url: The base URL of the Toolbox server API.
126
126
name: The name of the remote tool.
127
- schema: The schema of the tool.
127
+ description: The description of the remote tool.
128
+ params: The args of the tool.
128
129
required_authn_params: A dict of required authenticated parameters that need
129
130
a auth_service_token_getter set for them yet.
130
131
auth_service_token_getters: A dict of authService -> token (or callables
@@ -138,7 +139,8 @@ def __copy(
138
139
session = check (session , self .__session ),
139
140
base_url = check (base_url , self .__base_url ),
140
141
name = check (name , self .__name__ ),
141
- schema = check (schema , self .__schema ),
142
+ description = check (description , self .__description ),
143
+ params = check (params , self .__params ),
142
144
required_authn_params = check (
143
145
required_authn_params , self .__required_authn_params
144
146
),
@@ -239,14 +241,7 @@ def add_auth_token_getters(
239
241
)
240
242
)
241
243
242
- # Update tool params in schema
243
- new_schema = copy .deepcopy (self .__schema )
244
- for param in new_schema .parameters :
245
- if param .name in auth_token_getters .keys ():
246
- new_schema .parameters .remove (param )
247
-
248
244
return self .__copy (
249
- schema = new_schema ,
250
245
auth_service_token_getters = new_getters ,
251
246
required_authn_params = new_req_authn_params ,
252
247
)
@@ -264,25 +259,23 @@ def bind_parameters(
264
259
Returns:
265
260
A new ToolboxTool instance with the specified parameters bound.
266
261
"""
267
- param_names = set (p .name for p in self .__schema . parameters )
262
+ param_names = set (p .name for p in self .__params )
268
263
for name in bound_params .keys ():
269
264
if name not in param_names :
270
265
raise Exception (f"unable to bind parameters: no parameter named { name } " )
271
266
272
- # Update tool params in schema
273
- new_schema = copy .deepcopy (self .__schema )
274
- for param in new_schema .parameters :
275
- if param .name in bound_params :
276
- new_schema .parameters .remove (param )
267
+ new_params = []
268
+ for p in self .__params :
269
+ if p .name not in bound_params :
270
+ new_params .append (p )
271
+ all_bound_params = dict (self .__bound_parameters )
272
+ all_bound_params .update (bound_params )
277
273
278
274
return self .__copy (
279
- schema = new_schema ,
280
- bound_params = types .MappingProxyType (
281
- dict (self .__bound_parameters , ** bound_params )
282
- ),
275
+ params = new_params ,
276
+ bound_params = types .MappingProxyType (all_bound_params ),
283
277
)
284
278
285
-
286
279
def identify_required_authn_params (
287
280
req_authn_params : Mapping [str , list [str ]], auth_service_names : Iterable [str ]
288
281
) -> dict [str , list [str ]]:
0 commit comments