@@ -250,12 +250,15 @@ def check_restriction(restrict, params: dict) -> bool:
250250 # if it's a tuple, use only the parameters in the second argument to call the restriction
251251 elif (
252252 isinstance (restrict , tuple )
253- and len (restrict ) == 2
253+ and ( len (restrict ) == 2 or len ( restrict ) == 3 )
254254 and callable (restrict [0 ])
255255 and isinstance (restrict [1 ], (list , tuple ))
256256 ):
257257 # unpack the tuple
258- restrict , selected_params = restrict
258+ if len (restrict ) == 2 :
259+ restrict , selected_params = restrict
260+ else :
261+ restrict , selected_params , source = restrict
259262 # look up the selected parameters and their value
260263 selected_params = dict ((key , params [key ]) for key in selected_params )
261264 # call the restriction
@@ -1061,14 +1064,14 @@ def to_equality_constraint(
10611064 finalized_constraint = to_equality_constraint (parsed_restriction , params_used )
10621065 if finalized_constraint is None :
10631066 # we must turn it into a general function
1064- if format .lower () == "pyatf" :
1067+ if format is not None and format .lower () == "pyatf" :
10651068 finalized_constraint = parsed_restriction
10661069 else :
10671070 finalized_constraint = f"def r({ ', ' .join (params_used )} ): return { parsed_restriction } \n "
10681071 parsed_restrictions .append ((finalized_constraint , params_used ))
10691072
10701073 # if pyATF, restrictions that are set on the same parameter must be combined into one
1071- if format .lower () == "pyatf" :
1074+ if format is not None and format .lower () == "pyatf" :
10721075 res_dict = dict ()
10731076 registered_params = list ()
10741077 registered_restrictions = list ()
0 commit comments