Skip to content

Commit bb7e8d3

Browse files
committed
feat!: Add the Base SDK support to the Toolbox SDK
This is done by updating the underlying `AsyncToolboxTool` class to return a function-like tool instead of a specific LangChain orchestration's `BaseTool` type.
1 parent f4935c0 commit bb7e8d3

File tree

3 files changed

+80
-47
lines changed

3 files changed

+80
-47
lines changed

src/toolbox_langchain/async_tools.py

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,33 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
from copy import deepcopy
16-
from typing import Any, Callable, TypeVar, Union
17+
from typing import Any, Callable, Type, Union
1718
from warnings import warn
1819

1920
from aiohttp import ClientSession
20-
from langchain_core.tools import BaseTool
21+
from pydantic import BaseModel
2122

2223
from .utils import (
24+
ParameterSchema,
2325
ToolSchema,
2426
_find_auth_params,
2527
_find_bound_params,
2628
_invoke_tool,
29+
_parse_type,
30+
_schema_to_docstring,
2731
_schema_to_model,
2832
)
2933

30-
T = TypeVar("T")
31-
3234

3335
# This class is an internal implementation detail and is not exposed to the
3436
# end-user. It should not be used directly by external code. Changes to this
3537
# class will not be considered breaking changes to the public API.
36-
class AsyncToolboxTool(BaseTool):
38+
class AsyncToolboxTool:
3739
"""
38-
A subclass of LangChain's BaseTool that supports features specific to
39-
Toolbox, like bound parameters and authenticated tools.
40+
A class that supports features specific to Toolbox, like bound parameters
41+
and authenticated tools.
4042
"""
4143

