1
1
import asyncio
2
- from typing import Any , Optional , Type
2
+ import warnings
3
+ from typing import Any , Callable , Optional , Type
3
4
4
5
from aiohttp import ClientSession
5
6
from langchain_core .tools import StructuredTool
@@ -20,6 +21,8 @@ def __init__(self, url: str, session: Optional[ClientSession] = None):
20
21
"""
21
22
self ._url : str = url
22
23
self ._should_close_session : bool = session is None
24
+ self ._id_token_getters : dict [str , Callable [[], str ]] = {}
25
+ self ._tool_param_auth : dict [str , dict [str , list [str ]]] = {}
23
26
self ._session : ClientSession = session or ClientSession ()
24
27
25
28
async def close (self ) -> None :
@@ -77,6 +80,35 @@ async def _load_toolset_manifest(
77
80
url = f"{ self ._url } /api/toolset/{ toolset_name or '' } "
78
81
return await _load_yaml (url , self ._session )
79
82
83
+ def _validate_auth (self , tool_name : str ) -> bool :
84
+ """
85
+ Helper method that validates the authentication requirements of the tool
86
+ with the given tool_name. We consider the validation to pass if at least
87
+ one auth sources of each of the auth parameters, of the given tool, is
88
+ registered.
89
+
90
+ Args:
91
+ tool_name: Name of the tool to validate auth sources for.
92
+
93
+ Returns:
94
+ True if at least one permitted auth source of each of the auth
95
+ params, of the given tool, is registered. Also returns True if the
96
+ given tool does not require any auth sources.
97
+ """
98
+
99
+ if tool_name not in self ._tool_param_auth :
100
+ return True
101
+
102
+ for permitted_auth_sources in self ._tool_param_auth [tool_name ].values ():
103
+ found_match = False
104
+ for registered_auth_source in self ._id_token_getters :
105
+ if registered_auth_source in permitted_auth_sources :
106
+ found_match = True
107
+ break
108
+ if not found_match :
109
+ return False
110
+ return True
111
+
80
112
def _generate_tool (
81
113
self , tool_name : str , manifest : ManifestSchema
82
114
) -> StructuredTool :
@@ -96,8 +128,16 @@ def _generate_tool(
96
128
model_name = tool_name , schema = tool_schema .parameters
97
129
)
98
130
131
+ # If the tool had parameters that require authentication, then right
132
+ # before invoking that tool, we validate whether all these required
133
+ # authentication sources have been registered or not.
99
134
async def _tool_func (** kwargs : Any ) -> dict :
100
- return await _invoke_tool (self ._url , self ._session , tool_name , kwargs )
135
+ if not self ._validate_auth (tool_name ):
136
+ raise PermissionError (f"Login required before invoking { tool_name } ." )
137
+
138
+ return await _invoke_tool (
139
+ self ._url , self ._session , tool_name , kwargs , self ._id_token_getters
140
+ )
101
141
102
142
return StructuredTool .from_function (
103
143
coroutine = _tool_func ,
@@ -106,21 +146,89 @@ async def _tool_func(**kwargs: Any) -> dict:
106
146
args_schema = tool_model ,
107
147
)
108
148
109
- async def load_tool (self , tool_name : str ) -> StructuredTool :
149
+ def _process_auth_params (self , manifest : ManifestSchema ) -> None :
150
+ """
151
+ Extracts parameters requiring authentication from the manifest.
152
+ Verifies each parameter has at least one valid auth source.
153
+
154
+ Args:
155
+ manifest: The manifest to validate and modify.
156
+
157
+ Warns:
158
+ UserWarning: If a parameter in the manifest has no valid sources.
159
+ """
160
+ for tool_name , tool_schema in manifest .tools .items ():
161
+ non_auth_params = []
162
+ for param in tool_schema .parameters :
163
+
164
+ # Extract auth params from the tool schema.
165
+ #
166
+ # These parameters are removed from the manifest to prevent data
167
+ # validation errors since their values are inferred by the
168
+ # Toolbox service, not provided by the user.
169
+ #
170
+ # Store the permitted authentication sources for each parameter
171
+ # in '_tool_param_auth' for efficient validation in
172
+ # '_validate_auth'.
173
+ if not param .authSources :
174
+ non_auth_params .append (param )
175
+ continue
176
+
177
+ self ._tool_param_auth .setdefault (tool_name , {})[
178
+ param .name
179
+ ] = param .authSources
180
+
181
+ tool_schema .parameters = non_auth_params
182
+
183
+ # If none of the permitted auth sources of a parameter are
184
+ # registered, raise a warning message to the user.
185
+ if not self ._validate_auth (tool_name ):
186
+ warnings .warn (
187
+ f"Some parameters of tool { tool_name } require authentication, but no valid auth sources are registered. Please register the required sources before use."
188
+ )
189
+
190
+ def add_auth_header (
191
+ self , auth_source : str , get_id_token : Callable [[], str ]
192
+ ) -> None :
193
+ """
194
+ Registers a function to retrieve an ID token for a given authentication
195
+ source.
196
+
197
+ Args:
198
+ auth_source : The name of the authentication source.
199
+ get_id_token: A function that returns the ID token.
200
+ """
201
+ self ._id_token_getters [auth_source ] = get_id_token
202
+
203
+ async def load_tool (
204
+ self , tool_name : str , auth_headers : dict [str , Callable [[], str ]] = {}
205
+ ) -> StructuredTool :
110
206
"""
111
207
Loads the tool, with the given tool name, from the Toolbox service.
112
208
113
209
Args:
114
210
tool_name: The name of the tool to load.
211
+ auth_headers: A mapping of authentication source names to
212
+ functions that retrieve ID tokens. If provided, these will
213
+ override or be added to the existing ID token getters.
214
+ Default: Empty.
115
215
116
216
Returns:
117
217
A tool loaded from the Toolbox
118
218
"""
219
+ for auth_source , get_id_token in auth_headers .items ():
220
+ self .add_auth_header (auth_source , get_id_token )
221
+
119
222
manifest : ManifestSchema = await self ._load_tool_manifest (tool_name )
223
+
224
+ self ._process_auth_params (manifest )
225
+
120
226
return self ._generate_tool (tool_name , manifest )
121
227
122
228
async def load_toolset (
123
- self , toolset_name : Optional [str ] = None
229
+ self ,
230
+ toolset_name : Optional [str ] = None ,
231
+ auth_headers : dict [str , Callable [[], str ]] = {},
124
232
) -> list [StructuredTool ]:
125
233
"""
126
234
Loads tools from the Toolbox service, optionally filtered by toolset
@@ -129,12 +237,22 @@ async def load_toolset(
129
237
Args:
130
238
toolset_name: The name of the toolset to load.
131
239
Default: None. If not provided, then all the tools are loaded.
240
+ auth_headers: A mapping of authentication source names to
241
+ functions that retrieve ID tokens. If provided, these will
242
+ override or be added to the existing ID token getters.
243
+ Default: Empty.
132
244
133
245
Returns:
134
246
A list of all tools loaded from the Toolbox.
135
247
"""
248
+ for auth_source , get_id_token in auth_headers .items ():
249
+ self .add_auth_header (auth_source , get_id_token )
250
+
136
251
tools : list [StructuredTool ] = []
137
252
manifest : ManifestSchema = await self ._load_toolset_manifest (toolset_name )
253
+
254
+ self ._process_auth_params (manifest )
255
+
138
256
for tool_name in manifest .tools :
139
257
tools .append (self ._generate_tool (tool_name , manifest ))
140
258
return tools
0 commit comments