|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | from asyncio import to_thread
|
16 |
| -from typing import Any, Callable, Union, Mapping, Sequence, Awaitable |
| 16 | +from typing import Any, Callable, Union, Mapping, Sequence, Awaitable, Optional |
| 17 | +from langchain_core.runnables import RunnableConfig |
17 | 18 |
|
18 | 19 | from deprecated import deprecated
|
19 | 20 | from langchain_core.tools import BaseTool
|
@@ -73,11 +74,50 @@ def _client_headers(
|
73 | 74 | ) -> Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]]:
|
74 | 75 | return self.__core_tool._client_headers
|
75 | 76 |
|
76 |
| - def _run(self, **kwargs: Any) -> str: |
77 |
| - return self.__core_tool(**kwargs) |
78 |
| - |
79 |
| - async def _arun(self, **kwargs: Any) -> str: |
80 |
| - return await to_thread(self.__core_tool, **kwargs) |
| 77 | + def _run( |
| 78 | + self, |
| 79 | + config: RunnableConfig, |
| 80 | + **kwargs: Any, |
| 81 | + ) -> str: |
| 82 | + tool_to_run = self.__core_tool |
| 83 | + if config and "configurable" in config and "auth_token_getters" in config["configurable"]: |
| 84 | + auth_token_getters = config["configurable"]["auth_token_getters"] |
| 85 | + if auth_token_getters: |
| 86 | + |
| 87 | + # The `add_auth_token_getters` method requires that all provided |
| 88 | + # getters are used by the tool. To prevent validation errors, |
| 89 | + # filter the incoming getters to include only those that this |
| 90 | + # specific tool requires. |
| 91 | + required_auth_keys = set(self.__core_tool._required_authz_tokens) |
| 92 | + for auth_list in self.__core_tool._required_authn_params.values(): |
| 93 | + required_auth_keys.update(auth_list) |
| 94 | + filtered_getters = {k: v for k, v in auth_token_getters.items() if k in required_auth_keys} |
| 95 | + if filtered_getters: |
| 96 | + tool_to_run = self.__core_tool.add_auth_token_getters(filtered_getters) |
| 97 | + |
| 98 | + return tool_to_run(**kwargs) |
| 99 | + |
| 100 | + async def _arun( |
| 101 | + self, |
| 102 | + config: RunnableConfig, |
| 103 | + **kwargs: Any) -> str: |
| 104 | + tool_to_run = self.__core_tool |
| 105 | + if config and "configurable" in config and "auth_token_getters" in config["configurable"]: |
| 106 | + auth_token_getters = config["configurable"]["auth_token_getters"] |
| 107 | + if auth_token_getters: |
| 108 | + |
| 109 | + # The `add_auth_token_getters` method requires that all provided |
| 110 | + # getters are used by the tool. To prevent validation errors, |
| 111 | + # filter the incoming getters to include only those that this |
| 112 | + # specific tool requires. |
| 113 | + required_auth_keys = set(self.__core_tool._required_authz_tokens) |
| 114 | + for auth_list in self.__core_tool._required_authn_params.values(): |
| 115 | + required_auth_keys.update(auth_list) |
| 116 | + filtered_getters = {k: v for k, v in auth_token_getters.items() if k in required_auth_keys} |
| 117 | + if filtered_getters: |
| 118 | + tool_to_run = self.__core_tool.add_auth_token_getters(filtered_getters) |
| 119 | + |
| 120 | + return await to_thread(tool_to_run, **kwargs) |
81 | 121 |
|
82 | 122 | def add_auth_token_getters(
|
83 | 123 | self, auth_token_getters: dict[str, Callable[[], str]]
|
|
0 commit comments