Skip to content

Commit d4a24e6

Browse files
authored
feat(langchain-sdk): Add Toolbox SDK for LangChain (#22)
Adds a initial python SDK for interacting with Toolbox from LangChain.
0 parents  commit d4a24e6

File tree

5 files changed

+1224
-0
lines changed

5 files changed

+1224
-0
lines changed

pyproject.toml

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
[project]
2+
name = "toolbox_langchain_sdk"
3+
version="0.0.1"
4+
description = "Python SDK for interacting with the Toolbox service with LangChain"
5+
license = {file = "LICENSE"}
6+
requires-python = ">=3.9"
7+
authors = [
8+
{name = "Google LLC", email = "[email protected]"}
9+
]
10+
dependencies = [
11+
"aiohttp",
12+
"PyYAML",
13+
"langchain-core",
14+
"pydantic",
15+
]
16+
17+
classifiers = [
18+
"Intended Audience :: Developers",
19+
"License :: OSI Approved :: Apache Software License",
20+
"Programming Language :: Python",
21+
"Programming Language :: Python :: 3",
22+
"Programming Language :: Python :: 3.9",
23+
"Programming Language :: Python :: 3.10",
24+
"Programming Language :: Python :: 3.11",
25+
"Programming Language :: Python :: 3.12",
26+
]
27+
28+
[project.urls]
29+
Homepage = "https://github.com/googleapis/genai-toolbox"
30+
Repository = "https://github.com/googleapis/genai-toolbox.git"
31+
"Bug Tracker" = "https://github.com/googleapis/genai-toolbox/issues"
32+
33+
[project.optional-dependencies]
34+
test = [
35+
"black[jupyter]",
36+
"isort",
37+
"mypy",
38+
"pytest-asyncio",
39+
"pytest",
40+
"pytest-cov",
41+
"Pillow"
42+
]
43+
44+
[build-system]
45+
requires = ["setuptools"]
46+
build-backend = "setuptools.build_meta"
47+
48+
[tool.black]
49+
target-version = ['py39']
50+
51+
[tool.isort]
52+
profile = "black"
53+
54+
[tool.mypy]
55+
python_version = "3.9"
56+
warn_unused_configs = true
57+
disallow_incomplete_defs = true

src/toolbox_langchain_sdk/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .client import ToolboxClient
2+
3+
__all__ = ["ToolboxClient"]

src/toolbox_langchain_sdk/client.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from typing import Optional
2+
3+
from aiohttp import ClientSession
4+
from langchain_core.tools import StructuredTool
5+
from pydantic import BaseModel
6+
7+
from .utils import ManifestSchema, _invoke_tool, _load_yaml, _schema_to_model
8+
9+
10+
class ToolboxClient:
11+
def __init__(self, url: str, session: ClientSession):
12+
"""
13+
Initializes the ToolboxClient for the Toolbox service at the given URL.
14+
15+
Args:
16+
url: The base URL of the Toolbox service.
17+
session: The HTTP client session.
18+
"""
19+
self._url: str = url
20+
self._session = session
21+
22+
async def _load_tool_manifest(self, tool_name: str) -> ManifestSchema:
23+
"""
24+
Fetches and parses the YAML manifest for the given tool from the Toolbox service.
25+
26+
Args:
27+
tool_name: The name of the tool to load.
28+
29+
Returns:
30+
The parsed Toolbox manifest.
31+
"""
32+
url = f"{self._url}/api/tool/{tool_name}"
33+
return await _load_yaml(url, self._session)
34+
35+
async def _load_toolset_manifest(
36+
self, toolset_name: Optional[str] = None
37+
) -> ManifestSchema:
38+
"""
39+
Fetches and parses the YAML manifest from the Toolbox service.
40+
41+
Args:
42+
toolset_name: The name of the toolset to load.
43+
Default: None. If not provided, then all the available tools are loaded.
44+
45+
Returns:
46+
The parsed Toolbox manifest.
47+
"""
48+
url = f"{self._url}/api/toolset/{toolset_name or ''}"
49+
return await _load_yaml(url, self._session)
50+
51+
def _generate_tool(
52+
self, tool_name: str, manifest: ManifestSchema
53+
) -> StructuredTool:
54+
"""
55+
Creates a StructuredTool object and a dynamically generated BaseModel for the given tool.
56+
57+
Args:
58+
tool_name: The name of the tool to generate.
59+
manifest: The parsed Toolbox manifest.
60+
61+
Returns:
62+
The generated tool.
63+
"""
64+
tool_schema = manifest.tools[tool_name]
65+
tool_model: BaseModel = _schema_to_model(
66+
model_name=tool_name, schema=tool_schema.parameters
67+
)
68+
69+
async def _tool_func(**kwargs) -> dict:
70+
return await _invoke_tool(self._url, self._session, tool_name, kwargs)
71+
72+
return StructuredTool.from_function(
73+
coroutine=_tool_func,
74+
name=tool_name,
75+
description=tool_schema.description,
76+
args_schema=tool_model,
77+
)
78+
79+
async def load_tool(self, tool_name: str) -> StructuredTool:
80+
"""
81+
Loads the tool, with the given tool name, from the Toolbox service.
82+
83+
Args:
84+
toolset_name: The name of the toolset to load.
85+
Default: None. If not provided, then all the tools are loaded.
86+
87+
Returns:
88+
A tool loaded from the Toolbox
89+
"""
90+
manifest: ManifestSchema = await self._load_tool_manifest(tool_name)
91+
return self._generate_tool(tool_name, manifest)
92+
93+
async def load_toolset(
94+
self, toolset_name: Optional[str] = None
95+
) -> list[StructuredTool]:
96+
"""
97+
Loads tools from the Toolbox service, optionally filtered by toolset name.
98+
99+
Args:
100+
toolset_name: The name of the toolset to load.
101+
Default: None. If not provided, then all the tools are loaded.
102+
103+
Returns:
104+
A list of all tools loaded from the Toolbox.
105+
"""
106+
tools: list[StructuredTool] = []
107+
manifest: ManifestSchema = await self._load_toolset_manifest(toolset_name)
108+
for tool_name in manifest.tools:
109+
tools.append(self._generate_tool(tool_name, manifest))
110+
return tools

src/toolbox_langchain_sdk/utils.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from typing import Any, Type, Optional
2+
3+
import yaml
4+
from aiohttp import ClientSession
5+
from pydantic import BaseModel, Field, create_model
6+
7+
8+
class ParameterSchema(BaseModel):
9+
name: str
10+
type: str
11+
description: str
12+
13+
14+
class ToolSchema(BaseModel):
15+
description: str
16+
parameters: list[ParameterSchema]
17+
18+
19+
class ManifestSchema(BaseModel):
20+
serverVersion: str
21+
tools: dict[str, ToolSchema]
22+
23+
24+
async def _load_yaml(url: str, session: ClientSession) -> ManifestSchema:
25+
"""
26+
Asynchronously fetches and parses the YAML data from the given URL.
27+
28+
Args:
29+
url: The base URL to fetch the YAML from.
30+
session: The HTTP client session
31+
32+
Returns:
33+
The parsed Toolbox manifest.
34+
"""
35+
async with session.get(url) as response:
36+
response.raise_for_status()
37+
parsed_yaml = yaml.safe_load(await response.text())
38+
return ManifestSchema(**parsed_yaml)
39+
40+
41+
def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[BaseModel]:
42+
"""
43+
Converts a schema (from the YAML manifest) to a Pydantic BaseModel class.
44+
45+
Args:
46+
model_name: The name of the model to create.
47+
schema: The schema to convert.
48+
49+
Returns:
50+
A Pydantic BaseModel class.
51+
"""
52+
field_definitions = {}
53+
for field in schema:
54+
field_definitions[field.name] = (
55+
# TODO: Remove the hardcoded optional types once optional fields are supported by Toolbox.
56+
Optional[_parse_type(field.type)],
57+
Field(description=field.description),
58+
)
59+
60+
return create_model(model_name, **field_definitions)
61+
62+
63+
def _parse_type(type_: str) -> Any:
64+
"""
65+
Converts a schema type to a JSON type.
66+
67+
Args:
68+
type_: The type name to convert.
69+
70+
Returns:
71+
A valid JSON type.
72+
"""
73+
74+
if type_ == "string":
75+
return str
76+
elif type_ == "integer":
77+
return int
78+
elif type_ == "number":
79+
return float
80+
elif type_ == "boolean":
81+
return bool
82+
elif type_ == "array":
83+
return list
84+
else:
85+
raise ValueError(f"Unsupported schema type: {type_}")
86+
87+
88+
async def _invoke_tool(
89+
url: str, session: ClientSession, tool_name: str, data: dict
90+
) -> dict:
91+
"""
92+
Asynchronously makes an API call to the Toolbox service to invoke a tool.
93+
94+
Args:
95+
url: The base URL of the Toolbox service.
96+
session: The HTTP client session.
97+
tool_name: The name of the tool to invoke.
98+
data: The input data for the tool.
99+
100+
Returns:
101+
A dictionary containing the parsed JSON response from the tool invocation.
102+
"""
103+
url = f"{url}/api/tool/{tool_name}/invoke"
104+
async with session.post(url, json=_convert_none_to_empty_string(data)) as response:
105+
response.raise_for_status()
106+
return await response.json()
107+
108+
109+
# TODO: Remove this temporary fix once optional fields are supported by Toolbox.
110+
def _convert_none_to_empty_string(input_dict):
111+
new_dict = {}
112+
for key, value in input_dict.items():
113+
if value is None:
114+
new_dict[key] = ""
115+
else:
116+
new_dict[key] = value
117+
return new_dict

0 commit comments

Comments
 (0)