Skip to content

Commit 78f392c

Browse files
committed
feat!: Add the Base SDK support to the Toolbox SDK
This is done by updating the underlying `AsyncToolboxTool` class to return a function-like tool instead of a specific LangChain orchestration's `BaseTool` type.
1 parent f4935c0 commit 78f392c

File tree

3 files changed

+84
-49
lines changed

3 files changed

+84
-49
lines changed

src/toolbox_langchain/async_tools.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,31 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
from copy import deepcopy
16-
from typing import Any, Callable, TypeVar, Union
17+
from typing import Any, Callable, Union
1718
from warnings import warn
1819

1920
from aiohttp import ClientSession
20-
from langchain_core.tools import BaseTool
2121

2222
from .utils import (
2323
ToolSchema,
2424
_find_auth_params,
2525
_find_bound_params,
2626
_invoke_tool,
27+
_parse_type,
28+
_schema_to_docstring,
2729
_schema_to_model,
2830
)
2931

30-
T = TypeVar("T")
31-
3232

3333
# This class is an internal implementation detail and is not exposed to the
3434
# end-user. It should not be used directly by external code. Changes to this
3535
# class will not be considered breaking changes to the public API.
36-
class AsyncToolboxTool(BaseTool):
36+
class AsyncToolboxTool:
3737
"""
38-
A subclass of LangChain's BaseTool that supports features specific to
39-
Toolbox, like bound parameters and authenticated tools.
38+
A class that supports features specific to Toolbox, like bound parameters
39+
and authenticated tools.
4040
"""
4141

