1313# limitations under the License.
1414
1515
16+ import asyncio
1617import types
1718from collections import defaultdict
1819from inspect import Parameter , Signature
19- from typing import Any , Callable , DefaultDict , Iterable , Mapping , Optional , Sequence
20+ from typing import (
21+ Any ,
22+ Callable ,
23+ DefaultDict ,
24+ Iterable ,
25+ Mapping ,
26+ Optional ,
27+ Sequence ,
28+ Union ,
29+ )
2030
2131from aiohttp import ClientSession
2232from pytest import Session
@@ -44,6 +54,7 @@ def __init__(
4454 params : Sequence [Parameter ],
4555 required_authn_params : Mapping [str , list [str ]],
4656 auth_service_token_getters : Mapping [str , Callable [[], str ]],
57+ bound_params : Mapping [str , Union [Callable [[], Any ], Any ]] = {},
4758 ):
4859 """
4960 Initializes a callable that will trigger the tool invocation through the
@@ -81,6 +92,8 @@ def __init__(
8192 self .__required_authn_params = required_authn_params
8293 # map of authService -> token_getter
8394 self .__auth_service_token_getters = auth_service_token_getters
95+ # map of parameter name to value or Callable
96+ self .__bound_parameters = bound_params
8497
8598 def __copy (
8699 self ,
@@ -91,6 +104,7 @@ def __copy(
91104 params : Optional [list [Parameter ]] = None ,
92105 required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
93106 auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
107+ bound_params : Optional [Mapping [str , Union [Callable [[], Any ], Any ]]] = None ,
94108 ) -> "ToolboxTool" :
95109 """
96110 Creates a copy of the ToolboxTool, overriding specific fields.
@@ -117,6 +131,7 @@ def __copy(
117131 required_authn_params = required_authn_params or self .__required_authn_params ,
118132 auth_service_token_getters = auth_service_token_getters
119133 or self .__auth_service_token_getters ,
134+ bound_params = bound_params or self .__bound_parameters ,
120135 )
121136
122137 async def __call__ (self , * args : Any , ** kwargs : Any ) -> str :
@@ -146,6 +161,14 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
146161 all_args .apply_defaults () # Include default values if not provided
147162 payload = all_args .arguments
148163
164+ # apply bounded parameters
165+ for param , value in self .__bound_parameters .items ():
166+ if asyncio .iscoroutinefunction (value ):
167+ value = await value ()
168+ elif callable (value ):
169+ value = value ()
170+ payload [param ] = value
171+
149172 # create headers for auth services
150173 headers = {}
151174 for auth_service , token_getter in self .__auth_service_token_getters .items ():
@@ -202,12 +225,41 @@ def add_auth_token_getters(
202225 required_authn_params = new_req_authn_params ,
203226 )
204227
228+ def bind_parameters (
229+ self , bound_params : Mapping [str , Callable [[], str ]]
230+ ) -> "ToolboxTool" :
231+ """
232+ Binds parameters to values or callables that produce values.
233+
234+ Args:
235+ bound_params: A mapping of parameter names to values or callables that
236+ produce values.
237+
238+ Returns:
239+ A new ToolboxTool instance with the specified parameters bound.
240+ """
241+ all_params = set (p .name for p in self .__params )
242+ for name in bound_params .keys ():
243+ if name not in all_params :
244+ raise Exception (f"unable to bind parameters: no parameter named { name } " )
245+
246+ new_params = []
247+ for p in self .__params :
248+ if p .name not in bound_params :
249+ new_params .append (p )
250+
251+ return self .__copy (
252+ params = new_params ,
253+ bound_params = bound_params ,
254+ )
255+
205256
206257def filter_required_authn_params (
207258 req_authn_params : Mapping [str , list [str ]], auth_services : Iterable [str ]
208259) -> dict [str , list [str ]]:
209260 """
210- Utility function for reducing 'req_authn_params' to a subset of parameters that aren't supplied by a least one service in auth_services.
261+ Utility function for reducing 'req_authn_params' to a subset of parameters that
262+ aren't supplied by a least one service in auth_services.
211263
212264 Args:
213265 req_authn_params: A mapping of parameter names to sets of required
@@ -216,8 +268,8 @@ def filter_required_authn_params(
216268 token getters are available.
217269
218270 Returns:
219- A new dictionary representing the subset of required authentication
220- parameters that are not covered by the provided `auth_services`.
271+ A new dictionary representing the subset of required authentication parameters
272+ that are not covered by the provided `auth_services`.
221273 """
222274 req_params = {}
223275 for param , services in req_authn_params .items ():
0 commit comments