@@ -54,7 +54,7 @@ def __init__(
54
54
params : Sequence [Parameter ],
55
55
required_authn_params : Mapping [str , list [str ]],
56
56
auth_service_token_getters : Mapping [str , Callable [[], str ]],
57
- bound_params : Mapping [str , Union [Callable [[], Any ], Any ]] = {} ,
57
+ bound_params : Mapping [str , Union [Callable [[], Any ], Any ]],
58
58
):
59
59
"""
60
60
Initializes a callable that will trigger the tool invocation through the
@@ -67,10 +67,13 @@ def __init__(
67
67
desc: The description of the remote tool (used as its docstring).
68
68
params: A list of `inspect.Parameter` objects defining the tool's
69
69
arguments and their types/defaults.
70
- required_authn_params: A dict of required authenticated parameters that
71
- need a auth_service_token_getter set for them yet .
72
- auth_service_tokens : A dict of authService -> token (or callables that
70
+ required_authn_params: A dict of required authenticated parameters to a list
71
+ of services that provide values for them.
72
+ auth_service_token_getters : A dict of authService -> token (or callables that
73
73
produce a token)
74
+ bound_params: A mapping of parameter names to bind to specific values or
75
+ callables that are called to produce values as needed.
76
+
74
77
"""
75
78
76
79
# used to invoke the toolbox API
@@ -92,7 +95,7 @@ def __init__(
92
95
self .__required_authn_params = required_authn_params
93
96
# map of authService -> token_getter
94
97
self .__auth_service_token_getters = auth_service_token_getters
95
- # map of parameter name to value or Callable
98
+ # map of parameter name to value ( or callable that produces that value)
96
99
self .__bound_parameters = bound_params
97
100
98
101
def __copy (
@@ -120,18 +123,24 @@ def __copy(
120
123
a auth_service_token_getter set for them yet.
121
124
auth_service_token_getters: A dict of authService -> token (or callables
122
125
that produce a token)
126
+ bound_params: A mapping of parameter names to bind to specific values or
127
+ callables that are called to produce values as needed.
123
128
124
129
"""
130
+ check = lambda val , default : val if val is not None else default
125
131
return ToolboxTool (
126
- session = session or self .__session ,
127
- base_url = base_url or self .__base_url ,
128
- name = name or self .__name__ ,
129
- desc = desc or self .__desc ,
130
- params = params or self .__params ,
131
- required_authn_params = required_authn_params or self .__required_authn_params ,
132
- auth_service_token_getters = auth_service_token_getters
133
- or self .__auth_service_token_getters ,
134
- bound_params = bound_params or self .__bound_parameters ,
132
+ session = check (session , self .__session ),
133
+ base_url = check (base_url , self .__base_url ),
134
+ name = check (name , self .__name__ ),
135
+ desc = check (desc , self .__desc ),
136
+ params = check (params , self .__params ),
137
+ required_authn_params = check (
138
+ required_authn_params , self .__required_authn_params
139
+ ),
140
+ auth_service_token_getters = check (
141
+ auth_service_token_getters , self .__auth_service_token_getters
142
+ ),
143
+ bound_params = check (bound_params , self .__bound_parameters ),
135
144
)
136
145
137
146
async def __call__ (self , * args : Any , ** kwargs : Any ) -> str :
@@ -151,9 +160,12 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
151
160
152
161
# check if any auth services need to be specified yet
153
162
if len (self .__required_authn_params ) > 0 :
154
- req_auth_services = set (l for l in self .__required_authn_params .keys ())
163
+ # Gather all the required auth services into a set
164
+ req_auth_services = set ()
165
+ for s in self .__required_authn_params .values ():
166
+ req_auth_services .update (s )
155
167
raise Exception (
156
- f"One of more of the following authn services are required to invoke this tool: { ',' .join (req_auth_services )} "
168
+ f"One or more of the following authn services are required to invoke this tool: { ',' .join (req_auth_services )} "
157
169
)
158
170
159
171
# validate inputs to this call using the signature
@@ -190,7 +202,7 @@ def add_auth_token_getters(
190
202
auth_token_getters : Mapping [str , Callable [[], str ]],
191
203
) -> "ToolboxTool" :
192
204
"""
193
- Registers a auth token getter function that is used for AuthServices when tools
205
+ Registers an auth token getter function that is used for AuthServices when tools
194
206
are invoked.
195
207
196
208
Args:
@@ -203,10 +215,12 @@ def add_auth_token_getters(
203
215
"""
204
216
205
217
# throw an error if the authentication source is already registered
206
- dupes = auth_token_getters .keys () & self .__auth_service_token_getters .keys ()
207
- if dupes :
218
+ existing_services = self .__auth_service_token_getters .keys ()
219
+ incoming_services = auth_token_getters .keys ()
220
+ duplicates = existing_services & incoming_services
221
+ if duplicates :
208
222
raise ValueError (
209
- f"Authentication source(s) `{ ', ' .join (dupes )} ` already registered in tool `{ self .__name__ } `."
223
+ f"Authentication source(s) `{ ', ' .join (duplicates )} ` already registered in tool `{ self .__name__ } `."
210
224
)
211
225
212
226
# create a read-only updated value for new_getters
@@ -215,7 +229,7 @@ def add_auth_token_getters(
215
229
)
216
230
# create a read-only updated for params that are still required
217
231
new_req_authn_params = types .MappingProxyType (
218
- filter_required_authn_params (
232
+ identify_required_authn_params (
219
233
self .__required_authn_params , auth_token_getters .keys ()
220
234
)
221
235
)
@@ -226,7 +240,7 @@ def add_auth_token_getters(
226
240
)
227
241
228
242
def bind_parameters (
229
- self , bound_params : Mapping [str , Callable [[], str ]]
243
+ self , bound_params : Mapping [str , Union [ Callable [[], Any ], Any ]]
230
244
) -> "ToolboxTool" :
231
245
"""
232
246
Binds parameters to values or callables that produce values.
@@ -238,9 +252,9 @@ def bind_parameters(
238
252
Returns:
239
253
A new ToolboxTool instance with the specified parameters bound.
240
254
"""
241
- all_params = set (p .name for p in self .__params )
255
+ param_names = set (p .name for p in self .__params )
242
256
for name in bound_params .keys ():
243
- if name not in all_params :
257
+ if name not in param_names :
244
258
raise Exception (f"unable to bind parameters: no parameter named { name } " )
245
259
246
260
new_params = []
@@ -254,27 +268,28 @@ def bind_parameters(
254
268
)
255
269
256
270
257
- def filter_required_authn_params (
258
- req_authn_params : Mapping [str , list [str ]], auth_services : Iterable [str ]
271
+ def identify_required_authn_params (
272
+ req_authn_params : Mapping [str , list [str ]], auth_service_names : Iterable [str ]
259
273
) -> dict [str , list [str ]]:
260
274
"""
261
- Utility function for reducing 'req_authn_params' to a subset of parameters that
262
- aren't supplied by a least one service in auth_services .
275
+ Identifies authentication parameters that are still required; because they
276
+ not covered by the provided `auth_service_names` .
263
277
264
- Args:
265
- req_authn_params: A mapping of parameter names to sets of required
266
- authentication services.
267
- auth_services : An iterable of authentication service names for which
268
- token getters are available.
278
+ Args:
279
+ req_authn_params: A mapping of parameter names to sets of required
280
+ authentication services.
281
+ auth_service_names : An iterable of authentication service names for which
282
+ token getters are available.
269
283
270
284
Returns:
271
285
A new dictionary representing the subset of required authentication parameters
272
286
that are not covered by the provided `auth_services`.
273
287
"""
274
- req_params = {}
288
+ required_params = {} # params that are still required with provided auth_services
275
289
for param , services in req_authn_params .items ():
276
- # if we don't have a token_getter for any of the services required by the param, the param is still required
277
- required = not any (s in services for s in auth_services )
290
+ # if we don't have a token_getter for any of the services required by the param,
291
+ # the param is still required
292
+ required = not any (s in services for s in auth_service_names )
278
293
if required :
279
- req_params [param ] = services
280
- return req_params
294
+ required_params [param ] = services
295
+ return required_params
0 commit comments