13
13
# limitations under the License.
14
14
15
15
16
+ import types
17
+ from collections import defaultdict
16
18
from inspect import Parameter , Signature
17
- from typing import Any
19
+ from typing import Any , Callable , DefaultDict , Iterable , Mapping , Optional , Sequence
18
20
19
21
from aiohttp import ClientSession
22
+ from pytest import Session
20
23
21
24
22
25
class ToolboxTool :
@@ -32,20 +35,19 @@ class ToolboxTool:
32
35
and `inspect` work as expected.
33
36
"""
34
37
35
- __url : str
36
- __session : ClientSession
37
- __signature__ : Signature
38
-
39
38
def __init__ (
40
39
self ,
41
40
session : ClientSession ,
42
41
base_url : str ,
43
42
name : str ,
44
43
desc : str ,
45
- params : list [Parameter ],
44
+ params : Sequence [Parameter ],
45
+ required_authn_params : Mapping [str , list [str ]],
46
+ auth_service_token_getters : Mapping [str , Callable [[], str ]],
46
47
):
47
48
"""
48
- Initializes a callable that will trigger the tool invocation through the Toolbox server.
49
+ Initializes a callable that will trigger the tool invocation through the
50
+ Toolbox server.
49
51
50
52
Args:
51
53
session: The `aiohttp.ClientSession` used for making API requests.
@@ -54,19 +56,69 @@ def __init__(
54
56
desc: The description of the remote tool (used as its docstring).
55
57
params: A list of `inspect.Parameter` objects defining the tool's
56
58
arguments and their types/defaults.
59
+ required_authn_params: A dict of required authenticated parameters that
60
+ need a auth_service_token_getter set for them yet.
61
+ auth_service_tokens: A dict of authService -> token (or callables that
62
+ produce a token)
57
63
"""
58
64
59
65
# used to invoke the toolbox API
60
- self .__session = session
66
+ self .__session : ClientSession = session
67
+ self .__base_url : str = base_url
61
68
self .__url = f"{ base_url } /api/tool/{ name } /invoke"
62
69
63
- # the following properties are set to help anyone that might inspect it determine
70
+ self .__desc = desc
71
+ self .__params = params
72
+
73
+ # the following properties are set to help anyone that might inspect it determine usage
64
74
self .__name__ = name
65
75
self .__doc__ = desc
66
76
self .__signature__ = Signature (parameters = params , return_annotation = str )
67
77
self .__annotations__ = {p .name : p .annotation for p in params }
68
78
# TODO: self.__qualname__ ??
69
79
80
+ # map of parameter name to auth service required by it
81
+ self .__required_authn_params = required_authn_params
82
+ # map of authService -> token_getter
83
+ self .__auth_service_token_getters = auth_service_token_getters
84
+
85
+ def __copy (
86
+ self ,
87
+ session : Optional [ClientSession ] = None ,
88
+ base_url : Optional [str ] = None ,
89
+ name : Optional [str ] = None ,
90
+ desc : Optional [str ] = None ,
91
+ params : Optional [list [Parameter ]] = None ,
92
+ required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
93
+ auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
94
+ ):
95
+ """
96
+ Creates a copy of the ToolboxTool, overriding specific fields.
97
+
98
+ Args:
99
+ session: The `aiohttp.ClientSession` used for making API requests.
100
+ base_url: The base URL of the Toolbox server API.
101
+ name: The name of the remote tool.
102
+ desc: The description of the remote tool (used as its docstring).
103
+ params: A list of `inspect.Parameter` objects defining the tool's
104
+ arguments and their types/defaults.
105
+ required_authn_params: A dict of required authenticated parameters that need
106
+ a auth_service_token_getter set for them yet.
107
+ auth_service_token_getters: A dict of authService -> token (or callables
108
+ that produce a token)
109
+
110
+ """
111
+ return ToolboxTool (
112
+ session = session or self .__session ,
113
+ base_url = base_url or self .__base_url ,
114
+ name = name or self .__name__ ,
115
+ desc = desc or self .__desc ,
116
+ params = params or self .__params ,
117
+ required_authn_params = required_authn_params or self .__required_authn_params ,
118
+ auth_service_token_getters = auth_service_token_getters
119
+ or self .__auth_service_token_getters ,
120
+ )
121
+
70
122
async def __call__ (self , * args : Any , ** kwargs : Any ) -> str :
71
123
"""
72
124
Asynchronously calls the remote tool with the provided arguments.
@@ -81,16 +133,96 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
81
133
Returns:
82
134
The string result returned by the remote tool execution.
83
135
"""
136
+
137
+ # check if any auth services need to be specified yet
138
+ if len (self .__required_authn_params ) > 0 :
139
+ req_auth_services = set (l for l in self .__required_authn_params .keys ())
140
+ raise Exception (
141
+ f"One of more of the following authn services are required to invoke this tool: { ',' .join (req_auth_services )} "
142
+ )
143
+
144
+ # validate inputs to this call using the signature
84
145
all_args = self .__signature__ .bind (* args , ** kwargs )
85
146
all_args .apply_defaults () # Include default values if not provided
86
147
payload = all_args .arguments
87
148
149
+ # create headers for auth services
150
+ headers = {}
151
+ for auth_service , token_getter in self .__auth_service_token_getters .items ():
152
+ headers [f"{ auth_service } _token" ] = token_getter ()
153
+
88
154
async with self .__session .post (
89
155
self .__url ,
90
156
json = payload ,
157
+ headers = headers ,
91
158
) as resp :
92
- ret = await resp .json ()
93
- if "error" in ret :
94
- # TODO: better error
95
- raise Exception (ret ["error" ])
96
- return ret .get ("result" , ret )
159
+ body = await resp .json ()
160
+ if resp .status < 200 or resp .status >= 300 :
161
+ err = body .get ("error" , f"unexpected status from server: { resp .status } " )
162
+ raise Exception (err )
163
+ return body .get ("result" , body )
164
+
165
+ def add_auth_token_getters (
166
+ self ,
167
+ auth_token_getters : Mapping [str , Callable [[], str ]],
168
+ ) -> "ToolboxTool" :
169
+ """
170
+ Registers a auth token getter function that is used for AuthServices when tools
171
+ are invoked.
172
+
173
+ Args:
174
+ auth_token_getters: A mapping of authentication service names to
175
+ callables that return the corresponding authentication token.
176
+
177
+ Returns:
178
+ A new ToolboxTool instance with the specified authentication token
179
+ getters registered.
180
+ """
181
+
182
+ # throw an error if the authentication source is already registered
183
+ dupes = auth_token_getters .keys () & self .__auth_service_token_getters .keys ()
184
+ if dupes :
185
+ raise ValueError (
186
+ f"Authentication source(s) `{ ', ' .join (dupes )} ` already registered in tool `{ self .__name__ } `."
187
+ )
188
+
189
+ # create a read-only updated value for new_getters
190
+ new_getters = types .MappingProxyType (
191
+ dict (self .__auth_service_token_getters , ** auth_token_getters )
192
+ )
193
+ # create a read-only updated for params that are still required
194
+ new_req_authn_params = types .MappingProxyType (
195
+ filter_required_authn_params (
196
+ self .__required_authn_params , auth_token_getters .keys ()
197
+ )
198
+ )
199
+
200
+ return self .__copy (
201
+ auth_service_token_getters = new_getters ,
202
+ required_authn_params = new_req_authn_params ,
203
+ )
204
+
205
+
206
+ def filter_required_authn_params (
207
+ req_authn_params : Mapping [str , list [str ]], auth_services : Iterable [str ]
208
+ ) -> dict [str , list [str ]]:
209
+ """
210
+ Utility function for reducing 'req_authn_params' to a subset of parameters that aren't supplied by a least one service in auth_services.
211
+
212
+ Args:
213
+ req_authn_params: A mapping of parameter names to sets of required
214
+ authentication services.
215
+ auth_services: An iterable of authentication service names for which
216
+ token getters are available.
217
+
218
+ Returns:
219
+ A new dictionary representing the subset of required authentication
220
+ parameters that are not covered by the provided `auth_services`.
221
+ """
222
+ req_params = {}
223
+ for param , services in req_authn_params .items ():
224
+ # if we don't have a token_getter for any of the services required by the param, the param is still required
225
+ required = not any (s in services for s in auth_services )
226
+ if required :
227
+ req_params [param ] = services
228
+ return req_params
0 commit comments