Skip to content

Commit dedc60a

Browse files
committed
feat!: Add the Base SDK support to the Toolbox SDK
This is done by updating the underlying `AsyncToolboxTool` class to return a function-like tool instead of a specific LangChain orchestration's `BaseTool` type.
1 parent f4935c0 commit dedc60a

File tree

3 files changed

+75
-41
lines changed

3 files changed

+75
-41
lines changed

src/toolbox_langchain/async_tools.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,33 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
from copy import deepcopy
16-
from typing import Any, Callable, TypeVar, Union
17+
from typing import Any, Callable, Type, Union
1718
from warnings import warn
1819

1920
from aiohttp import ClientSession
20-
from langchain_core.tools import BaseTool
21+
from pydantic import BaseModel
2122

2223
from .utils import (
24+
ParameterSchema,
2325
ToolSchema,
2426
_find_auth_params,
2527
_find_bound_params,
2628
_invoke_tool,
29+
_parse_type,
30+
_schema_to_docstring,
2731
_schema_to_model,
2832
)
2933

30-
T = TypeVar("T")
31-
3234

3335
# This class is an internal implementation detail and is not exposed to the
3436
# end-user. It should not be used directly by external code. Changes to this
3537
# class will not be considered breaking changes to the public API.
36-
class AsyncToolboxTool(BaseTool):
38+
class AsyncToolboxTool:
3739
"""
38-
A subclass of LangChain's BaseTool that supports features specific to
39-
Toolbox, like bound parameters and authenticated tools.
40+
A class that supports features specific to Toolbox, like bound parameters
41+
and authenticated tools.
4042
"""
4143

