Skip to content

Commit 6fe2e39

Browse files
authored
feat(langchain-sdk)!: Migrate existing state and APIs to a tools level class. (#171)
# `ToolboxTool` Class The newly implemented `ToolboxTool` class manages tool state and supports features like bound params and OAuth. We have also added the logic as well as state related to bound params to `ToolboxTool` in #192. `ToolboxTool` follows a functional approach, meaning it ensures that the internal tool state remains unchanged and a new copy of the tool is updated with the new values of auth tokens (or bound params from #192). ## `ToolboxTool` Class Diagram To better explain, here's a visual diagram of the ToolboxTool class. The boxes have member function names, and the "red" box means a function is mutating the internal state, while "green" means it does not change the internal state (like a C++ `const` function or JS immutable function). The red box also shows the member variable they mutated. ![Image of a diagram showing the `ToolboxTool` class with member functions and their colors indicating whether they mutate the internal state](https://github.com/user-attachments/assets/a59eadf2-4aa5-49f0-9b8d-65864e26df07) ## Constructor From the diagram above, it is evident that all the state changes, or rather setting of the state, is done in the class constructor. For further reference, here's a visual explanation of the functionality of the constructor showing which params we add as auth tokens, which params we add to the underlying schema, and for what the error/warnings are thrown. ![Image of a diagram showing the functionality of the `ToolboxTool` constructor](https://github.com/user-attachments/assets/b4e6c0a1-f0da-4cf5-a56c-976c4c4b006e) > [!NOTE] > Documentation updates are done by #193 and the corresponding changes to the LlamaIndex SDK are done by #203.
1 parent 47ef875 commit 6fe2e39

File tree

7 files changed

+818
-1028
lines changed

7 files changed

+818
-1028
lines changed

src/toolbox_langchain_sdk/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
# limitations under the License.
1414

1515
from .client import ToolboxClient
16+
from .tools import ToolboxTool
1617

17-
__all__ = ["ToolboxClient"]
18+
__all__ = ["ToolboxClient", "ToolboxTool"]

src/toolbox_langchain_sdk/client.py

Lines changed: 43 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@
1313
# limitations under the License.
1414

1515
import asyncio
16-
from typing import Any, Callable, Optional, Type
16+
from typing import Any, Callable, Optional, Union
1717
from warnings import warn
1818

1919
from aiohttp import ClientSession
20-
from deprecated import deprecated
21-
from langchain_core.tools import StructuredTool
22-
from pydantic import BaseModel
2320

24-
from .utils import ManifestSchema, _invoke_tool, _load_manifest, _schema_to_model
21+
from .tools import ToolboxTool
22+
from .utils import ManifestSchema, _load_manifest
2523

2624

2725
class ToolboxClient:
@@ -31,18 +29,16 @@ def __init__(self, url: str, session: Optional[ClientSession] = None):
3129
3230
Args:
3331
url: The base URL of the Toolbox service.
34-
session: The HTTP client session.
35-
Default: None
32+
session: An optional HTTP client session. If not provided, a new
33+
session will be created.
3634
"""
3735
self._url: str = url
3836
self._should_close_session: bool = session is None
39-
self._id_token_getters: dict[str, Callable[[], str]] = {}
40-
self._tool_param_auth: dict[str, dict[str, list[str]]] = {}
4137
self._session: ClientSession = session or ClientSession()
4238

4339
async def close(self) -> None:
4440
"""
45-
Close the Toolbox client and its tools.
41+
Closes the HTTP client session if it was created by this client.
4642
"""
4743
# We check whether _should_close_session is set or not since we do not
4844
# want to close the session in case the user had passed their own
@@ -52,14 +48,18 @@ async def close(self) -> None:
5248
await self._session.close()
5349

5450
def __del__(self):
51+
"""
52+
Ensures the HTTP client session is closed when the client is garbage
53+
collected.
54+
"""
5555
try:
5656
loop = asyncio.get_event_loop()
5757
if loop.is_running():
5858
loop.create_task(self.close())
5959
else:
6060
loop.run_until_complete(self.close())
6161
except Exception:
62-
# We "pass" assuming that the exception is thrown because the event
62+
# We "pass" assuming that the exception is thrown because the event
6363
# loop is no longer running, but at that point the Session should
6464
# have been closed already anyway.
6565
pass
@@ -85,158 +85,32 @@ async def _load_toolset_manifest(
8585
Fetches and parses the manifest schema from the Toolbox service.
8686
8787
Args:
88-
toolset_name: The name of the toolset to load.
89-
Default: None. If not provided, then all the available tools are
90-
loaded.
88+
toolset_name: The name of the toolset to load. If not provided,
89+
the manifest for all available tools is loaded.
9190
9291
Returns:
9392
The parsed Toolbox manifest.
9493
"""
9594
url = f"{self._url}/api/toolset/{toolset_name or ''}"
9695
return await _load_manifest(url, self._session)
9796

98-
def _validate_auth(self, tool_name: str) -> bool:
99-
"""
100-
Helper method that validates the authentication requirements of the tool
101-
with the given tool_name. We consider the validation to pass if at least
102-
one auth sources of each of the auth parameters, of the given tool, is
103-
registered.
104-
105-
Args:
106-
tool_name: Name of the tool to validate auth sources for.
107-
108-
Returns:
109-
True if at least one permitted auth source of each of the auth
110-
params, of the given tool, is registered. Also returns True if the
111-
given tool does not require any auth sources.
112-
"""
113-
114-
if tool_name not in self._tool_param_auth:
115-
return True
116-
117-
for permitted_auth_sources in self._tool_param_auth[tool_name].values():
118-
found_match = False
119-
for registered_auth_source in self._id_token_getters:
120-
if registered_auth_source in permitted_auth_sources:
121-
found_match = True
122-
break
123-
if not found_match:
124-
return False
125-
return True
126-
127-
def _generate_tool(
128-
self, tool_name: str, manifest: ManifestSchema
129-
) -> StructuredTool:
130-
"""
131-
Creates a StructuredTool object and a dynamically generated BaseModel
132-
for the given tool.
133-
134-
Args:
135-
tool_name: The name of the tool to generate.
136-
manifest: The parsed Toolbox manifest.
137-
138-
Returns:
139-
The generated tool.
140-
"""
141-
tool_schema = manifest.tools[tool_name]
142-
tool_model: Type[BaseModel] = _schema_to_model(
143-
model_name=tool_name, schema=tool_schema.parameters
144-
)
145-
146-
# If the tool had parameters that require authentication, then right
147-
# before invoking that tool, we validate whether all these required
148-
# authentication sources have been registered or not.
149-
async def _tool_func(**kwargs: Any) -> dict:
150-
if not self._validate_auth(tool_name):
151-
raise PermissionError(f"Login required before invoking {tool_name}.")
152-
153-
return await _invoke_tool(
154-
self._url, self._session, tool_name, kwargs, self._id_token_getters
155-
)
156-
157-
return StructuredTool.from_function(
158-
coroutine=_tool_func,
159-
name=tool_name,
160-
description=tool_schema.description,
161-
args_schema=tool_model,
162-
)
163-
164-
def _process_auth_params(self, manifest: ManifestSchema) -> None:
165-
"""
166-
Extracts parameters requiring authentication from the manifest.
167-
Verifies each parameter has at least one valid auth source.
168-
169-
Args:
170-
manifest: The manifest to validate and modify.
171-
172-
Warns:
173-
UserWarning: If a parameter in the manifest has no valid sources.
174-
"""
175-
for tool_name, tool_schema in manifest.tools.items():
176-
non_auth_params = []
177-
for param in tool_schema.parameters:
178-
179-
# Extract auth params from the tool schema.
180-
#
181-
# These parameters are removed from the manifest to prevent data
182-
# validation errors since their values are inferred by the
183-
# Toolbox service, not provided by the user.
184-
#
185-
# Store the permitted authentication sources for each parameter
186-
# in '_tool_param_auth' for efficient validation in
187-
# '_validate_auth'.
188-
if not param.authSources:
189-
non_auth_params.append(param)
190-
continue
191-
192-
self._tool_param_auth.setdefault(tool_name, {})[
193-
param.name
194-
] = param.authSources
195-
196-
tool_schema.parameters = non_auth_params
197-
198-
# If none of the permitted auth sources of a parameter are
199-
# registered, raise a warning message to the user.
200-
if not self._validate_auth(tool_name):
201-
warn(
202-
f"Some parameters of tool {tool_name} require authentication, but no valid auth sources are registered. Please register the required sources before use."
203-
)
204-
205-
@deprecated("Please use `add_auth_token` instead.")
206-
def add_auth_header(
207-
self, auth_source: str, get_id_token: Callable[[], str]
208-
) -> None:
209-
self.add_auth_token(auth_source, get_id_token)
210-
211-
def add_auth_token(self, auth_source: str, get_id_token: Callable[[], str]) -> None:
212-
"""
213-
Registers a function to retrieve an ID token for a given authentication
214-
source.
215-
216-
Args:
217-
auth_source : The name of the authentication source.
218-
get_id_token: A function that returns the ID token.
219-
"""
220-
self._id_token_getters[auth_source] = get_id_token
221-
22297
async def load_tool(
22398
self,
22499
tool_name: str,
225100
auth_tokens: dict[str, Callable[[], str]] = {},
226101
auth_headers: Optional[dict[str, Callable[[], str]]] = None,
227-
) -> StructuredTool:
102+
) -> ToolboxTool:
228103
"""
229-
Loads the tool, with the given tool name, from the Toolbox service.
104+
Loads the tool with the given tool name from the Toolbox service.
230105
231106
Args:
232107
tool_name: The name of the tool to load.
233-
auth_tokens: A mapping of authentication source names to
234-
functions that retrieve ID tokens. If provided, these will
235-
override or be added to the existing ID token getters.
236-
Default: Empty.
108+
auth_tokens: An optional mapping of authentication source names to
109+
functions that retrieve ID tokens.
110+
auth_headers: Deprecated. Use `auth_tokens` instead.
237111
238112
Returns:
239-
A tool loaded from the Toolbox
113+
A tool loaded from the Toolbox.
240114
"""
241115
if auth_headers:
242116
if auth_tokens:
@@ -251,32 +125,31 @@ async def load_tool(
251125
)
252126
auth_tokens = auth_headers
253127

