Skip to content

Commit b28bdb5

Browse files
authored
feat(langchain-sdk): Support authentication in LangChain Toolbox SDK. (#133)
This PR adds support for authentication in the `ToolboxClient`. Note: This is a resubmission of #117
1 parent 1cf78bf commit b28bdb5

File tree

4 files changed

+821
-14
lines changed

4 files changed

+821
-14
lines changed

src/toolbox_langchain_sdk/client.py

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
2-
from typing import Any, Optional, Type
2+
import warnings
3+
from typing import Any, Callable, Optional, Type
34

45
from aiohttp import ClientSession
56
from langchain_core.tools import StructuredTool
@@ -20,6 +21,8 @@ def __init__(self, url: str, session: Optional[ClientSession] = None):
2021
"""
2122
self._url: str = url
2223
self._should_close_session: bool = session is None
24+
self._id_token_getters: dict[str, Callable[[], str]] = {}
25+
self._tool_param_auth: dict[str, dict[str, list[str]]] = {}
2326
self._session: ClientSession = session or ClientSession()
2427

2528
async def close(self) -> None:
@@ -77,6 +80,35 @@ async def _load_toolset_manifest(
7780
url = f"{self._url}/api/toolset/{toolset_name or ''}"
7881
return await _load_yaml(url, self._session)
7982

83+
def _validate_auth(self, tool_name: str) -> bool:
84+
"""
85+
Helper method that validates the authentication requirements of the tool
86+
with the given tool_name. We consider the validation to pass if at least
87+
one auth sources of each of the auth parameters, of the given tool, is
88+
registered.
89+
90+
Args:
91+
tool_name: Name of the tool to validate auth sources for.
92+
93+
Returns:
94+
True if at least one permitted auth source of each of the auth
95+
params, of the given tool, is registered. Also returns True if the
96+
given tool does not require any auth sources.
97+
"""
98+
99+
if tool_name not in self._tool_param_auth:
100+
return True
101+
102+
for permitted_auth_sources in self._tool_param_auth[tool_name].values():
103+
found_match = False
104+
for registered_auth_source in self._id_token_getters:
105+
if registered_auth_source in permitted_auth_sources:
106+
found_match = True
107+
break
108+
if not found_match:
109+
return False
110+
return True
111+
80112
def _generate_tool(
81113
self, tool_name: str, manifest: ManifestSchema
82114
) -> StructuredTool:
@@ -96,8 +128,16 @@ def _generate_tool(
96128
model_name=tool_name, schema=tool_schema.parameters
97129
)
98130

131+
# If the tool had parameters that require authentication, then right
132+
# before invoking that tool, we validate whether all these required
133+
# authentication sources have been registered or not.
99134
async def _tool_func(**kwargs: Any) -> dict:
100-
return await _invoke_tool(self._url, self._session, tool_name, kwargs)
135+
if not self._validate_auth(tool_name):
136+
raise PermissionError(f"Login required before invoking {tool_name}.")
137+
138+
return await _invoke_tool(
139+
self._url, self._session, tool_name, kwargs, self._id_token_getters
140+
)
101141

102142
return StructuredTool.from_function(
103143
coroutine=_tool_func,
@@ -106,21 +146,89 @@ async def _tool_func(**kwargs: Any) -> dict:
106146
args_schema=tool_model,
107147
)
108148

109-
async def load_tool(self, tool_name: str) -> StructuredTool:
149+
def _process_auth_params(self, manifest: ManifestSchema) -> None:
150+
"""
151+
Extracts parameters requiring authentication from the manifest.
152+
Verifies each parameter has at least one valid auth source.
153+
154+
Args:
155+
manifest: The manifest to validate and modify.
156+
157+
Warns:
158+
UserWarning: If a parameter in the manifest has no valid sources.
159+
"""
160+
for tool_name, tool_schema in manifest.tools.items():
161+
non_auth_params = []
162+
for param in tool_schema.parameters:
163+
164+
# Extract auth params from the tool schema.
165+
#
166+
# These parameters are removed from the manifest to prevent data
167+
# validation errors since their values are inferred by the
168+
# Toolbox service, not provided by the user.
169+
#
170+
# Store the permitted authentication sources for each parameter
171+
# in '_tool_param_auth' for efficient validation in
172+
# '_validate_auth'.
173+
if not param.authSources:
174+
non_auth_params.append(param)
175+
continue
176+
177+
self._tool_param_auth.setdefault(tool_name, {})[
178+
param.name
179+
] = param.authSources
180+
181+
tool_schema.parameters = non_auth_params
182+
183+
# If none of the permitted auth sources of a parameter are
184+
# registered, raise a warning message to the user.
185+
if not self._validate_auth(tool_name):
186+
warnings.warn(
187+
f"Some parameters of tool {tool_name} require authentication, but no valid auth sources are registered. Please register the required sources before use."
188+
)
189+
190+
def add_auth_header(
191+
self, auth_source: str, get_id_token: Callable[[], str]
192+
) -> None:
193+
"""
194+
Registers a function to retrieve an ID token for a given authentication
195+
source.
196+
197+
Args:
198+
auth_source : The name of the authentication source.
199+
get_id_token: A function that returns the ID token.
200+
"""
201+
self._id_token_getters[auth_source] = get_id_token
202+
203+
async def load_tool(
204+
self, tool_name: str, auth_headers: dict[str, Callable[[], str]] = {}
205+
) -> StructuredTool:
110206
"""
111207
Loads the tool, with the given tool name, from the Toolbox service.
112208
113209
Args:
114210
tool_name: The name of the tool to load.
211+
auth_headers: A mapping of authentication source names to
212+
functions that retrieve ID tokens. If provided, these will
213+
override or be added to the existing ID token getters.
214+
Default: Empty.
115215
116216
Returns:
117217
A tool loaded from the Toolbox
118218
"""
219+
for auth_source, get_id_token in auth_headers.items():
220+
self.add_auth_header(auth_source, get_id_token)
221+
119222
manifest: ManifestSchema = await self._load_tool_manifest(tool_name)
223+
224+
self._process_auth_params(manifest)
225+
120226
return self._generate_tool(tool_name, manifest)
121227

122228
async def load_toolset(
123-
self, toolset_name: Optional[str] = None
229+
self,
230+
toolset_name: Optional[str] = None,
231+
auth_headers: dict[str, Callable[[], str]] = {},
124232
) -> list[StructuredTool]:
125233
"""
126234
Loads tools from the Toolbox service, optionally filtered by toolset
@@ -129,12 +237,22 @@ async def load_toolset(
129237
Args:
130238
toolset_name: The name of the toolset to load.
131239
Default: None. If not provided, then all the tools are loaded.
240+
auth_headers: A mapping of authentication source names to
241+
functions that retrieve ID tokens. If provided, these will
242+
override or be added to the existing ID token getters.
243+
Default: Empty.
132244
133245
Returns:
134246
A list of all tools loaded from the Toolbox.
135247
"""
248+
for auth_source, get_id_token in auth_headers.items():
249+
self.add_auth_header(auth_source, get_id_token)
250+
136251
tools: list[StructuredTool] = []
137252
manifest: ManifestSchema = await self._load_toolset_manifest(toolset_name)
253+
254+
self._process_auth_params(manifest)
255+
138256
for tool_name in manifest.tools:
139257
tools.append(self._generate_tool(tool_name, manifest))
140258
return tools

src/toolbox_langchain_sdk/utils.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Optional, Type, cast
1+
import warnings
2+
from typing import Any, Callable, Optional, Type, cast
23

34
import yaml
45
from aiohttp import ClientSession
@@ -9,6 +10,7 @@ class ParameterSchema(BaseModel):
910
name: str
1011
type: str
1112
description: str
13+
authSources: Optional[list[str]] = None
1214

1315

1416
class ToolSchema(BaseModel):
@@ -34,8 +36,14 @@ async def _load_yaml(url: str, session: ClientSession) -> ManifestSchema:
3436
"""
3537
async with session.get(url) as response:
3638
response.raise_for_status()
37-
parsed_yaml = yaml.safe_load(await response.text())
38-
return ManifestSchema(**parsed_yaml)
39+
try:
40+
parsed_yaml = yaml.safe_load(await response.text())
41+
except yaml.YAMLError as e:
42+
raise yaml.YAMLError(f"Failed to parse YAML from {url}: {e}") from e
43+
try:
44+
return ManifestSchema(**parsed_yaml)
45+
except ValueError as e:
46+
raise ValueError(f"Invalid YAML data from {url}: {e}") from e
3947

4048

4149
def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[BaseModel]:
@@ -54,7 +62,8 @@ def _schema_to_model(model_name: str, schema: list[ParameterSchema]) -> Type[Bas
5462
field_definitions[field.name] = cast(
5563
Any,
5664
(
57-
# TODO: Remove the hardcoded optional types once optional fields are supported by Toolbox.
65+
# TODO: Remove the hardcoded optional types once optional fields
66+
# are supported by Toolbox.
5867
Optional[_parse_type(field.type)],
5968
Field(description=field.description),
6069
),
@@ -88,8 +97,30 @@ def _parse_type(type_: str) -> Any:
8897
raise ValueError(f"Unsupported schema type: {type_}")
8998

9099

100+
def _get_auth_headers(id_token_getters: dict[str, Callable[[], str]]) -> dict[str, str]:
101+
"""
102+
Gets id tokens for the given auth sources in the getters map and returns
103+
headers to be included in tool invocation.
104+
105+
Args:
106+
id_token_getters: A dict that maps auth source names to the functions
107+
that return its ID token.
108+
109+
Returns:
110+
A dictionary of headers to be included in the tool invocation.
111+
"""
112+
auth_headers = {}
113+
for auth_source, get_id_token in id_token_getters.items():
114+
auth_headers[f"{auth_source}_token"] = get_id_token()
115+
return auth_headers
116+
117+
91118
async def _invoke_tool(
92-
url: str, session: ClientSession, tool_name: str, data: dict
119+
url: str,
120+
session: ClientSession,
121+
tool_name: str,
122+
data: dict,
123+
id_token_getters: dict[str, Callable[[], str]],
93124
) -> dict:
94125
"""
95126
Asynchronously makes an API call to the Toolbox service to invoke a tool.
@@ -99,12 +130,29 @@ async def _invoke_tool(
99130
session: The HTTP client session.
100131
tool_name: The name of the tool to invoke.
101132
data: The input data for the tool.
133+
id_token_getters: A dict that maps auth source names to the functions
134+
that return its ID token.
102135
103136
Returns:
104-
A dictionary containing the parsed JSON response from the tool invocation.
137+
A dictionary containing the parsed JSON response from the tool
138+
invocation.
105139
"""
106140
url = f"{url}/api/tool/{tool_name}/invoke"
107-
async with session.post(url, json=_convert_none_to_empty_string(data)) as response:
141+
auth_headers = _get_auth_headers(id_token_getters)
142+
143+
# ID tokens contain sensitive user information (claims). Transmitting these
144+
# over HTTP exposes the data to interception and unauthorized access. Always
145+
# use HTTPS to ensure secure communication and protect user privacy.
146+
if auth_headers and not url.startswith("https://"):
147+
warnings.warn(
148+
"Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication."
149+
)
150+
151+
async with session.post(
152+
url,
153+
json=_convert_none_to_empty_string(data),
154+
headers=auth_headers,
155+
) as response:
108156
response.raise_for_status()
109157
return await response.json()
110158

0 commit comments

Comments
 (0)