12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import inspect
15
16
from copy import deepcopy
16
- from typing import Any , Callable , TypeVar , Union
17
+ from typing import Any , Callable , Type , Union
17
18
from warnings import warn
18
19
19
20
from aiohttp import ClientSession
20
- from langchain_core . tools import BaseTool
21
+ from pydantic import BaseModel
21
22
22
23
from .utils import (
24
+ ParameterSchema ,
23
25
ToolSchema ,
24
26
_find_auth_params ,
25
27
_find_bound_params ,
26
28
_invoke_tool ,
29
+ _parse_type ,
30
+ _schema_to_docstring ,
27
31
_schema_to_model ,
28
32
)
29
33
30
- T = TypeVar ("T" )
31
-
32
34
33
35
# This class is an internal implementation detail and is not exposed to the
34
36
# end-user. It should not be used directly by external code. Changes to this
35
37
# class will not be considered breaking changes to the public API.
36
- class AsyncToolboxTool ( BaseTool ) :
38
+ class AsyncToolboxTool :
37
39
"""
38
- A subclass of LangChain's BaseTool that supports features specific to
39
- Toolbox, like bound parameters and authenticated tools.
40
+ A class that supports features specific to Toolbox, like bound parameters
41
+ and authenticated tools.
40
42
"""
41
43
42
44
def __init__ (
@@ -110,51 +112,70 @@ def __init__(
110
112
111
113
# Bind values for parameters present in the schema that don't require
112
114
# authentication.
113
- bound_params = {
115
+ __bound_params = {
114
116
param_name : param_value
115
117
for param_name , param_value in bound_params .items ()
116
118
if param_name in [param .name for param in non_auth_bound_params ]
117
119
}
118
120
119
121
# Update the tools schema to validate only the presence of parameters
120
122
# that neither require authentication nor are bound.
121
- schema .parameters = non_auth_non_bound_params
122
-
123
- # Due to how pydantic works, we must initialize the underlying
124
- # BaseTool class before assigning values to member variables.
125
- super ().__init__ (
126
- name = name ,
127
- description = schema .description ,
128
- args_schema = _schema_to_model (model_name = name , schema = schema .parameters ),
129
- )
123
+ __updated_schema = deepcopy (schema )
124
+ __updated_schema .parameters = non_auth_non_bound_params
130
125
131
126
self .__name = name
132
- self .__schema = schema
127
+ self .__schema = __updated_schema
128
+ self .__model = _schema_to_model (self .__name , self .__schema .parameters )
133
129
self .__url = url
134
130
self .__session = session
135
131
self .__auth_tokens = auth_tokens
136
132
self .__auth_params = auth_params
137
- self .__bound_params = bound_params
133
+ self .__bound_params = __bound_params
138
134
139
135
# Warn users about any missing authentication so they can add it before
140
136
# tool invocation.
141
137
self .__validate_auth (strict = False )
142
138
143
- def _run (self , ** kwargs : Any ) -> dict [str , Any ]:
144
- raise NotImplementedError ("Synchronous methods not supported by async tools." )
139
+ # Store parameter definitions for the function signature and annotations
140
+ sig_params = []
141
+ annotations = {}
142
+ for param in self .__schema .parameters :
143
+ param_type = _parse_type (param )
144
+ annotations [param .name ] = param_type
145
+ sig_params .append (
146
+ inspect .Parameter (
147
+ param .name ,
148
+ inspect .Parameter .POSITIONAL_OR_KEYWORD ,
149
+ annotation = param_type ,
150
+ )
151
+ )
152
+
153
+ # Set function name, docstring, signature and annotations
154
+ self .__name__ = self .__name
155
+ self .__qualname__ = self .__name
156
+ self .__doc__ = _schema_to_docstring (self .__schema )
157
+ self .__signature__ = inspect .Signature (
158
+ parameters = sig_params , return_annotation = dict [str , Any ]
159
+ )
160
+ self .__annotations__ = annotations
145
161
146
- async def _arun (self , ** kwargs : Any ) -> dict [str , Any ]:
162
+ async def __call__ (self , * args : Any , ** kwargs : Any ) -> dict [str , Any ]:
147
163
"""
148
164
The coroutine that invokes the tool with the given arguments.
149
165
150
166
Args:
151
- **kwargs: The arguments to the tool.
167
+ **args: The positional arguments to the tool.
168
+ **kwargs: The keyword arguments to the tool.
152
169
153
170
Returns:
154
171
A dictionary containing the parsed JSON response from the tool
155
172
invocation.
156
173
"""
157
174
175
+ # Validate arguments
176
+ validated_args = self .__signature__ .bind (* args , ** kwargs ).arguments
177
+ self .__model .model_validate (validated_args )
178
+
158
179
# If the tool had parameters that require authentication, then right
159
180
# before invoking that tool, we check whether all these required
160
181
# authentication sources have been registered or not.
@@ -169,10 +190,10 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]:
169
190
evaluated_params [param_name ] = param_value
170
191
171
192
# Merge bound parameters with the provided arguments
172
- kwargs .update (evaluated_params )
193
+ validated_args .update (evaluated_params )
173
194
174
195
return await _invoke_tool (
175
- self .__url , self .__session , self .__name , kwargs , self .__auth_tokens
196
+ self .__url , self .__session , self .__name , validated_args , self .__auth_tokens
176
197
)
177
198
178
199
def __validate_auth (self , strict : bool = True ) -> None :
0 commit comments