254-
for auth_source, get_id_token in auth_tokens.items():
255-
self.add_auth_token(auth_source, get_id_token)
256-
257128
manifest: ManifestSchema = await self._load_tool_manifest(tool_name)
258-
259-
self._process_auth_params(manifest)
260-
261-
return self._generate_tool(tool_name, manifest)
129+
return ToolboxTool(
130+
tool_name,
131+
manifest.tools[tool_name],
132+
self._url,
133+
self._session,
134+
auth_tokens,
135+
)
262136

263137
async def load_toolset(
264138
self,
265139
toolset_name: Optional[str] = None,
266140
auth_tokens: dict[str, Callable[[], str]] = {},
267141
auth_headers: Optional[dict[str, Callable[[], str]]] = None,
268-
) -> list[StructuredTool]:
142+
) -> list[ToolboxTool]:
269143
"""
270144
Loads tools from the Toolbox service, optionally filtered by toolset
271145
name.
272146
273147
Args:
274-
toolset_name: The name of the toolset to load.
275-
Default: None. If not provided, then all the tools are loaded.
276-
auth_tokens: A mapping of authentication source names to
277-
functions that retrieve ID tokens. If provided, these will
278-
override or be added to the existing ID token getters.
279-
Default: Empty.
148+
toolset_name: The name of the toolset to load. If not provided,
149+
all tools are loaded.
150+
auth_tokens: An optional mapping of authentication source names to
151+
functions that retrieve ID tokens.
152+
auth_headers: Deprecated. Use `auth_tokens` instead.
280153
281154
Returns:
282155
A list of all tools loaded from the Toolbox.
@@ -294,14 +167,17 @@ async def load_toolset(
294167
)
295168
auth_tokens = auth_headers
296169

297-
for auth_source, get_id_token in auth_tokens.items():
298-
self.add_auth_token(auth_source, get_id_token)
299-
300-
tools: list[StructuredTool] = []
170+
tools: list[ToolboxTool] = []
301171
manifest: ManifestSchema = await self._load_toolset_manifest(toolset_name)
302172

303-
self._process_auth_params(manifest)
304-
305-
for tool_name in manifest.tools:
306-
tools.append(self._generate_tool(tool_name, manifest))
173+
for tool_name, tool_schema in manifest.tools.items():
174+
tools.append(
175+
ToolboxTool(
176+
tool_name,
177+
tool_schema,
178+
self._url,
179+
self._session,
180+
auth_tokens,
181+
)
182+
)
307183
return tools

0 commit comments

Comments
 (0)