Skip to content

Commit 931c9fc

Browse files
committed
feat(toolbox-langchain): Implement self-authenticated tools
1 parent d33d044 commit 931c9fc

File tree

2 files changed

+69
-9
lines changed

2 files changed

+69
-9
lines changed

packages/toolbox-langchain/src/toolbox_langchain/async_tools.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from typing import Any, Callable, Union
16-
16+
from langchain_core.runnables import RunnableConfig
1717
from deprecated import deprecated
1818
from langchain_core.tools import BaseTool
1919
from toolbox_core.tool import ToolboxTool as ToolboxCoreTool
@@ -52,7 +52,11 @@ def __init__(
5252
def _run(self, **kwargs: Any) -> str:
5353
raise NotImplementedError("Synchronous methods not supported by async tools.")
5454

55-
async def _arun(self, **kwargs: Any) -> str:
55+
async def _arun(
56+
self,
57+
config: RunnableConfig,
58+
**kwargs: Any,
59+
) -> str:
5660
"""
5761
The coroutine that invokes the tool with the given arguments.
5862
@@ -63,7 +67,23 @@ async def _arun(self, **kwargs: Any) -> str:
6367
A dictionary containing the parsed JSON response from the tool
6468
invocation.
6569
"""
66-
return await self.__core_tool(**kwargs)
70+
tool_to_run = self.__core_tool
71+
if config and "configurable" in config and "auth_token_getters" in config["configurable"]:
72+
auth_token_getters = config["configurable"]["auth_token_getters"]
73+
if auth_token_getters:
74+
75+
# The `add_auth_token_getters` method requires that all provided
76+
# getters are used by the tool. To prevent validation errors,
77+
# filter the incoming getters to include only those that this
78+
# specific tool requires.
79+
required_auth_keys = set(self.__core_tool._required_authz_tokens)
80+
for auth_list in self.__core_tool._required_authn_params.values():
81+
required_auth_keys.update(auth_list)
82+
filtered_getters = {k: v for k, v in auth_token_getters.items() if k in required_auth_keys}
83+
if filtered_getters:
84+
tool_to_run = self.__core_tool.add_auth_token_getters(filtered_getters)
85+
86+
return await tool_to_run(**kwargs)
6787

6888
def add_auth_token_getters(
6989
self, auth_token_getters: dict[str, Callable[[], str]]

packages/toolbox-langchain/src/toolbox_langchain/tools.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
from asyncio import to_thread
16-
from typing import Any, Callable, Union, Mapping, Sequence, Awaitable
16+
from typing import Any, Callable, Union, Mapping, Sequence, Awaitable, Optional
17+
from langchain_core.runnables import RunnableConfig
1718

1819
from deprecated import deprecated
1920
from langchain_core.tools import BaseTool
@@ -73,11 +74,50 @@ def _client_headers(
7374
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
7475
return self.__core_tool._client_headers
7576

76-
def _run(self, **kwargs: Any) -> str:
77-
return self.__core_tool(**kwargs)
78-
79-
async def _arun(self, **kwargs: Any) -> str:
80-
return await to_thread(self.__core_tool, **kwargs)
77+
def _run(
78+
self,
79+
config: RunnableConfig,
80+
**kwargs: Any,
81+
) -> str:
82+
tool_to_run = self.__core_tool
83+
if config and "configurable" in config and "auth_token_getters" in config["configurable"]:
84+
auth_token_getters = config["configurable"]["auth_token_getters"]
85+
if auth_token_getters:
86+
87+
# The `add_auth_token_getters` method requires that all provided
88+
# getters are used by the tool. To prevent validation errors,
89+
# filter the incoming getters to include only those that this
90+
# specific tool requires.
91+
required_auth_keys = set(self.__core_tool._required_authz_tokens)
92+
for auth_list in self.__core_tool._required_authn_params.values():
93+
required_auth_keys.update(auth_list)
94+
filtered_getters = {k: v for k, v in auth_token_getters.items() if k in required_auth_keys}
95+
if filtered_getters:
96+
tool_to_run = self.__core_tool.add_auth_token_getters(filtered_getters)
97+
98+
return tool_to_run(**kwargs)
99+
100+
async def _arun(
101+
self,
102+
config: RunnableConfig,
103+
**kwargs: Any) -> str:
104+
tool_to_run = self.__core_tool
105+
if config and "configurable" in config and "auth_token_getters" in config["configurable"]:
106+
auth_token_getters = config["configurable"]["auth_token_getters"]
107+
if auth_token_getters:
108+
109+
# The `add_auth_token_getters` method requires that all provided
110+
# getters are used by the tool. To prevent validation errors,
111+
# filter the incoming getters to include only those that this
112+
# specific tool requires.
113+
required_auth_keys = set(self.__core_tool._required_authz_tokens)
114+
for auth_list in self.__core_tool._required_authn_params.values():
115+
required_auth_keys.update(auth_list)
116+
filtered_getters = {k: v for k, v in auth_token_getters.items() if k in required_auth_keys}
117+
if filtered_getters:
118+
tool_to_run = self.__core_tool.add_auth_token_getters(filtered_getters)
119+
120+
return await to_thread(tool_to_run, **kwargs)
81121

82122
def add_auth_token_getters(
83123
self, auth_token_getters: dict[str, Callable[[], str]]

0 commit comments

Comments
 (0)