@@ -35,8 +35,8 @@ def search(
35
35
36
36
# The GA fitness function after including required variables into the function's scope
37
37
# Unused arguments are required for pygad's fitness function signature
38
- #pylint: disable=cell-var-from-loop
39
- def fitness_function (ga , solution , idx ): # pylint: disable=unused-argument
38
+ # pylint: disable=cell-var-from-loop
39
+ def fitness_function (ga , solution , idx ): # pylint: disable=unused-argument
40
40
surrogate .control_value = solution [0 ] - self .delta
41
41
surrogate .treatment_value = solution [0 ] + self .delta
42
42
@@ -45,7 +45,9 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
45
45
adjustment_dict [adjustment ] = solution [i + 1 ]
46
46
47
47
ate = surrogate .estimate_ate_calculated (adjustment_dict )
48
-
48
+ if len (ate ) > 1 :
49
+ raise ValueError (
50
+ "Multiple ate values provided but currently only single values supported in this method" )
49
51
return contradiction_function (ate [0 ])
50
52
51
53
gene_types , gene_space = self .create_gene_types (surrogate , specification )
@@ -82,7 +84,7 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
82
84
83
85
@staticmethod
84
86
def create_gene_types (
85
- surrogate_model : CubicSplineRegressionEstimator , specification : CausalSpecification
87
+ surrogate_model : CubicSplineRegressionEstimator , specification : CausalSpecification
86
88
) -> tuple [list , list ]:
87
89
"""Generate the gene_types and gene_space for a given fitness function and specification
88
90
:param surrogate_model: Instance of a CubicSplineRegressionEstimator
0 commit comments