4244
def __init__(
@@ -110,51 +112,70 @@ def __init__(
110112

111113
# Bind values for parameters present in the schema that don't require
112114
# authentication.
113-
bound_params = {
115+
__bound_params = {
114116
param_name: param_value
115117
for param_name, param_value in bound_params.items()
116118
if param_name in [param.name for param in non_auth_bound_params]
117119
}
118120

119121
# Update the tools schema to validate only the presence of parameters
120122
# that neither require authentication nor are bound.
121-
schema.parameters = non_auth_non_bound_params
122-
123-
# Due to how pydantic works, we must initialize the underlying
124-
# BaseTool class before assigning values to member variables.
125-
super().__init__(
126-
name=name,
127-
description=schema.description,
128-
args_schema=_schema_to_model(model_name=name, schema=schema.parameters),
129-
)
123+
__updated_schema = deepcopy(schema)
124+
__updated_schema.parameters = non_auth_non_bound_params
130125

131126
self.__name = name
132-
self.__schema = schema
127+
self.__schema = __updated_schema
128+
self.__model = _schema_to_model(self.__name, self.__schema.parameters)
133129
self.__url = url
134130
self.__session = session
135131
self.__auth_tokens = auth_tokens
136132
self.__auth_params = auth_params
137-
self.__bound_params = bound_params
133+
self.__bound_params = __bound_params
138134

139135
# Warn users about any missing authentication so they can add it before
140136
# tool invocation.
141137
self.__validate_auth(strict=False)
142138

143-
def _run(self, **kwargs: Any) -> dict[str, Any]:
144-
raise NotImplementedError("Synchronous methods not supported by async tools.")
139+
# Store parameter definitions for the function signature and annotations
140+
sig_params = []
141+
annotations = {}
142+
for param in self.__schema.parameters:
143+
param_type = _parse_type(param)
144+
annotations[param.name] = param_type
145+
sig_params.append(
146+
inspect.Parameter(
147+
param.name,
148+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
149+
annotation=param_type,
150+
)
151+
)
152+
153+
# Set function name, docstring, signature and annotations
154+
self.__name__ = self.__name
155+
self.__qualname__ = self.__name
156+
self.__doc__ = _schema_to_docstring(self.__schema)
157+
self.__signature__ = inspect.Signature(
158+
parameters=sig_params, return_annotation=dict[str, Any]
159+
)
160+
self.__annotations__ = annotations
145161

146-
async def _arun(self, **kwargs: Any) -> dict[str, Any]:
162+
async def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
147163
"""
148164
The coroutine that invokes the tool with the given arguments.
149165
150166
Args:
151-
**kwargs: The arguments to the tool.
167+
**args: The positional arguments to the tool.
168+
**kwargs: The keyword arguments to the tool.
152169
153170
Returns:
154171
A dictionary containing the parsed JSON response from the tool
155172
invocation.
156173
"""
157174

175+
# Validate arguments
176+
validated_args = self.__signature__.bind(*args, **kwargs).arguments
177+
self.__model.model_validate(validated_args)
178+
158179
# If the tool had parameters that require authentication, then right
159180
# before invoking that tool, we check whether all these required
160181
# authentication sources have been registered or not.
@@ -169,10 +190,10 @@ async def _arun(self, **kwargs: Any) -> dict[str, Any]:
169190
evaluated_params[param_name] = param_value
170191

171192
# Merge bound parameters with the provided arguments
172-
kwargs.update(evaluated_params)
193+
validated_args.update(evaluated_params)
173194

174195
return await _invoke_tool(
175-
self.__url, self.__session, self.__name, kwargs, self.__auth_tokens
196+
self.__url, self.__session, self.__name, validated_args, self.__auth_tokens
176197
)
177198

178199
def __validate_auth(self, strict: bool = True) -> None:

src/toolbox_langchain/tools.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717
from threading import Thread
1818
from typing import Any, Awaitable, Callable, TypeVar, Union
1919

20-
from langchain_core.tools import BaseTool
21-
2220
from .async_tools import AsyncToolboxTool
2321

2422
T = TypeVar("T")
2523

2624

27-
class ToolboxTool(BaseTool):
25+
class ToolboxTool:
2826
"""
2927
A subclass of LangChain's BaseTool that supports features specific to
3028
Toolbox, like bound parameters and authenticated tools.
@@ -45,14 +43,6 @@ def __init__(
4543
thread: The thread to run blocking operations in.
4644
"""
4745

48-
# Due to how pydantic works, we must initialize the underlying
49-
# BaseTool class before assigning values to member variables.
50-
super().__init__(
51-
name=async_tool.name,
52-
description=async_tool.description,
53-
args_schema=async_tool.args_schema,
54-
)
55-
5646
self.__async_tool = async_tool
5747
self.__loop = loop
5848
self.__thread = thread
@@ -77,11 +67,8 @@ async def __run_as_async(self, coro: Awaitable[T]) -> T:
7767
asyncio.run_coroutine_threadsafe(coro, self.__loop)
7868
)
7969

80-
def _run(self, **kwargs: Any) -> dict[str, Any]:
81-
return self.__run_as_sync(self.__async_tool._arun(**kwargs))
82-
83-
async def _arun(self, **kwargs: Any) -> dict[str, Any]:
84-
return await self.__run_as_async(self.__async_tool._arun(**kwargs))
70+
def __call__(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
71+
return self.__run_as_sync(self.__async_tool(*args, **kwargs))
8572

8673
def add_auth_tokens(
8774
self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True

src/toolbox_langchain/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,29 @@ def _find_bound_params(
266266
_non_bound_params.append(param)
267267

268268
return (_bound_params, _non_bound_params)
269+
270+
271+
def _schema_to_docstring(tool_schema: ToolSchema) -> str:
272+
"""Generates a Google Style docstring from a ToolSchema object.
273+
274+
If the schema has parameters, the docstring includes an 'Args:' section
275+
detailing each parameter's name, type, and description. If no parameters are
276+
present, only the tool's description is returned.
277+
278+
Args:
279+
tool_schema: The schema object defining the tool's interface,
280+
including its description and parameters.
281+
282+
Returns:
283+
str: A Google Style formatted docstring.
284+
"""
285+
286+
if not tool_schema.parameters:
287+
return tool_schema.description
288+
289+
docstring = f"{tool_schema.description}\n\nArgs:"
290+
for param in tool_schema.parameters:
291+
docstring += (
292+
f"\n {param.name} ({_parse_type(param).__name__}): {param.description}"
293+
)
294+
return docstring

0 commit comments

Comments
 (0)