13
13
# limitations under the License.
14
14
15
15
import asyncio
16
- from asyncio import AbstractEventLoop
17
- from threading import Thread
18
- from typing import Any , Awaitable , Callable , TypeVar , Union
16
+ from typing import Any , Callable , Union
19
17
20
18
from langchain_core .tools import BaseTool
19
+ from toolbox_core .sync_tool import ToolboxSyncTool as ToolboxCoreSyncTool
21
20
22
- from .async_tools import AsyncToolboxTool
23
-
24
- T = TypeVar ("T" )
25
21
26
22
27
23
class ToolboxTool (BaseTool ):
@@ -32,56 +28,37 @@ class ToolboxTool(BaseTool):
32
28
33
29
def __init__ (
34
30
self ,
35
- async_tool : AsyncToolboxTool ,
36
- loop : AbstractEventLoop ,
37
- thread : Thread ,
31
+ core_sync_tool : ToolboxCoreSyncTool ,
38
32
) -> None :
39
33
"""
40
34
Initializes a ToolboxTool instance.
41
35
42
36
Args:
43
- async_tool: The underlying AsyncToolboxTool instance.
44
- loop: The event loop used to run asynchronous tasks.
45
- thread: The thread to run blocking operations in.
37
+ core_sync_tool: The underlying core sync ToolboxTool instance.
46
38
"""
47
39
48
- # Due to how pydantic works, we must initialize the underlying
49
- # BaseTool class before assigning values to member variables.
40
+ self .__core_sync_tool = core_sync_tool
50
41
super ().__init__ (
51
- name = async_tool . name ,
52
- description = async_tool . description ,
53
- args_schema = async_tool . args_schema ,
42
+ name = self . __core_sync_tool . __name__ ,
43
+ description = self . __core_sync_tool . __doc__ ,
44
+ args_schema = self . __core_sync_tool . _ToolboxSyncTool__pydantic_model ,
54
45
)
55
46
56
- self .__async_tool = async_tool
57
- self .__loop = loop
58
- self .__thread = thread
59
-
60
- def __run_as_sync (self , coro : Awaitable [T ]) -> T :
61
- """Run an async coroutine synchronously"""
62
- if not self .__loop :
63
- raise Exception (
64
- "Cannot call synchronous methods before the background loop is initialized."
65
- )
66
- return asyncio .run_coroutine_threadsafe (coro , self .__loop ).result ()
47
+ def _run (self , ** kwargs : Any ) -> dict [str , Any ]:
48
+ return self .__core_sync_tool (** kwargs )
67
49
68
- async def __run_as_async (self , coro : Awaitable [ T ] ) -> T :
69
- """Run an async coroutine asynchronously"""
50
+ async def _arun (self , ** kwargs : Any ) -> dict [ str , Any ] :
51
+ coro = self . __core_sync_tool . _ToolboxSyncTool__async_tool ( ** kwargs )
70
52
71
53
# If a loop has not been provided, attempt to run in current thread.
72
- if not self .__loop :
54
+ if not self .__core_sync_client . _ToolboxSyncClient__loop :
73
55
return await coro
74
56
75
57
# Otherwise, run in the background thread.
76
- return await asyncio .wrap_future (
77
- asyncio .run_coroutine_threadsafe (coro , self .__loop )
58
+ await asyncio .wrap_future (
59
+ asyncio .run_coroutine_threadsafe (coro , self .__core_sync_client . _ToolboxSyncTool__loop )
78
60
)
79
61
80
- def _run (self , ** kwargs : Any ) -> dict [str , Any ]:
81
- return self .__run_as_sync (self .__async_tool ._arun (** kwargs ))
82
-
83
- async def _arun (self , ** kwargs : Any ) -> dict [str , Any ]:
84
- return await self .__run_as_async (self .__async_tool ._arun (** kwargs ))
85
62
86
63
def add_auth_token_getters (
87
64
self , auth_token_getters : dict [str , Callable [[], str ]], strict : bool = True
@@ -93,27 +70,21 @@ def add_auth_token_getters(
93
70
Args:
94
71
auth_token_getters: A dictionary of authentication source names to
95
72
the functions that return corresponding ID token.
96
- strict: If True, a ValueError is raised if any of the provided auth
97
- parameters is already bound. If False, only a warning is issued.
98
73
99
74
Returns:
100
75
A new ToolboxTool instance that is a deep copy of the current
101
- instance, with added auth tokens .
76
+ instance, with added auth token getters .
102
77
103
78
Raises:
104
79
ValueError: If any of the provided auth parameters is already
105
80
registered.
106
- ValueError: If any of the provided auth parameters is already bound
107
- and strict is True.
108
81
"""
109
- return ToolboxTool (
110
- self .__async_tool .add_auth_token_getters (auth_token_getters , strict ),
111
- self .__loop ,
112
- self .__thread ,
113
- )
82
+ new_core_sync_tool = self .__core_sync_tool .add_auth_token_getters (auth_token_getters )
83
+ return ToolboxTool (core_sync_tool = new_core_sync_tool )
84
+
114
85
115
86
def add_auth_token_getter (
116
- self , auth_source : str , get_id_token : Callable [[], str ], strict : bool = True
87
+ self , auth_source : str , get_id_token : Callable [[], str ]
117
88
) -> "ToolboxTool" :
118
89
"""
119
90
Registers a function to retrieve an ID token for a given authentication
@@ -122,28 +93,19 @@ def add_auth_token_getter(
122
93
Args:
123
94
auth_source: The name of the authentication source.
124
95
get_id_token: A function that returns the ID token.
125
- strict: If True, a ValueError is raised if the provided auth
126
- parameter is already bound. If False, only a warning is issued.
127
96
128
97
Returns:
129
98
A new ToolboxTool instance that is a deep copy of the current
130
99
instance, with added auth token.
131
100
132
101
Raises:
133
102
ValueError: If the provided auth parameter is already registered.
134
- ValueError: If the provided auth parameter is already bound and
135
- strict is True.
136
103
"""
137
- return ToolboxTool (
138
- self .__async_tool .add_auth_token_getter (auth_source , get_id_token , strict ),
139
- self .__loop ,
140
- self .__thread ,
141
- )
104
+ return self .add_auth_token_getters ({auth_source : get_id_token })
142
105
143
106
def bind_params (
144
107
self ,
145
108
bound_params : dict [str , Union [Any , Callable [[], Any ]]],
146
- strict : bool = True ,
147
109
) -> "ToolboxTool" :
148
110
"""
149
111
Registers values or functions to retrieve the value for the
@@ -152,25 +114,16 @@ def bind_params(
152
114
Args:
153
115
bound_params: A dictionary of the bound parameter name to the
154
116
value or function of the bound value.
155
- strict: If True, a ValueError is raised if any of the provided bound
156
- params is not defined in the tool's schema, or requires
157
- authentication. If False, only a warning is issued.
158
117
159
118
Returns:
160
119
A new ToolboxTool instance that is a deep copy of the current
161
120
instance, with added bound params.
162
121
163
122
Raises:
164
123
ValueError: If any of the provided bound params is already bound.
165
- ValueError: if any of the provided bound params is not defined in
166
- the tool's schema, or require authentication, and strict is
167
- True.
168
124
"""
169
- return ToolboxTool (
170
- self .__async_tool .bind_params (bound_params , strict ),
171
- self .__loop ,
172
- self .__thread ,
173
- )
125
+ new_core_sync_tool = self .__core_sync_tool .bind_params (bound_params )
126
+ return ToolboxTool (core_sync_tool = new_core_sync_tool )
174
127
175
128
def bind_param (
176
129
self ,
@@ -186,21 +139,12 @@ def bind_param(
186
139
param_name: The name of the bound parameter.
187
140
param_value: The value of the bound parameter, or a callable that
188
141
returns the value.
189
- strict: If True, a ValueError is raised if the provided bound
190
- param is not defined in the tool's schema, or requires
191
- authentication. If False, only a warning is issued.
192
142
193
143
Returns:
194
144
A new ToolboxTool instance that is a deep copy of the current
195
145
instance, with added bound param.
196
146
197
147
Raises:
198
148
ValueError: If the provided bound param is already bound.
199
- ValueError: if the provided bound param is not defined in the tool's
200
- schema, or requires authentication, and strict is True.
201
149
"""
202
- return ToolboxTool (
203
- self .__async_tool .bind_param (param_name , param_value , strict ),
204
- self .__loop ,
205
- self .__thread ,
206
- )
150
+ return self .bind_params ({param_name : param_value })
0 commit comments