@@ -1052,8 +1052,35 @@ def to_equality_constraint(
1052
1052
finalized_constraint = to_equality_constraint (parsed_restriction , params_used )
1053
1053
if finalized_constraint is None :
1054
1054
# we must turn it into a general function
1055
- finalized_constraint = f"def r({ ', ' .join (params_used )} ): return { parsed_restriction } \n "
1055
+ if format .lower () == "pyatf" :
1056
+ finalized_constraint = parsed_restriction
1057
+ else :
1058
+ finalized_constraint = f"def r({ ', ' .join (params_used )} ): return { parsed_restriction } \n "
1056
1059
parsed_restrictions .append ((finalized_constraint , params_used ))
1060
+
1061
+ # if pyATF, restrictions that are set on the same parameter must be combined into one
1062
+ if format .lower () == "pyatf" :
1063
+ res_dict = dict ()
1064
+ registered_params = list ()
1065
+ registered_restrictions = list ()
1066
+ for param in tune_params .keys ():
1067
+ registered_params .append (param )
1068
+ for index , (res , params ) in enumerate (parsed_restrictions ):
1069
+ if index in registered_restrictions :
1070
+ continue
1071
+ if all (p in registered_params for p in params ):
1072
+ if param not in res_dict :
1073
+ res_dict [param ] = (list (), list ())
1074
+ res_dict [param ][0 ].append (res )
1075
+ res_dict [param ][1 ].extend (params )
1076
+ registered_restrictions .append (index )
1077
+ # combine multiple restrictions into one
1078
+ parsed_restrictions_pyatf = list ()
1079
+ for res_tuple in res_dict .values ():
1080
+ res , params_used = res_tuple
1081
+ params_used = list (dict .fromkeys (params_used )) # param_used should only contain unique, dict preserves order
1082
+ parsed_restrictions_pyatf .append ((f"def r({ ', ' .join (params_used )} ): return { ' and ' .join (res )} \n " , params_used ))
1083
+ return parsed_restrictions_pyatf
1057
1084
else :
1058
1085
# create one monolithic function
1059
1086
parsed_restrictions = ") and (" .join (
0 commit comments