Skip to content

Commit 4fcfc35

Browse files
authored
feat(langchain-sdk): Add features for binding parameters to ToolboxTool class. (#192)
The newly implemented `ToolboxTool` class manages tool state and supports this new feature of binding parameters along wiith the existing OAuth. > [!NOTE] > These changes are done in the LlamaIndex SDK as well in #203, along with documentation updates in #193.
1 parent 6fe2e39 commit 4fcfc35

File tree

5 files changed

+372
-14
lines changed

5 files changed

+372
-14
lines changed

src/toolbox_langchain_sdk/client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ async def load_tool(
9999
tool_name: str,
100100
auth_tokens: dict[str, Callable[[], str]] = {},
101101
auth_headers: Optional[dict[str, Callable[[], str]]] = None,
102+
bound_params: dict[str, Union[Any, Callable[[], Any]]] = {},
103+
strict: bool = True,
102104
) -> ToolboxTool:
103105
"""
104106
Loads the tool with the given tool name from the Toolbox service.
@@ -108,6 +110,11 @@ async def load_tool(
108110
auth_tokens: An optional mapping of authentication source names to
109111
functions that retrieve ID tokens.
110112
auth_headers: Deprecated. Use `auth_tokens` instead.
113+
bound_params: An optional mapping of parameter names to their
114+
bound values.
115+
strict: If True, raises a ValueError if any of the given bound
116+
parameters are missing from the schema or require
117+
authentication. If False, only issues a warning.
111118
112119
Returns:
113120
A tool loaded from the Toolbox.
@@ -132,13 +139,17 @@ async def load_tool(
132139
self._url,
133140
self._session,
134141
auth_tokens,
142+
bound_params,
143+
strict,
135144
)
136145

137146
async def load_toolset(
138147
self,
139148
toolset_name: Optional[str] = None,
140149
auth_tokens: dict[str, Callable[[], str]] = {},
141150
auth_headers: Optional[dict[str, Callable[[], str]]] = None,
151+
bound_params: dict[str, Union[Any, Callable[[], Any]]] = {},
152+
strict: bool = True,
142153
) -> list[ToolboxTool]:
143154
"""
144155
Loads tools from the Toolbox service, optionally filtered by toolset
@@ -150,6 +161,11 @@ async def load_toolset(
150161
auth_tokens: An optional mapping of authentication source names to
151162
functions that retrieve ID tokens.
152163
auth_headers: Deprecated. Use `auth_tokens` instead.
164+
bound_params: An optional mapping of parameter names to their
165+
bound values.
166+
strict: If True, raises a ValueError if any of the given bound
167+
parameters are missing from the schema or require
168+
authentication. If False, only issues a warning.
153169
154170
Returns:
155171
A list of all tools loaded from the Toolbox.
@@ -178,6 +194,8 @@ async def load_toolset(
178194
self._url,
179195
self._session,
180196
auth_tokens,
197+
bound_params,
198+
strict,
181199
)
182200
)
183201
return tools

src/toolbox_langchain_sdk/tools.py

Lines changed: 162 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from copy import deepcopy
16-
from typing import Any, Callable
16+
from typing import Any, Callable, Union
1717
from warnings import warn
1818

1919
from aiohttp import ClientSession
@@ -24,6 +24,7 @@
2424
ParameterSchema,
2525
ToolSchema,
2626
_find_auth_params,
27+
_find_bound_params,
2728
_invoke_tool,
2829
_schema_to_model,
2930
)
@@ -32,7 +33,7 @@
3233
class ToolboxTool(StructuredTool):
3334
"""
3435
A subclass of LangChain's StructuredTool that supports features specific to
35-
Toolbox, like authenticated tools.
36+
Toolbox, like bound parameters and authenticated tools.
3637
"""
3738

