@@ -49,6 +49,7 @@ def __init__(
49
49
description : str ,
50
50
params : Sequence [ParameterSchema ],
51
51
required_authn_params : Mapping [str , list [str ]],
52
+ required_authz_tokens : Sequence [str ],
52
53
auth_service_token_getters : Mapping [str , Callable [[], str ]],
53
54
bound_params : Mapping [str , Union [Callable [[], Any ], Any ]],
54
55
client_headers : Mapping [str , Union [Callable , Coroutine , str ]],
@@ -63,12 +64,14 @@ def __init__(
63
64
name: The name of the remote tool.
64
65
description: The description of the remote tool.
65
66
params: The args of the tool.
66
- required_authn_params: A map of required authenticated parameters to a list
67
- of alternative services that can provide values for them.
68
- auth_service_token_getters: A dict of authService -> token (or callables that
69
- produce a token)
70
- bound_params: A mapping of parameter names to bind to specific values or
71
- callables that are called to produce values as needed.
67
+ required_authn_params: A map of required authenticated parameters to
68
+ a list of alternative services that can provide values for them.
69
+ required_authz_tokens: A sequence of alternative services for
70
+ providing authorization token for the tool invocation.
71
+ auth_service_token_getters: A dict of authService -> token (or
72
+ callables that produce a token)
73
+ bound_params: A mapping of parameter names to bind to specific
74
+ values or callables that are called to produce values as needed.
72
75
client_headers: Client specific headers bound to the tool.
73
76
"""
74
77
# used to invoke the toolbox API
@@ -106,6 +109,8 @@ def __init__(
106
109
107
110
# map of parameter name to auth service required by it
108
111
self .__required_authn_params = required_authn_params
112
+ # sequence of authorization tokens required by it
113
+ self .__required_authz_tokens = required_authz_tokens
109
114
# map of authService -> token_getter
110
115
self .__auth_service_token_getters = auth_service_token_getters
111
116
# map of parameter name to value (or callable that produces that value)
@@ -121,6 +126,7 @@ def __copy(
121
126
description : Optional [str ] = None ,
122
127
params : Optional [Sequence [ParameterSchema ]] = None ,
123
128
required_authn_params : Optional [Mapping [str , list [str ]]] = None ,
129
+ required_authz_tokens : Optional [Sequence [str ]] = None ,
124
130
auth_service_token_getters : Optional [Mapping [str , Callable [[], str ]]] = None ,
125
131
bound_params : Optional [Mapping [str , Union [Callable [[], Any ], Any ]]] = None ,
126
132
client_headers : Optional [Mapping [str , Union [Callable , Coroutine , str ]]] = None ,
@@ -134,12 +140,14 @@ def __copy(
134
140
name: The name of the remote tool.
135
141
description: The description of the remote tool.
136
142
params: The args of the tool.
137
- required_authn_params: A map of required authenticated parameters to a list
138
- of alternative services that can provide values for them.
139
- auth_service_token_getters: A dict of authService -> token (or callables
140
- that produce a token)
141
- bound_params: A mapping of parameter names to bind to specific values or
142
- callables that are called to produce values as needed.
143
+ required_authn_params: A map of required authenticated parameters to
144
+ a list of alternative services that can provide values for them.
145
+ required_authz_tokens: A sequence of alternative services for
146
+ providing authorization token for the tool invocation.
147
+ auth_service_token_getters: A dict of authService -> token (or
148
+ callables that produce a token)
149
+ bound_params: A mapping of parameter names to bind to specific
150
+ values or callables that are called to produce values as needed.
143
151
client_headers: Client specific headers bound to the tool.
144
152
"""
145
153
check = lambda val , default : val if val is not None else default
@@ -152,6 +160,9 @@ def __copy(
152
160
required_authn_params = check (
153
161
required_authn_params , self .__required_authn_params
154
162
),
163
+ required_authz_tokens = check (
164
+ required_authz_tokens , self .__required_authz_tokens
165
+ ),
155
166
auth_service_token_getters = check (
156
167
auth_service_token_getters , self .__auth_service_token_getters
157
168
),
@@ -179,11 +190,15 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
179
190
"""
180
191
181
192
# check if any auth services need to be specified yet
182
- if len (self .__required_authn_params ) > 0 :
193
+ if (
194
+ len (self .__required_authn_params ) > 0
195
+ or len (self .__required_authz_tokens ) > 0
196
+ ):
183
197
# Gather all the required auth services into a set
184
198
req_auth_services = set ()
185
199
for s in self .__required_authn_params .values ():
186
200
req_auth_services .update (s )
201
+ req_auth_services .update (self .__required_authz_tokens )
187
202
raise ValueError (
188
203
f"One or more of the following authn services are required to invoke this tool"
189
204
f": { ',' .join (req_auth_services )} "
@@ -269,18 +284,20 @@ def add_auth_token_getters(
269
284
dict (self .__auth_service_token_getters , ** auth_token_getters )
270
285
)
271
286
# create a read-only updated for params that are still required
272
- new_req_authn_params = types . MappingProxyType (
287
+ new_req_authn_params , new_req_authz_tokens , used_auth_token_getters = (
273
288
identify_required_authn_params (
274
- # TODO: Add authRequired
275
289
self .__required_authn_params ,
276
- [] ,
290
+ self . __required_authz_tokens ,
277
291
auth_token_getters .keys (),
278
- )[ 0 ]
292
+ )
279
293
)
280
294
295
+ # TODO: Add validation for used_auth_token_getters
296
+
281
297
return self .__copy (
282
298
auth_service_token_getters = new_getters ,
283
- required_authn_params = new_req_authn_params ,
299
+ required_authn_params = types .MappingProxyType (new_req_authn_params ),
300
+ required_authz_tokens = new_req_authz_tokens ,
284
301
)
285
302
286
303
def bind_params (
0 commit comments