4244
def __init__(
@@ -110,51 +112,69 @@ def __init__(
110112

111113
# Bind values for parameters present in the schema that don't require
112114
# authentication.
113-
bound_params = {
115+
__bound_params = {
114116
param_name: param_value
115117
for param_name, param_value in bound_params.items()
116118
if param_name in [param.name for param in non_auth_bound_params]
117119
}
118120

119121
# Update the tools schema to validate only the presence of parameters
120122
# that neither require authentication nor are bound.
121-
schema.parameters = non_auth_non_bound_params
122-
123-
# Due to how pydantic works, we must initialize the underlying
124-
# BaseTool class before assigning values to member variables.
125-
super().__init__(
126-
name=name,
127-
description=schema.description,
128-
args_schema=_schema_to_model(model_name=name, schema=schema.parameters),
129-
)
123+
__updated_schema = deepcopy(schema)
124+
__updated_schema.parameters = non_auth_non_bound_params
130125

131-
self.__name = name
132-
self.__schema = schema
126+
self.__schema = __updated_schema
127+
self.__model = _schema_to_model(name, self.__schema.parameters)
133128
self.__url = url
134129
self.__session = session
135130
self.__auth_tokens = auth_tokens
136131
self.__auth_params = auth_params
137-
self.__bound_params = bound_params
132+
self.__bound_params = __bound_params
138133

139134
# Warn users about any missing authentication so they can add it before
140135
# tool invocation.
141136
self.__validate_auth(strict=False)
142137

143-
def _run(self, **kwargs: Any) -> dict[str, Any]:
144-
raise NotImplementedError("Synchronous methods not supported by async tools.")
138+
# Store parameter definitions for the function signature and annotations
139+
sig_params = []
140+
annotations = {}
141+
for param in self.__schema.parameters:
142+
param_type = _parse_type(param)
143+
annotations[param.name] = param_type
144+
sig_params.append(
145+
inspect.Parameter(
146+
param.name,
147+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
148+
annotation=param_type,
149+
)
150+
)
151+
152+
# Set function name, docstring, signature and annotations
153+
self.__name__ = name
154+
self.__qualname__ = name
155+
self.__doc__ = _schema_to_docstring(self.__schema)
156+
self.__signature__ = inspect.Signature(
157+
parameters=sig_params, return_annotation=dict[str, Any]
158+
)
159+
self.__annotations__ = annotations
145160

146-
async def _arun(self, **kwargs: Any) -> dict[str, Any]:
161+
async def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
147162
"""
148163
The coroutine that invokes the tool with the given arguments.
149164
150165
Args:
151-
**kwargs: The arguments to the tool.
166+
**args: The positional arguments to the tool.
167+
**kwargs: The keyword arguments to the tool.
152168
153169
Returns:
154170
A dictionary containing the parsed JSON response from the tool
155171
invocation.
156172
"""
157173

174+
# Validate arguments
175+
validated_args = self.__signature__.bind(*args, **kwargs).arguments
176+
self.__model.model_validate(validated_args)
177+
158178
# If the tool had parameters that require authentication, then right
159179
# before invoking that tool, we check whether all these required
160180
# authentication sources have been registered or not.
@@ -169,10 +189,10 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]:
169189
evaluated_params[param_name] = param_value
170190

171191
# Merge bound parameters with the provided arguments
172-
kwargs.update(evaluated_params)
192+
validated_args.update(evaluated_params)
173193

174194
return await _invoke_tool(
175-
self.__url, self.__session, self.__name, kwargs, self.__auth_tokens
195+
self.__url, self.__session, self.__name__, validated_args, self.__auth_tokens
176196
)
177197

178198
def __validate_auth(self, strict: bool = True) -> None:
@@ -221,12 +241,12 @@ def __validate_auth(self, strict: bool = True) -> None:
221241

222242
if not is_authenticated:
223243
messages.append(
224-
f"Tool {self.__name} requires authentication, but no valid authentication sources are registered. Please register the required sources before use."
244+
f"Tool {self.__name__} requires authentication, but no valid authentication sources are registered. Please register the required sources before use."
225245
)
226246

227247
if params_missing_auth:
228248
messages.append(
229-
f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name} require authentication, but no valid authentication sources are registered. Please register the required sources before use."
249+
f"Parameter(s) `{', '.join(params_missing_auth)}` of tool {self.__name__} require authentication, but no valid authentication sources are registered. Please register the required sources before use."
230250
)
231251

232252
if messages:
@@ -277,7 +297,7 @@ def __create_copy(
277297
# as errors or warnings, depending on the given `strict` flag.
278298
new_schema.parameters += self.__auth_params
279299
return AsyncToolboxTool(
280-
name=self.__name,
300+
name=self.__name__,
281301
schema=new_schema,
282302
url=self.__url,
283303
session=self.__session,
@@ -317,7 +337,7 @@ def add_auth_tokens(
317337

318338
if dupe_tokens:
319339
raise ValueError(
320-
f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name}`."
340+
f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self.__name__}`."
321341
)
322342

323343
return self.__create_copy(auth_tokens=auth_tokens, strict=strict)
@@ -380,7 +400,7 @@ def bind_params(
380400

381401
if dupe_params:
382402
raise ValueError(
383-
f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name}`."
403+
f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self.__name__}`."
384404
)
385405

386406
return self.__create_copy(bound_params=bound_params, strict=strict)

src/toolbox_langchain/tools.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717
from threading import Thread
1818
from typing import Any, Awaitable, Callable, TypeVar, Union
1919

20-
from langchain_core.tools import BaseTool
21-
2220
from .async_tools import AsyncToolboxTool
2321

2422
T = TypeVar("T")
2523

2624

27-
class ToolboxTool(BaseTool):
25+
class ToolboxTool:
2826
"""
2927
A subclass of LangChain's BaseTool that supports features specific to
3028
Toolbox, like bound parameters and authenticated tools.
@@ -45,14 +43,6 @@ def __init__(
4543
thread: The thread to run blocking operations in.
4644
"""
4745

48-
# Due to how pydantic works, we must initialize the underlying
49-
# BaseTool class before assigning values to member variables.
50-
super().__init__(
51-
name=async_tool.name,
52-
description=async_tool.description,
53-
args_schema=async_tool.args_schema,
54-
)
55-
5646
self.__async_tool = async_tool
5747
self.__loop = loop
5848
self.__thread = thread
@@ -77,11 +67,8 @@ async def __run_as_async(self, coro: Awaitable[T]) -> T:
7767
asyncio.run_coroutine_threadsafe(coro, self.__loop)
7868
)
7969

80-
def _run(self, **kwargs: Any) -> dict[str, Any]:
81-
return self.__run_as_sync(self.__async_tool._arun(**kwargs))
82-
83-
async def _arun(self, **kwargs: Any) -> dict[str, Any]:
84-
return await self.__run_as_async(self.__async_tool._arun(**kwargs))
70+
def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
71+
return self.__run_as_sync(self.__async_tool(*args, **kwargs))
8572

8673
def add_auth_tokens(
8774
self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True

src/toolbox_langchain/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,29 @@ def _find_bound_params(
266266
_non_bound_params.append(param)
267267

268268
return (_bound_params, _non_bound_params)
269+
270+
271+
def _schema_to_docstring(tool_schema: ToolSchema) -> str:
272+
"""Generates a Google Style docstring from a ToolSchema object.
273+
274+
If the schema has parameters, the docstring includes an 'Args:' section
275+
detailing each parameter's name, type, and description. If no parameters are
276+
present, only the tool's description is returned.
277+
278+
Args:
279+
tool_schema: The schema object defining the tool's interface,
280+
including its description and parameters.
281+
282+
Returns:
283+
str: A Google Style formatted docstring.
284+
"""
285+
286+
if not tool_schema.parameters:
287+
return tool_schema.description
288+
289+
docstring = f"{tool_schema.description}\n\nArgs:"
290+
for param in tool_schema.parameters:
291+
docstring += (
292+
f"\n {param.name} ({_parse_type(param).__name__}): {param.description}"
293+
)
294+
return docstring

0 commit comments

Comments
 (0)