22
22
from .protocol import ParameterSchema
23
23
from .utils import (
24
24
create_func_docstring ,
25
- identify_required_authn_params ,
25
+ identify_auth_requirements ,
26
26
params_to_pydantic_model ,
27
27
resolve_value ,
28
28
)
@@ -49,6 +49,7 @@ def __init__(
49
49
description : str ,
50
50
params : Sequence [ParameterSchema ],
51
51
required_authn_params : Mapping [str , list [str ]],
52
+ required_authz_tokens : Sequence [str ],
52
53
auth_service_token_getters : Mapping [str , Callable [[], str ]],
53
54
bound_params : Mapping [str , Union [Callable [[], Any ], Any ]],
54
55
client_headers : Mapping [str , Union [Callable , Coroutine , str ]],
@@ -63,12 +64,14 @@ def __init__(
63
64
name: The name of the remote tool.
64
65
description: The description of the remote tool.
65
66
params: The args of the tool.
66
- required_authn_params: A map of required authenticated parameters to a list
67
- of alternative services that can provide values for them.
68
- auth_service_token_getters: A dict of authService -> token (or callables that
69
- produce a token)
70
- bound_params: A mapping of parameter names to bind to specific values or
71
- callables that are called to produce values as needed.
67
+ required_authn_params: A map of required authenticated parameters to
68
+ a list of alternative services that can provide values for them.
69
+ required_authz_tokens: A sequence of alternative services for
70
+ providing authorization token for the tool invocation.
71
+ auth_service_token_getters: A dict of authService -> token (or
72
+ callables that produce a token)
73
+ bound_params: A mapping of parameter names to bind to specific
74
+ values or callables that are called to produce values as needed.
72
75
client_headers: Client specific headers bound to the tool.
73
76
"""
74
77
# used to invoke the toolbox API
@@ -106,6 +109,8 @@ def __init__(
106
109
107
110
# map of parameter name to auth service required by it
108
111
self .__required_authn_params = required_authn_params
112
+ # sequence of authorization tokens required by it
113
+ self .__required_authz_tokens = required_authz_tokens
109
114
# map of authService -> token_getter
110
115
self .__auth_service_token_getters = auth_service_token_getters
111
116
# map of parameter name to value (or callable that produces that value)
@@ -149,6 +154,7 @@ def __copy(
149
154
description : Optional [str ] = None ,
150
155
params : Optional [Sequence [ParameterSchema ]] = None ,
151
156
required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
157
+ required_authz_tokens : Optional [Sequence [str ]] = None ,
152
158
auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
153
159
bound_params : Optional [Mapping [str , Union [Callable [[], Any ], Any ]]] = None ,
154
160
client_headers : Optional [Mapping [str , Union [Callable , Coroutine , str ]]] = None ,
@@ -162,12 +168,14 @@ def __copy(
162
168
name: The name of the remote tool.
163
169
description: The description of the remote tool.
164
170
params: The args of the tool.
165
- required_authn_params: A map of required authenticated parameters to a list
166
- of alternative services that can provide values for them.
167
- auth_service_token_getters: A dict of authService -> token (or callables
168
- that produce a token)
169
- bound_params: A mapping of parameter names to bind to specific values or
170
- callables that are called to produce values as needed.
171
+ required_authn_params: A map of required authenticated parameters to
172
+ a list of alternative services that can provide values for them.
173
+ required_authz_tokens: A sequence of alternative services for
174
+ providing authorization token for the tool invocation.
175
+ auth_service_token_getters: A dict of authService -> token (or
176
+ callables that produce a token)
177
+ bound_params: A mapping of parameter names to bind to specific
178
+ values or callables that are called to produce values as needed.
171
179
client_headers: Client specific headers bound to the tool.
172
180
"""
173
181
check = lambda val , default : val if val is not None else default
@@ -180,6 +188,9 @@ def __copy(
180
188
required_authn_params = check (
181
189
required_authn_params , self .__required_authn_params
182
190
),
191
+ required_authz_tokens = check (
192
+ required_authz_tokens , self .__required_authz_tokens
193
+ ),
183
194
auth_service_token_getters = check (
184
195
auth_service_token_getters , self .__auth_service_token_getters
185
196
),
@@ -207,11 +218,15 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
207
218
"""
208
219
209
220
# check if any auth services need to be specified yet
210
- if len (self .__required_authn_params ) > 0 :
221
+ if (
222
+ len (self .__required_authn_params ) > 0
223
+ or len (self .__required_authz_tokens ) > 0
224
+ ):
211
225
# Gather all the required auth services into a set
212
226
req_auth_services = set ()
213
227
for s in self .__required_authn_params .values ():
214
228
req_auth_services .update (s )
229
+ req_auth_services .update (self .__required_authz_tokens )
215
230
raise ValueError (
216
231
f"One or more of the following authn services are required to invoke this tool"
217
232
f": { ',' .join (req_auth_services )} "
@@ -292,23 +307,24 @@ def add_auth_token_getters(
292
307
f"Cannot register client the same headers in the client as well as tool."
293
308
)
294
309
295
- # create a read-only updated value for new_getters
296
- new_getters = MappingProxyType (
297
- dict (self .__auth_service_token_getters , ** auth_token_getters )
298
- )
299
- # create a read-only updated for params that are still required
300
- new_req_authn_params = MappingProxyType (
301
- identify_required_authn_params (
302
- # TODO: Add authRequired
310
+ new_getters = dict (self .__auth_service_token_getters , ** auth_token_getters )
311
+
312
+ # find the updated requirements
313
+ new_req_authn_params , new_req_authz_tokens , used_auth_token_getters = (
314
+ identify_auth_requirements (
303
315
self .__required_authn_params ,
304
- [] ,
316
+ self . __required_authz_tokens ,
305
317
auth_token_getters .keys (),
306
- )[ 0 ]
318
+ )
307
319
)
308
320
321
+ # TODO: Add validation for used_auth_token_getters
322
+
309
323
return self .__copy (
310
- auth_service_token_getters = new_getters ,
311
- required_authn_params = new_req_authn_params ,
324
+ # create a read-only map for updated getters, params and tokens that are still required
325
+ auth_service_token_getters = MappingProxyType (new_getters ),
326
+ required_authn_params = MappingProxyType (new_req_authn_params ),
327
+ required_authz_tokens = tuple (new_req_authz_tokens ),
312
328
)
313
329
314
330
def bind_params (
0 commit comments