13
13
# limitations under the License.
14
14
15
15
from copy import deepcopy
16
- from typing import Any , Callable
16
+ from typing import Any , Callable , Union
17
17
from warnings import warn
18
18
19
19
from aiohttp import ClientSession
24
24
ParameterSchema ,
25
25
ToolSchema ,
26
26
_find_auth_params ,
27
+ _find_bound_params ,
27
28
_invoke_tool ,
28
29
_schema_to_model ,
29
30
)
32
33
class ToolboxTool (StructuredTool ):
33
34
"""
34
35
A subclass of LangChain's StructuredTool that supports features specific to
35
- Toolbox, like authenticated tools.
36
+ Toolbox, like bound parameters and authenticated tools.
36
37
"""
37
38
38
39
def __init__ (
@@ -42,6 +43,8 @@ def __init__(
42
43
url : str ,
43
44
session : ClientSession ,
44
45
auth_tokens : dict [str , Callable [[], str ]] = {},
46
+ bound_params : dict [str , Union [Any , Callable [[], Any ]]] = {},
47
+ strict : bool = True ,
45
48
) -> None :
46
49
"""
47
50
Initializes a ToolboxTool instance.
@@ -53,6 +56,11 @@ def __init__(
53
56
session: The HTTP client session.
54
57
auth_tokens: A mapping of authentication source names to functions
55
58
that retrieve ID tokens.
59
+ bound_params: A mapping of parameter names to their bound
60
+ values.
61
+ strict: If True, raises a ValueError if any of the given bound
62
+ parameters are missing from the schema or require
63
+ authentication. If False, only issues a warning.
56
64
"""
57
65
58
66
# If the schema is not already a ToolSchema instance, we create one from
@@ -63,10 +71,51 @@ def __init__(
63
71
schema = ToolSchema (** schema )
64
72
65
73
auth_params , non_auth_params = _find_auth_params (schema .parameters )
74
+ non_auth_bound_params , non_auth_non_bound_params = _find_bound_params (
75
+ non_auth_params , list (bound_params )
76
+ )
77
+
78
+ # Check if the user is trying to bind a param that is authenticated or
79
+ # is missing from the given schema.
80
+ auth_bound_params : list [str ] = []
81
+ missing_bound_params : list [str ] = []
82
+ for bound_param in bound_params :
83
+ if bound_param in [param .name for param in auth_params ]:
84
+ auth_bound_params .append (bound_param )
85
+ elif bound_param not in [param .name for param in non_auth_params ]:
86
+ missing_bound_params .append (bound_param )
87
+
88
+ # Create error messages for any params that are found to be
89
+ # authenticated or missing.
90
+ messages : list [str ] = []
91
+ if auth_bound_params :
92
+ messages .append (
93
+ f"Parameter(s) { ', ' .join (auth_bound_params )} already authenticated and cannot be bound."
94
+ )
95
+ if missing_bound_params :
96
+ messages .append (
97
+ f"Parameter(s) { ', ' .join (missing_bound_params )} missing and cannot be bound."
98
+ )
99
+
100
+ # Join any error messages and raise them as an error or warning,
101
+ # depending on the value of the strict flag.
102
+ if messages :
103
+ message = "\n \n " .join (messages )
104
+ if strict :
105
+ raise ValueError (message )
106
+ warn (message )
107
+
108
+ # Bind values for parameters present in the schema that don't require
109
+ # authentication.
110
+ bound_params = {
111
+ param_name : param_value
112
+ for param_name , param_value in bound_params .items ()
113
+ if param_name in [param .name for param in non_auth_bound_params ]
114
+ }
66
115
67
116
# Update the tools schema to validate only the presence of parameters
68
- # that do not require authentication.
69
- schema .parameters = non_auth_params
117
+ # that neither require authentication nor are bound .
118
+ schema .parameters = non_auth_non_bound_params
70
119
71
120
# Due to how pydantic works, we must initialize the underlying
72
121
# StructuredTool class before assigning values to member variables.
@@ -84,6 +133,7 @@ def __init__(
84
133
self ._session : ClientSession = session
85
134
self ._auth_tokens : dict [str , Callable [[], str ]] = auth_tokens
86
135
self ._auth_params : list [ParameterSchema ] = auth_params
136
+ self ._bound_params : dict [str , Union [Any , Callable [[], Any ]]] = bound_params
87
137
88
138
# Warn users about any missing authentication so they can add it before
89
139
# tool invocation.
@@ -106,6 +156,17 @@ async def __tool_func(self, **kwargs: Any) -> dict:
106
156
# authentication sources have been registered or not.
107
157
self .__validate_auth ()
108
158
159
+ # Evaluate dynamic parameter values if any
160
+ evaluated_params = {}
161
+ for param_name , param_value in self ._bound_params .items ():
162
+ if callable (param_value ):
163
+ evaluated_params [param_name ] = param_value ()
164
+ else :
165
+ evaluated_params [param_name ] = param_value
166
+
167
+ # Merge bound parameters with the provided arguments
168
+ kwargs .update (evaluated_params )
169
+
109
170
return await _invoke_tool (
110
171
self ._url , self ._session , self ._name , kwargs , self ._auth_tokens
111
172
)
@@ -154,42 +215,66 @@ def __create_copy(
154
215
self ,
155
216
* ,
156
217
auth_tokens : dict [str , Callable [[], str ]] = {},
218
+ bound_params : dict [str , Union [Any , Callable [[], Any ]]] = {},
219
+ strict : bool ,
157
220
) -> Self :
158
221
"""
159
222
Creates a deep copy of the current ToolboxTool instance, allowing for
160
- modification of auth tokens.
223
+ modification of auth tokens and bound params .
161
224
162
225
This method enables the creation of new tool instances with inherited
163
226
properties from the current instance, while optionally updating the auth
164
- tokens. This is useful for creating variations of the tool with
165
- additional auth tokens without modifying the original instance, ensuring
166
- immutability.
227
+ tokens and bound params . This is useful for creating variations of the
228
+ tool with additional auth tokens or bound params without modifying the
229
+ original instance, ensuring immutability.
167
230
168
231
Args:
169
232
auth_tokens: A dictionary of auth source names to functions that
170
233
retrieve ID tokens. These tokens will be merged with the
171
234
existing auth tokens.
235
+ bound_params: A dictionary of parameter names to their
236
+ bound values or functions to retrieve the values. These params
237
+ will be merged with the existing bound params.
238
+ strict: If True, raises a ValueError if any of the given bound
239
+ parameters are missing from the schema or require
240
+ authentication. If False, only issues a warning.
172
241
173
242
Returns:
174
243
A new ToolboxTool instance that is a deep copy of the current
175
- instance, with optionally updated auth tokens.
244
+ instance, with added auth tokens or bound params .
176
245
"""
246
+ new_schema = deepcopy (self ._schema )
247
+
248
+ # Reconstruct the complete parameter schema by merging the auth
249
+ # parameters back with the non-auth parameters. This is necessary to
250
+ # accurately validate the new combination of auth tokens and bound
251
+ # params in the constructor of the new ToolboxTool instance, ensuring
252
+ # that any overlaps or conflicts are correctly identified and reported
253
+ # as errors or warnings, depending on the given `strict` flag.
254
+ new_schema .parameters += self ._auth_params
177
255
return type (self )(
178
256
name = self ._name ,
179
- schema = deepcopy ( self . _schema ) ,
257
+ schema = new_schema ,
180
258
url = self ._url ,
181
259
session = self ._session ,
182
260
auth_tokens = {** self ._auth_tokens , ** auth_tokens },
261
+ bound_params = {** self ._bound_params , ** bound_params },
262
+ strict = strict ,
183
263
)
184
264
185
- def add_auth_tokens (self , auth_tokens : dict [str , Callable [[], str ]]) -> Self :
265
+ def add_auth_tokens (
266
+ self , auth_tokens : dict [str , Callable [[], str ]], strict : bool = True
267
+ ) -> Self :
186
268
"""
187
269
Registers functions to retrieve ID tokens for the corresponding
188
270
authentication sources.
189
271
190
272
Args:
191
273
auth_tokens: A dictionary of authentication source names to the
192
274
functions that return corresponding ID token.
275
+ strict: If True, a ValueError is raised if any of the provided auth
276
+ tokens are already registered, or are already bound. If False,
277
+ only a warning is issued.
193
278
194
279
Returns:
195
280
A new ToolboxTool instance that is a deep copy of the current
@@ -207,19 +292,82 @@ def add_auth_tokens(self, auth_tokens: dict[str, Callable[[], str]]) -> Self:
207
292
f"Authentication source(s) `{ ', ' .join (dupe_tokens )} ` already registered in tool `{ self ._name } `."
208
293
)
209
294
210
- return self .__create_copy (auth_tokens = auth_tokens )
295
+ return self .__create_copy (auth_tokens = auth_tokens , strict = strict )
211
296
212
- def add_auth_token (self , auth_source : str , get_id_token : Callable [[], str ]) -> Self :
297
+ def add_auth_token (
298
+ self , auth_source : str , get_id_token : Callable [[], str ], strict : bool = True
299
+ ) -> Self :
213
300
"""
214
301
Registers a function to retrieve an ID token for a given authentication
215
302
source.
216
303
217
304
Args:
218
305
auth_source: The name of the authentication source.
219
306
get_id_token: A function that returns the ID token.
307
+ strict: If True, a ValueError is raised if any of the provided auth
308
+ tokens are already registered, or are already bound. If False,
309
+ only a warning is issued.
220
310
221
311
Returns:
222
312
A new ToolboxTool instance that is a deep copy of the current
223
313
instance, with added auth tokens.
224
314
"""
225
- return self .add_auth_tokens ({auth_source : get_id_token })
315
+ return self .add_auth_tokens ({auth_source : get_id_token }, strict = strict )
316
+
317
+ def bind_params (
318
+ self ,
319
+ bound_params : dict [str , Union [Any , Callable [[], Any ]]],
320
+ strict : bool = True ,
321
+ ) -> Self :
322
+ """
323
+ Registers values or functions to retrieve the value for the
324
+ corresponding bound parameters.
325
+
326
+ Args:
327
+ bound_params: A dictionary of the bound parameter name to the
328
+ value or function of the bound value.
329
+ strict: If True, a ValueError is raised if any of the provided bound
330
+ params are already bound, not defined in the tool's schema, or
331
+ require authentication. If False, only a warning is issued.
332
+
333
+ Returns:
334
+ A new ToolboxTool instance that is a deep copy of the current
335
+ instance, with added bound params.
336
+ """
337
+
338
+ # Check if the parameter is already bound.
339
+ dupe_params : list [str ] = []
340
+ for param_name , _ in bound_params .items ():
341
+ if param_name in self ._bound_params :
342
+ dupe_params .append (param_name )
343
+
344
+ if dupe_params :
345
+ raise ValueError (
346
+ f"Parameter(s) `{ ', ' .join (dupe_params )} ` already bound in tool `{ self ._name } `."
347
+ )
348
+
349
+ return self .__create_copy (bound_params = bound_params , strict = strict )
350
+
351
+ def bind_param (
352
+ self ,
353
+ param_name : str ,
354
+ param_value : Union [Any , Callable [[], Any ]],
355
+ strict : bool = True ,
356
+ ) -> Self :
357
+ """
358
+ Registers a value or a function to retrieve the value for a given
359
+ bound parameter.
360
+
361
+ Args:
362
+ param_name: The name of the bound parameter.
363
+ param_value: The value of the bound parameter, or a callable
364
+ that returns the value.
365
+ strict: If True, a ValueError is raised if any of the provided bound
366
+ params are already bound, not defined in the tool's schema, or
367
+ require authentication. If False, only a warning is issued.
368
+
369
+ Returns:
370
+ A new ToolboxTool instance that is a deep copy of the current
371
+ instance, with added bound params.
372
+ """
373
+ return self .bind_params ({param_name : param_value }, strict )
0 commit comments