Skip to content

Commit 22aa329

Browse files
committed
chore: Refactor getting tool to run into a reusable helper
1 parent 60e08b3 commit 22aa329

File tree

1 file changed

+10
-32
lines changed
  • packages/toolbox-langchain/src/toolbox_langchain

1 file changed

+10
-32
lines changed

packages/toolbox-langchain/src/toolbox_langchain/tools.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from asyncio import to_thread
16-
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
16+
from typing import Any, Awaitable, Callable, Mapping, Sequence, Union
1717

1818
from deprecated import deprecated
1919
from langchain_core.runnables import RunnableConfig
@@ -74,11 +74,7 @@ def _client_headers(
7474
) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
7575
return self.__core_tool._client_headers
7676

77-
def _run(
78-
self,
79-
config: RunnableConfig,
80-
**kwargs: Any,
81-
) -> str:
77+
def __get_tool_to_run(self, config: RunnableConfig) -> ToolboxCoreSyncTool:
8278
tool_to_run = self.__core_tool
8379
if (
8480
config
@@ -104,36 +100,18 @@ def _run(
104100
tool_to_run = self.__core_tool.add_auth_token_getters(
105101
filtered_getters
106102
)
103+
return tool_to_run
107104

105+
def _run(
106+
self,
107+
config: RunnableConfig,
108+
**kwargs: Any,
109+
) -> str:
110+
tool_to_run = self.__get_tool_to_run(config)
108111
return tool_to_run(**kwargs)
109112

110113
async def _arun(self, config: RunnableConfig, **kwargs: Any) -> str:
111-
tool_to_run = self.__core_tool
112-
if (
113-
config
114-
and "configurable" in config
115-
and "auth_token_getters" in config["configurable"]
116-
):
117-
auth_token_getters = config["configurable"]["auth_token_getters"]
118-
if auth_token_getters:
119-
120-
# The `add_auth_token_getters` method requires that all provided
121-
# getters are used by the tool. To prevent validation errors,
122-
# filter the incoming getters to include only those that this
123-
# specific tool requires.
124-
required_auth_keys = set(self.__core_tool._required_authz_tokens)
125-
for auth_list in self.__core_tool._required_authn_params.values():
126-
required_auth_keys.update(auth_list)
127-
filtered_getters = {
128-
k: v
129-
for k, v in auth_token_getters.items()
130-
if k in required_auth_keys
131-
}
132-
if filtered_getters:
133-
tool_to_run = self.__core_tool.add_auth_token_getters(
134-
filtered_getters
135-
)
136-
114+
tool_to_run = self.__get_tool_to_run(config)
137115
return await to_thread(tool_to_run, **kwargs)
138116

139117
def add_auth_token_getters(

0 commit comments

Comments
 (0)