@@ -250,12 +250,15 @@ def check_restriction(restrict, params: dict) -> bool:
250
250
# if it's a tuple, use only the parameters in the second argument to call the restriction
251
251
elif (
252
252
isinstance (restrict , tuple )
253
- and len (restrict ) == 2
253
+ and ( len (restrict ) == 2 or len ( restrict ) == 3 )
254
254
and callable (restrict [0 ])
255
255
and isinstance (restrict [1 ], (list , tuple ))
256
256
):
257
257
# unpack the tuple
258
- restrict , selected_params = restrict
258
+ if len (restrict ) == 2 :
259
+ restrict , selected_params = restrict
260
+ else :
261
+ restrict , selected_params , source = restrict
259
262
# look up the selected parameters and their value
260
263
selected_params = dict ((key , params [key ]) for key in selected_params )
261
264
# call the restriction
@@ -1061,14 +1064,14 @@ def to_equality_constraint(
1061
1064
finalized_constraint = to_equality_constraint (parsed_restriction , params_used )
1062
1065
if finalized_constraint is None :
1063
1066
# we must turn it into a general function
1064
- if format .lower () == "pyatf" :
1067
+ if format is not None and format .lower () == "pyatf" :
1065
1068
finalized_constraint = parsed_restriction
1066
1069
else :
1067
1070
finalized_constraint = f"def r({ ', ' .join (params_used )} ): return { parsed_restriction } \n "
1068
1071
parsed_restrictions .append ((finalized_constraint , params_used ))
1069
1072
1070
1073
# if pyATF, restrictions that are set on the same parameter must be combined into one
1071
- if format .lower () == "pyatf" :
1074
+ if format is not None and format .lower () == "pyatf" :
1072
1075
res_dict = dict ()
1073
1076
registered_params = list ()
1074
1077
registered_restrictions = list ()
0 commit comments