13
13
# limitations under the License.
14
14
15
15
16
+ import asyncio
16
17
import types
17
18
from collections import defaultdict
18
19
from inspect import Parameter , Signature
19
- from typing import Any , Callable , DefaultDict , Iterable , Mapping , Optional , Sequence
20
+ from typing import (
21
+ Any ,
22
+ Callable ,
23
+ DefaultDict ,
24
+ Iterable ,
25
+ Mapping ,
26
+ Optional ,
27
+ Sequence ,
28
+ Union ,
29
+ )
20
30
21
31
from aiohttp import ClientSession
22
32
from pytest import Session
@@ -44,6 +54,7 @@ def __init__(
44
54
params : Sequence [Parameter ],
45
55
required_authn_params : Mapping [str , list [str ]],
46
56
auth_service_token_getters : Mapping [str , Callable [[], str ]],
57
+ bound_params : Mapping [str , Union [Callable [[], Any ], Any ]] = {},
47
58
):
48
59
"""
49
60
Initializes a callable that will trigger the tool invocation through the
@@ -81,6 +92,8 @@ def __init__(
81
92
self .__required_authn_params = required_authn_params
82
93
# map of authService -> token_getter
83
94
self .__auth_service_token_getters = auth_service_token_getters
95
+ # map of parameter name to value or Callable
96
+ self .__bound_parameters = bound_params
84
97
85
98
def __copy (
86
99
self ,
@@ -91,6 +104,7 @@ def __copy(
91
104
params : Optional [list [Parameter ]] = None ,
92
105
required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
93
106
auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
107
+ bound_params : Optional [Mapping [str , Union [Callable [[], Any ], Any ]]] = None ,
94
108
) -> "ToolboxTool" :
95
109
"""
96
110
Creates a copy of the ToolboxTool, overriding specific fields.
@@ -121,6 +135,7 @@ def __copy(
121
135
auth_service_token_getters = check (
122
136
auth_service_token_getters , self .__auth_service_token_getters
123
137
),
138
+ bound_params = check (bound_params , self .__bound_parameters ),
124
139
)
125
140
126
141
async def __call__ (self , * args : Any , ** kwargs : Any ) -> str :
@@ -153,6 +168,14 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
153
168
all_args .apply_defaults () # Include default values if not provided
154
169
payload = all_args .arguments
155
170
171
+ # apply bounded parameters
172
+ for param , value in self .__bound_parameters .items ():
173
+ if asyncio .iscoroutinefunction (value ):
174
+ value = await value ()
175
+ elif callable (value ):
176
+ value = value ()
177
+ payload [param ] = value
178
+
156
179
# create headers for auth services
157
180
headers = {}
158
181
for auth_service , token_getter in self .__auth_service_token_getters .items ():
@@ -211,13 +234,41 @@ def add_auth_token_getters(
211
234
required_authn_params = new_req_authn_params ,
212
235
)
213
236
237
+ def bind_parameters (
238
+ self , bound_params : Mapping [str , Callable [[], str ]]
239
+ ) -> "ToolboxTool" :
240
+ """
241
+ Binds parameters to values or callables that produce values.
242
+
243
+ Args:
244
+ bound_params: A mapping of parameter names to values or callables that
245
+ produce values.
246
+
247
+ Returns:
248
+ A new ToolboxTool instance with the specified parameters bound.
249
+ """
250
+ all_params = set (p .name for p in self .__params )
251
+ for name in bound_params .keys ():
252
+ if name not in all_params :
253
+ raise Exception (f"unable to bind parameters: no parameter named { name } " )
254
+
255
+ new_params = []
256
+ for p in self .__params :
257
+ if p .name not in bound_params :
258
+ new_params .append (p )
259
+
260
+ return self .__copy (
261
+ params = new_params ,
262
+ bound_params = bound_params ,
263
+ )
264
+
214
265
215
266
def identify_required_authn_params (
216
267
req_authn_params : Mapping [str , list [str ]], auth_service_names : Iterable [str ]
217
268
) -> dict [str , list [str ]]:
218
269
"""
219
- Identifies authentication parameters that are still required; or not covered by
220
- the provided `auth_service_names`.
270
+ Identifies authentication parameters that are still required; because they
271
+ not covered by the provided `auth_service_names`.
221
272
222
273
Args:
223
274
req_authn_params: A mapping of parameter names to sets of required
@@ -226,8 +277,8 @@ def identify_required_authn_params(
226
277
token getters are available.
227
278
228
279
Returns:
229
- A new dictionary representing the subset of required authentication
230
- parameters that are not covered by the provided `auth_service_names `.
280
+ A new dictionary representing the subset of required authentication parameters
281
+ that are not covered by the provided `auth_services `.
231
282
"""
232
283
required_params = {} # params that are still required with provided auth_services
233
284
for param , services in req_authn_params .items ():
0 commit comments