4242
def __init__(
@@ -110,51 +110,69 @@ def __init__(
110110

111111
# Bind values for parameters present in the schema that don't require
112112
# authentication.
113-
bound_params = {
113+
__bound_params = {
114114
param_name: param_value
115115
for param_name, param_value in bound_params.items()
116116
if param_name in [param.name for param in non_auth_bound_params]
117117
}
118118

119119
# Update the tools schema to validate only the presence of parameters
120120
# that neither require authentication nor are bound.
121-
schema.parameters = non_auth_non_bound_params
122-
123-
# Due to how pydantic works, we must initialize the underlying
124-
# BaseTool class before assigning values to member variables.
125-
super().__init__(
126-
name=name,
127-
description=schema.description,
128-
args_schema=_schema_to_model(model_name=name, schema=schema.parameters),
129-
)
121+
__updated_schema = deepcopy(schema)
122+
__updated_schema.parameters = non_auth_non_bound_params
130123

131-
self.__name = name
132-
self.__schema = schema
124+
self.__schema = __updated_schema
125+
self.__model = _schema_to_model(name, self.__schema.parameters)
133126
self.__url = url
134127
self.__session = session
135128
self.__auth_tokens = auth_tokens
136129
self.__auth_params = auth_params
137-
self.__bound_params = bound_params
130+
self.__bound_params = __bound_params
138131

139132
# Warn users about any missing authentication so they can add it before
140133
# tool invocation.
141134
self.__validate_auth(strict=False)
142135

143-
def _run(self, **kwargs: Any) -> dict[str, Any]:
144-
raise NotImplementedError("Synchronous methods not supported by async tools.")
136+
# Store parameter definitions for the function signature and annotations
137+
sig_params = []
138+
annotations = {}
139+
for param in self.__schema.parameters:
140+
param_type = _parse_type(param)
141+
annotations[param.name] = param_type
142+
sig_params.append(
143+
inspect.Parameter(
144+
param.name,
145+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
146+
annotation=param_type,
147+
)
148+
)
149+
150+
# Set function name, docstring, signature and annotations
151+
self.__name__ = name
152+
self.__qualname__ = name
153+
self.__doc__ = _schema_to_docstring(self.__schema)
154+
self.__signature__ = inspect.Signature(
155+
parameters=sig_params, return_annotation=dict[str, Any]
156+
)
157+
self.__annotations__ = annotations
145158

146-
async def _arun(self, **kwargs: Any) -> dict[str, Any]:
159+
async def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
147160
"""
148161
The coroutine that invokes the tool with the given arguments.
149162
150163
Args:
151-
**kwargs: The arguments to the tool.
164+
**args: The positional arguments to the tool.
165+
**kwargs: The keyword arguments to the tool.
152166
153167
Returns:
154168
A dictionary containing the parsed JSON response from the tool
155169
invocation.
156170
"""
157171

172+
# Validate arguments
173+
validated_args = self.__signature__.bind(*args, **kwargs).arguments
174+
self.__model.model_validate(validated_args)
175+
158176
# If the tool had parameters that require authentication, then right
159177
# before invoking that tool, we check whether all these required
160178
# authentication sources have been registered or not.
@@ -169,10 +187,14 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]:
169187
evaluated_params[param_name] = param_value
170188

171189
# Merge bound parameters with the provided arguments
172-
kwargs.update(evaluated_params)
190+
validated_args.update(evaluated_params)
173191

174192
return await _invoke_tool(
175-
self.__url, self.__session, self.__name, kwargs, self.__auth_tokens
193+
self.__url,
194+
self.__session,
195+
self.__name__,
196+
validated_args,
197+
self.__auth_tokens,
176198
)
177199

178200
def __validate_auth(self, strict: bool = True) -> None:
@@ -221,12 +243,12 @@ def __validate_auth(self, strict: bool = True) -> None:
221243

222244
if not is_authenticated:
223245
messages.append(
224-
f"Tool {self.__name} requires authentication, but no valid authentication sources are registered. Please register the required sources before use."
246+
f"Tool {self.__name__} requires authentication, but no valid authentication sources are registered. Please register the required sources before use."
225247
)
226248

227249
if params_missing_auth:
228250
messages.append(
229-
f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use."
251+
f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name__} require authentication, but no valid authentication sources are registered. Please register the required sources before use."
230252
)
231253

232254
if messages:
@@ -277,7 +299,7 @@ def __create_copy(
277299
# as errors or warnings, depending on the given `strict` flag.
278300
new_schema.parameters += self.__auth_params
279301
return AsyncToolboxTool(
280-
name=self.__name,
302+
name=self.__name__,
281303
schema=new_schema,
282304
url=self.__url,
283305
session=self.__session,
@@ -317,7 +339,7 @@ def add_auth_tokens(
317339

318340
if dupe_tokens:
319341
raise ValueError(
320-
f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`."
342+
f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name__}`."
321343
)
322344

323345
return self.__create_copy(auth_tokens=auth_tokens, strict=strict)
@@ -380,7 +402,7 @@ def bind_params(
380402

381403
if dupe_params:
382404
raise ValueError(
383-
f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name}`."
405+
f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name__}`."
384406
)
385407

386408
return self.__create_copy(bound_params=bound_params, strict=strict)

src/toolbox_langchain/tools.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,15 @@
1717
from threading import Thread
1818
from typing import Any, Awaitable, Callable, TypeVar, Union
1919

20-
from langchain_core.tools import BaseTool
21-
2220
from .async_tools import AsyncToolboxTool
2321

2422
T = TypeVar("T")
2523

2624

27-
class ToolboxTool(BaseTool):
25+
class ToolboxTool:
2826
"""
29-
A subclass of LangChain's BaseTool that supports features specific to
30-
Toolbox, like bound parameters and authenticated tools.
27+
A class that supports features specific to Toolbox, like bound parameters
28+
and authenticated tools.
3129
"""
3230

3331
def __init__(
@@ -45,14 +43,6 @@ def __init__(
4543
thread: The thread to run blocking operations in.
4644
"""
4745

48-
# Due to how pydantic works, we must initialize the underlying
49-
# BaseTool class before assigning values to member variables.
50-
super().__init__(
51-
name=async_tool.name,
52-
description=async_tool.description,
53-
args_schema=async_tool.args_schema,
54-
)
55-
5646
self.__async_tool = async_tool
5747
self.__loop = loop
5848
self.__thread = thread
@@ -77,11 +67,8 @@ async def __run_as_async(self, coro: Awaitable[T]) -> T:
7767
asyncio.run_coroutine_threadsafe(coro, self.__loop)
7868
)
7969

80-
def _run(self, **kwargs: Any) -> dict[str, Any]:
81-
return self.__run_as_sync(self.__async_tool._arun(**kwargs))
82-
83-
async def _arun(self, **kwargs: Any) -> dict[str, Any]:
84-
return await self.__run_as_async(self.__async_tool._arun(**kwargs))
70+
def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
71+
return self.__run_as_sync(self.__async_tool(*args, **kwargs))
8572

8673
def add_auth_tokens(
8774
self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True

src/toolbox_langchain/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,29 @@ def _find_bound_params(
266266
_non_bound_params.append(param)
267267

268268
return (_bound_params, _non_bound_params)
269+
270+
271+
def _schema_to_docstring(tool_schema: ToolSchema) -> str:
272+
"""Generates a Google Style docstring from a ToolSchema object.
273+
274+
If the schema has parameters, the docstring includes an 'Args:' section
275+
detailing each parameter's name, type, and description. If no parameters are
276+
present, only the tool's description is returned.
277+
278+
Args:
279+
tool_schema: The schema object defining the tool's interface,
280+
including its description and parameters.
281+
282+
Returns:
283+
str: A Google Style formatted docstring.
284+
"""
285+
286+
if not tool_schema.parameters:
287+
return tool_schema.description
288+
289+
docstring = f"{tool_schema.description}\n\nArgs:"
290+
for param in tool_schema.parameters:
291+
docstring += (
292+
f"\n {param.name} ({_parse_type(param).__name__}): {param.description}"
293+
)
294+
return docstring

0 commit comments

Comments
 (0)