Skip to content

Commit b8ad419

Browse files
Ensure only single ate values are provided in surrogate_models
1 parent 6029edb commit b8ad419

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def search(
3535

3636
# The GA fitness function after including required variables into the function's scope
3737
# Unused arguments are required for pygad's fitness function signature
38-
#pylint: disable=cell-var-from-loop
39-
def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
38+
# pylint: disable=cell-var-from-loop
39+
def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
4040
surrogate.control_value = solution[0] - self.delta
4141
surrogate.treatment_value = solution[0] + self.delta
4242

@@ -45,7 +45,9 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
4545
adjustment_dict[adjustment] = solution[i + 1]
4646

4747
ate = surrogate.estimate_ate_calculated(adjustment_dict)
48-
48+
if len(ate) > 1:
49+
raise ValueError(
50+
"Multiple ate values provided but currently only single values supported in this method")
4951
return contradiction_function(ate[0])
5052

5153
gene_types, gene_space = self.create_gene_types(surrogate, specification)
@@ -82,7 +84,7 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
8284

8385
@staticmethod
8486
def create_gene_types(
85-
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
87+
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
8688
) -> tuple[list, list]:
8789
"""Generate the gene_types and gene_space for a given fitness function and specification
8890
:param surrogate_model: Instance of a CubicSplineRegressionEstimator

0 commit comments

Comments
 (0)