13
13
# limitations under the License.
14
14
15
15
16
+ import types
17
+ from collections import defaultdict
16
18
from inspect import Parameter , Signature
17
- from typing import Any
19
+ from typing import Any , Callable , DefaultDict , Iterable , Mapping , Optional , Sequence
18
20
19
21
from aiohttp import ClientSession
22
+ from pytest import Session
20
23
21
24
22
25
class ToolboxTool :
@@ -32,20 +35,19 @@ class ToolboxTool:
32
35
and `inspect` work as expected.
33
36
"""
34
37
35
- __url : str
36
- __session : ClientSession
37
- __signature__ : Signature
38
-
39
38
def __init__ (
40
39
self ,
41
40
session : ClientSession ,
42
41
base_url : str ,
43
42
name : str ,
44
43
desc : str ,
45
- params : list [Parameter ],
44
+ params : Sequence [Parameter ],
45
+ required_authn_params : Mapping [str , list [str ]],
46
+ auth_service_token_getters : Mapping [str , Callable [[], str ]],
46
47
):
47
48
"""
48
- Initializes a callable that will trigger the tool invocation through the Toolbox server.
49
+ Initializes a callable that will trigger the tool invocation through the
50
+ Toolbox server.
49
51
50
52
Args:
51
53
session: The `aiohttp.ClientSession` used for making API requests.
@@ -54,19 +56,73 @@ def __init__(
54
56
desc: The description of the remote tool (used as its docstring).
55
57
params: A list of `inspect.Parameter` objects defining the tool's
56
58
arguments and their types/defaults.
59
+ required_authn_params: A dict of required authenticated parameters to a list
60
+ of services that provide values for them.
61
+ auth_service_token_getters: A dict of authService -> token (or callables that
62
+ produce a token)
57
63
"""
58
64
59
65
# used to invoke the toolbox API
60
- self .__session = session
66
+ self .__session : ClientSession = session
67
+ self .__base_url : str = base_url
61
68
self .__url = f"{ base_url } /api/tool/{ name } /invoke"
62
69
63
- # the following properties are set to help anyone that might inspect it determine
70
+ self .__desc = desc
71
+ self .__params = params
72
+
73
+ # the following properties are set to help anyone that might inspect it determine usage
64
74
self .__name__ = name
65
75
self .__doc__ = desc
66
76
self .__signature__ = Signature (parameters = params , return_annotation = str )
67
77
self .__annotations__ = {p .name : p .annotation for p in params }
68
78
# TODO: self.__qualname__ ??
69
79
80
+ # map of parameter name to auth service required by it
81
+ self .__required_authn_params = required_authn_params
82
+ # map of authService -> token_getter
83
+ self .__auth_service_token_getters = auth_service_token_getters
84
+
85
+ def __copy (
86
+ self ,
87
+ session : Optional [ClientSession ] = None ,
88
+ base_url : Optional [str ] = None ,
89
+ name : Optional [str ] = None ,
90
+ desc : Optional [str ] = None ,
91
+ params : Optional [list [Parameter ]] = None ,
92
+ required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
93
+ auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
94
+ ) -> "ToolboxTool" :
95
+ """
96
+ Creates a copy of the ToolboxTool, overriding specific fields.
97
+
98
+ Args:
99
+ session: The `aiohttp.ClientSession` used for making API requests.
100
+ base_url: The base URL of the Toolbox server API.
101
+ name: The name of the remote tool.
102
+ desc: The description of the remote tool (used as its docstring).
103
+ params: A list of `inspect.Parameter` objects defining the tool's
104
+ arguments and their types/defaults.
105
+ required_authn_params: A dict of required authenticated parameters that need
106
+ a auth_service_token_getter set for them yet.
107
+ auth_service_token_getters: A dict of authService -> token (or callables
108
+ that produce a token)
109
+
110
+ """
111
+ check = lambda val , default : val if val is not None else default
112
+ return ToolboxTool (
113
+ session = check (session , self .__session ),
114
+ base_url = check (base_url , self .__base_url ),
115
+ name = check (name , self .__name__ ),
116
+ desc = check (desc , self .__desc ),
117
+ params = check (params , self .__params ),
118
+ required_authn_params = check (
119
+ required_authn_params , self .__required_authn_params
120
+ ),
121
+ auth_service_token_getters = check (
122
+ auth_service_token_getters , self .__auth_service_token_getters
123
+ ),
124
+ )
125
+
70
126
async def __call__ (self , * args : Any , ** kwargs : Any ) -> str :
71
127
"""
72
128
Asynchronously calls the remote tool with the provided arguments.
@@ -81,16 +137,103 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
81
137
Returns:
82
138
The string result returned by the remote tool execution.
83
139
"""
140
+
141
+ # check if any auth services need to be specified yet
142
+ if len (self .__required_authn_params ) > 0 :
143
+ # Gather all the required auth services into a set
144
+ req_auth_services = set ()
145
+ for s in self .__required_authn_params .values ():
146
+ req_auth_services .update (s )
147
+ raise Exception (
148
+ f"One or more of the following authn services are required to invoke this tool: { ',' .join (req_auth_services )} "
149
+ )
150
+
151
+ # validate inputs to this call using the signature
84
152
all_args = self .__signature__ .bind (* args , ** kwargs )
85
153
all_args .apply_defaults () # Include default values if not provided
86
154
payload = all_args .arguments
87
155
156
+ # create headers for auth services
157
+ headers = {}
158
+ for auth_service , token_getter in self .__auth_service_token_getters .items ():
159
+ headers [f"{ auth_service } _token" ] = token_getter ()
160
+
88
161
async with self .__session .post (
89
162
self .__url ,
90
163
json = payload ,
164
+ headers = headers ,
91
165
) as resp :
92
- ret = await resp .json ()
93
- if "error" in ret :
94
- # TODO: better error
95
- raise Exception (ret ["error" ])
96
- return ret .get ("result" , ret )
166
+ body = await resp .json ()
167
+ if resp .status < 200 or resp .status >= 300 :
168
+ err = body .get ("error" , f"unexpected status from server: { resp .status } " )
169
+ raise Exception (err )
170
+ return body .get ("result" , body )
171
+
172
+ def add_auth_token_getters (
173
+ self ,
174
+ auth_token_getters : Mapping [str , Callable [[], str ]],
175
+ ) -> "ToolboxTool" :
176
+ """
177
+ Registers an auth token getter function that is used for AuthServices when tools
178
+ are invoked.
179
+
180
+ Args:
181
+ auth_token_getters: A mapping of authentication service names to
182
+ callables that return the corresponding authentication token.
183
+
184
+ Returns:
185
+ A new ToolboxTool instance with the specified authentication token
186
+ getters registered.
187
+ """
188
+
189
+ # throw an error if the authentication source is already registered
190
+ existing_services = self .__auth_service_token_getters .keys ()
191
+ incoming_services = auth_token_getters .keys ()
192
+ duplicates = existing_services & incoming_services
193
+ if duplicates :
194
+ raise ValueError (
195
+ f"Authentication source(s) `{ ', ' .join (duplicates )} ` already registered in tool `{ self .__name__ } `."
196
+ )
197
+
198
+ # create a read-only updated value for new_getters
199
+ new_getters = types .MappingProxyType (
200
+ dict (self .__auth_service_token_getters , ** auth_token_getters )
201
+ )
202
+ # create a read-only updated for params that are still required
203
+ new_req_authn_params = types .MappingProxyType (
204
+ identify_required_authn_params (
205
+ self .__required_authn_params , auth_token_getters .keys ()
206
+ )
207
+ )
208
+
209
+ return self .__copy (
210
+ auth_service_token_getters = new_getters ,
211
+ required_authn_params = new_req_authn_params ,
212
+ )
213
+
214
+
215
+ def identify_required_authn_params (
216
+ req_authn_params : Mapping [str , list [str ]], auth_service_names : Iterable [str ]
217
+ ) -> dict [str , list [str ]]:
218
+ """
219
+ Identifies authentication parameters that are still required; or not covered by
220
+ the provided `auth_service_names`.
221
+
222
+ Args:
223
+ req_authn_params: A mapping of parameter names to sets of required
224
+ authentication services.
225
+ auth_service_names: An iterable of authentication service names for which
226
+ token getters are available.
227
+
228
+ Returns:
229
+ A new dictionary representing the subset of required authentication
230
+ parameters that are not covered by the provided `auth_service_names`.
231
+ """
232
+ required_params = {} # params that are still required with provided auth_services
233
+ for param , services in req_authn_params .items ():
234
+ # if we don't have a token_getter for any of the services required by the param,
235
+ # the param is still required
236
+ required = not any (s in services for s in auth_service_names )
237
+ if required :
238
+ required_params [param ] = services
239
+ return required_params
0 commit comments