3839
def __init__(
@@ -42,6 +43,8 @@ def __init__(
4243
url: str,
4344
session: ClientSession,
4445
auth_tokens: dict[str, Callable[[], str]] = {},
46+
bound_params: dict[str, Union[Any, Callable[[], Any]]] = {},
47+
strict: bool = True,
4548
) -> None:
4649
"""
4750
Initializes a ToolboxTool instance.
@@ -53,6 +56,11 @@ def __init__(
5356
session: The HTTP client session.
5457
auth_tokens: A mapping of authentication source names to functions
5558
that retrieve ID tokens.
59+
bound_params: A mapping of parameter names to their bound
60+
values.
61+
strict: If True, raises a ValueError if any of the given bound
62+
parameters are missing from the schema or require
63+
authentication. If False, only issues a warning.
5664
"""
5765

5866
# If the schema is not already a ToolSchema instance, we create one from
@@ -63,10 +71,51 @@ def __init__(
6371
schema = ToolSchema(**schema)
6472

6573
auth_params, non_auth_params = _find_auth_params(schema.parameters)
74+
non_auth_bound_params, non_auth_non_bound_params = _find_bound_params(
75+
non_auth_params, list(bound_params)
76+
)
77+
78+
# Check if the user is trying to bind a param that is authenticated or
79+
# is missing from the given schema.
80+
auth_bound_params: list[str] = []
81+
missing_bound_params: list[str] = []
82+
for bound_param in bound_params:
83+
if bound_param in [param.name for param in auth_params]:
84+
auth_bound_params.append(bound_param)
85+
elif bound_param not in [param.name for param in non_auth_params]:
86+
missing_bound_params.append(bound_param)
87+
88+
# Create error messages for any params that are found to be
89+
# authenticated or missing.
90+
messages: list[str] = []
91+
if auth_bound_params:
92+
messages.append(
93+
f"Parameter(s) {', '.join(auth_bound_params)} already authenticated and cannot be bound."
94+
)
95+
if missing_bound_params:
96+
messages.append(
97+
f"Parameter(s) {', '.join(missing_bound_params)} missing and cannot be bound."
98+
)
99+
100+
# Join any error messages and raise them as an error or warning,
101+
# depending on the value of the strict flag.
102+
if messages:
103+
message = "\n\n".join(messages)
104+
if strict:
105+
raise ValueError(message)
106+
warn(message)
107+
108+
# Bind values for parameters present in the schema that don't require
109+
# authentication.
110+
bound_params = {
111+
param_name: param_value
112+
for param_name, param_value in bound_params.items()
113+
if param_name in [param.name for param in non_auth_bound_params]
114+
}
66115

67116
# Update the tools schema to validate only the presence of parameters
68-
# that do not require authentication.
69-
schema.parameters = non_auth_params
117+
# that neither require authentication nor are bound.
118+
schema.parameters = non_auth_non_bound_params
70119

71120
# Due to how pydantic works, we must initialize the underlying
72121
# StructuredTool class before assigning values to member variables.
@@ -84,6 +133,7 @@ def __init__(
84133
self._session: ClientSession = session
85134
self._auth_tokens: dict[str, Callable[[], str]] = auth_tokens
86135
self._auth_params: list[ParameterSchema] = auth_params
136+
self._bound_params: dict[str, Union[Any, Callable[[], Any]]] = bound_params
87137

88138
# Warn users about any missing authentication so they can add it before
89139
# tool invocation.
@@ -106,6 +156,17 @@ async def __tool_func(self, **kwargs: Any) -> dict:
106156
# authentication sources have been registered or not.
107157
self.__validate_auth()
108158

159+
# Evaluate dynamic parameter values if any
160+
evaluated_params = {}
161+
for param_name, param_value in self._bound_params.items():
162+
if callable(param_value):
163+
evaluated_params[param_name] = param_value()
164+
else:
165+
evaluated_params[param_name] = param_value
166+
167+
# Merge bound parameters with the provided arguments
168+
kwargs.update(evaluated_params)
169+
109170
return await _invoke_tool(
110171
self._url, self._session, self._name, kwargs, self._auth_tokens
111172
)
@@ -154,42 +215,66 @@ def __create_copy(
154215
self,
155216
*,
156217
auth_tokens: dict[str, Callable[[], str]] = {},
218+
bound_params: dict[str, Union[Any, Callable[[], Any]]] = {},
219+
strict: bool,
157220
) -> Self:
158221
"""
159222
Creates a deep copy of the current ToolboxTool instance, allowing for
160-
modification of auth tokens.
223+
modification of auth tokens and bound params.
161224
162225
This method enables the creation of new tool instances with inherited
163226
properties from the current instance, while optionally updating the auth
164-
tokens. This is useful for creating variations of the tool with
165-
additional auth tokens without modifying the original instance, ensuring
166-
immutability.
227+
tokens and bound params. This is useful for creating variations of the
228+
tool with additional auth tokens or bound params without modifying the
229+
original instance, ensuring immutability.
167230
168231
Args:
169232
auth_tokens: A dictionary of auth source names to functions that
170233
retrieve ID tokens. These tokens will be merged with the
171234
existing auth tokens.
235+
bound_params: A dictionary of parameter names to their
236+
bound values or functions to retrieve the values. These params
237+
will be merged with the existing bound params.
238+
strict: If True, raises a ValueError if any of the given bound
239+
parameters are missing from the schema or require
240+
authentication. If False, only issues a warning.
172241
173242
Returns:
174243
A new ToolboxTool instance that is a deep copy of the current
175-
instance, with optionally updated auth tokens.
244+
instance, with added auth tokens or bound params.
176245
"""
246+
new_schema = deepcopy(self._schema)
247+
248+
# Reconstruct the complete parameter schema by merging the auth
249+
# parameters back with the non-auth parameters. This is necessary to
250+
# accurately validate the new combination of auth tokens and bound
251+
# params in the constructor of the new ToolboxTool instance, ensuring
252+
# that any overlaps or conflicts are correctly identified and reported
253+
# as errors or warnings, depending on the given `strict` flag.
254+
new_schema.parameters += self._auth_params
177255
return type(self)(
178256
name=self._name,
179-
schema=deepcopy(self._schema),
257+
schema=new_schema,
180258
url=self._url,
181259
session=self._session,
182260
auth_tokens={**self._auth_tokens, **auth_tokens},
261+
bound_params={**self._bound_params, **bound_params},
262+
strict=strict,
183263
)
184264

185-
def add_auth_tokens(self, auth_tokens: dict[str, Callable[[], str]]) -> Self:
265+
def add_auth_tokens(
266+
self, auth_tokens: dict[str, Callable[[], str]], strict: bool = True
267+
) -> Self:
186268
"""
187269
Registers functions to retrieve ID tokens for the corresponding
188270
authentication sources.
189271
190272
Args:
191273
auth_tokens: A dictionary of authentication source names to the
192274
functions that return corresponding ID token.
275+
strict: If True, a ValueError is raised if any of the provided auth
276+
tokens are already registered, or are already bound. If False,
277+
only a warning is issued.
193278
194279
Returns:
195280
A new ToolboxTool instance that is a deep copy of the current
@@ -207,19 +292,82 @@ def add_auth_tokens(self, auth_tokens: dict[str, Callable[[], str]]) -> Self:
207292
f"Authentication source(s) `{', '.join(dupe_tokens)}` already registered in tool `{self._name}`."
208293
)
209294

210-
return self.__create_copy(auth_tokens=auth_tokens)
295+
return self.__create_copy(auth_tokens=auth_tokens, strict=strict)
211296

212-
def add_auth_token(self, auth_source: str, get_id_token: Callable[[], str]) -> Self:
297+
def add_auth_token(
298+
self, auth_source: str, get_id_token: Callable[[], str], strict: bool = True
299+
) -> Self:
213300
"""
214301
Registers a function to retrieve an ID token for a given authentication
215302
source.
216303
217304
Args:
218305
auth_source: The name of the authentication source.
219306
get_id_token: A function that returns the ID token.
307+
strict: If True, a ValueError is raised if any of the provided auth
308+
tokens are already registered, or are already bound. If False,
309+
only a warning is issued.
220310
221311
Returns:
222312
A new ToolboxTool instance that is a deep copy of the current
223313
instance, with added auth tokens.
224314
"""
225-
return self.add_auth_tokens({auth_source: get_id_token})
315+
return self.add_auth_tokens({auth_source: get_id_token}, strict=strict)
316+
317+
def bind_params(
318+
self,
319+
bound_params: dict[str, Union[Any, Callable[[], Any]]],
320+
strict: bool = True,
321+
) -> Self:
322+
"""
323+
Registers values or functions to retrieve the value for the
324+
corresponding bound parameters.
325+
326+
Args:
327+
bound_params: A dictionary of the bound parameter name to the
328+
value or function of the bound value.
329+
strict: If True, a ValueError is raised if any of the provided bound
330+
params are already bound, not defined in the tool's schema, or
331+
require authentication. If False, only a warning is issued.
332+
333+
Returns:
334+
A new ToolboxTool instance that is a deep copy of the current
335+
instance, with added bound params.
336+
"""
337+
338+
# Check if the parameter is already bound.
339+
dupe_params: list[str] = []
340+
for param_name, _ in bound_params.items():
341+
if param_name in self._bound_params:
342+
dupe_params.append(param_name)
343+
344+
if dupe_params:
345+
raise ValueError(
346+
f"Parameter(s) `{', '.join(dupe_params)}` already bound in tool `{self._name}`."
347+
)
348+
349+
return self.__create_copy(bound_params=bound_params, strict=strict)
350+
351+
def bind_param(
352+
self,
353+
param_name: str,
354+
param_value: Union[Any, Callable[[], Any]],
355+
strict: bool = True,
356+
) -> Self:
357+
"""
358+
Registers a value or a function to retrieve the value for a given
359+
bound parameter.
360+
361+
Args:
362+
param_name: The name of the bound parameter.
363+
param_value: The value of the bound parameter, or a callable
364+
that returns the value.
365+
strict: If True, a ValueError is raised if any of the provided bound
366+
params are already bound, not defined in the tool's schema, or
367+
require authentication. If False, only a warning is issued.
368+
369+
Returns:
370+
A new ToolboxTool instance that is a deep copy of the current
371+
instance, with added bound params.
372+
"""
373+
return self.bind_params({param_name: param_value}, strict)

src/toolbox_langchain_sdk/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,18 @@ def _find_auth_params(
238238
_non_auth_params.append(param)
239239

240240
return (_auth_params, _non_auth_params)
241+
242+
243+
def _find_bound_params(
244+
params: list[ParameterSchema], bound_params: list[str]
245+
) -> tuple[list[ParameterSchema], list[ParameterSchema]]:
246+
_bound_params: list[ParameterSchema] = []
247+
_non_bound_params: list[ParameterSchema] = []
248+
249+
for param in params:
250+
if param.name in bound_params:
251+
_bound_params.append(param)
252+
else:
253+
_non_bound_params.append(param)
254+
255+
return (_bound_params, _non_bound_params)

0 commit comments

Comments
 (0)