Skip to content

Commit cc0485f

Browse files
committed
Fixed bug in fitness function scope
1 parent 0a816a3 commit cc0485f

File tree

3 files changed

+27
-63
lines changed

3 files changed

+27
-63
lines changed

causal_testing/surrogate/causal_surrogate_assisted.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,15 @@ class SimulationResult:
1919
relationship: str
2020

2121

22-
@dataclass
23-
class SearchFitnessFunction:
24-
"""Data class containing the Fitness function and related model"""
25-
26-
fitness_function: Any
27-
surrogate_model: CubicSplineRegressionEstimator
28-
29-
3022
class SearchAlgorithm(ABC):
3123
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
3224
space to be searched"""
3325

3426
@abstractmethod
35-
def generate_fitness_functions(self, surrogate_models: list[Estimator]) -> list[SearchFitnessFunction]:
36-
"""Generates the fitness function of the search space
37-
:param surrogate_models: A list of CubicSplineRegressionEstimator generated for each edge of the DAG
38-
:return: A list of fitness functions mapping to each of the surrogate models in the input"""
39-
40-
@abstractmethod
41-
def search(self, fitness_functions: list[SearchFitnessFunction], specification: CausalSpecification) -> list:
27+
def search(self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification) -> list:
4228
"""Function which implements a search routine which searches for the optimal fitness value for the specified
4329
scenario
44-
:param fitness_functions: The fitness function to be optimised
30+
:param surrogate_models: The surrogate models to be searched
4531
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
4632

4733

@@ -95,8 +81,7 @@ def execute(
9581

9682
for i in range(max_executions):
9783
surrogate_models = self.generate_surrogates(self.specification, data_collector)
98-
fitness_functions = self.search_algorithm.generate_fitness_functions(surrogate_models)
99-
candidate_test_case, _, surrogate = self.search_algorithm.search(fitness_functions, self.specification)
84+
candidate_test_case, _, surrogate = self.search_algorithm.search(surrogate_models, self.specification)
10085

10186
self.simulator.startup()
10287
test_result = self.simulator.run_with_config(candidate_test_case)
@@ -123,7 +108,7 @@ def execute(
123108

124109
def generate_surrogates(
125110
self, specification: CausalSpecification, data_collector: ObservationalDataCollector
126-
) -> list[SearchFitnessFunction]:
111+
) -> list[CubicSplineRegressionEstimator]:
127112
"""Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
128113
:param specification: The Causal Specification (combination of Scenario and Causal Dag)
129114
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from causal_testing.specification.causal_specification import CausalSpecification
99
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
10-
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm, SearchFitnessFunction
10+
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm
1111

1212

1313
class GeneticSearchAlgorithm(SearchAlgorithm):
@@ -25,15 +25,15 @@ def __init__(self, delta=0.05, config: dict = None) -> None:
2525
"some_effect": lambda x: abs(1 / x),
2626
}
2727

28-
def generate_fitness_functions(
29-
self, surrogate_models: list[CubicSplineRegressionEstimator]
30-
) -> list[SearchFitnessFunction]:
31-
fitness_functions = []
28+
def search(
29+
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
30+
) -> list:
31+
solutions = []
3232

3333
for surrogate in surrogate_models:
3434
contradiction_function = self.contradiction_functions[surrogate.expected_relationship]
3535

36-
# The returned fitness function after including required variables into the function's scope
36+
# The GA fitness function after including required variables into the function's scope
3737
# Unused arguments are required for pygad's fitness function signature
3838
def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
3939
surrogate.control_value = solution[0] - self.delta
@@ -46,25 +46,15 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
4646
ate = surrogate.estimate_ate_calculated(adjustment_dict)
4747

4848
return contradiction_function(ate)
49-
50-
search_fitness_function = SearchFitnessFunction(fitness_function, surrogate)
51-
52-
fitness_functions.append(search_fitness_function)
53-
54-
return fitness_functions
55-
56-
def search(self, fitness_functions: list[SearchFitnessFunction], specification: CausalSpecification) -> list:
57-
solutions = []
58-
59-
for fitness_function in fitness_functions:
60-
gene_types, gene_space = self.create_gene_types(fitness_function, specification)
49+
50+
gene_types, gene_space = self.create_gene_types(surrogate, specification)
6151

6252
ga = GA(
6353
num_generations=200,
6454
num_parents_mating=4,
65-
fitness_func=fitness_function.fitness_function,
55+
fitness_func=fitness_function,
6656
sol_per_pop=10,
67-
num_genes=1 + len(fitness_function.surrogate_model.adjustment_set),
57+
num_genes=1 + len(surrogate.adjustment_set),
6858
gene_space=gene_space,
6959
gene_type=gene_types,
7060
)
@@ -82,24 +72,24 @@ def search(self, fitness_functions: list[SearchFitnessFunction], specification:
8272
solution, fitness, _ = ga.best_solution()
8373

8474
solution_dict = {}
85-
solution_dict[fitness_function.surrogate_model.treatment] = solution[0]
86-
for idx, adj in enumerate(fitness_function.surrogate_model.adjustment_set):
75+
solution_dict[surrogate.treatment] = solution[0]
76+
for idx, adj in enumerate(surrogate.adjustment_set):
8777
solution_dict[adj] = solution[idx + 1]
88-
solutions.append((solution_dict, fitness, fitness_function.surrogate_model))
78+
solutions.append((solution_dict, fitness, surrogate))
8979

9080
return max(solutions, key=itemgetter(1)) # This can be done better with fitness normalisation between edges
9181

9282
@staticmethod
9383
def create_gene_types(
94-
fitness_function: SearchFitnessFunction, specification: CausalSpecification
84+
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
9585
) -> tuple[list, list]:
9686
"""Generate the gene_types and gene_space for a given fitness function and specification
97-
:param fitness_function: Instance of SearchFitnessFunction
87+
:param surrogate_model: Instance of a CubicSplineRegressionEstimator
9888
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
9989

10090
var_space = {}
101-
var_space[fitness_function.surrogate_model.treatment] = {}
102-
for adj in fitness_function.surrogate_model.adjustment_set:
91+
var_space[surrogate_model.treatment] = {}
92+
for adj in surrogate_model.adjustment_set:
10393
var_space[adj] = {}
10494

10595
for relationship in list(specification.scenario.constraints):
@@ -112,12 +102,12 @@ def create_gene_types(
112102
var_space[rel_split[0]]["high"] = int(rel_split[2])
113103

114104
gene_space = []
115-
gene_space.append(var_space[fitness_function.surrogate_model.treatment])
116-
for adj in fitness_function.surrogate_model.adjustment_set:
105+
gene_space.append(var_space[surrogate_model.treatment])
106+
for adj in surrogate_model.adjustment_set:
117107
gene_space.append(var_space[adj])
118108

119109
gene_types = []
120-
gene_types.append(specification.scenario.variables.get(fitness_function.surrogate_model.treatment).datatype)
121-
for adj in fitness_function.surrogate_model.adjustment_set:
110+
gene_types.append(specification.scenario.variables.get(surrogate_model.treatment).datatype)
111+
for adj in surrogate_model.adjustment_set:
122112
gene_types.append(specification.scenario.variables.get(adj).datatype)
123113
return gene_types, gene_space

tests/surrogate_tests/test_causal_surrogate_assisted.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from causal_testing.specification.causal_specification import CausalSpecification
55
from causal_testing.specification.scenario import Scenario
66
from causal_testing.specification.variable import Input
7-
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm, SimulationResult, SearchFitnessFunction, CausalSurrogateAssistedTestCase, Simulator
7+
from causal_testing.surrogate.causal_surrogate_assisted import SimulationResult, CausalSurrogateAssistedTestCase, Simulator
88
from causal_testing.surrogate.surrogate_search_algorithms import GeneticSearchAlgorithm
99
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
1010
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
@@ -36,7 +36,7 @@ def test_inputs(self):
3636

3737
self.assertEqual(result.relationship, relationship)
3838

39-
class TestSearchFitnessFunction(unittest.TestCase):
39+
class TestCausalSurrogate(unittest.TestCase):
4040

4141
@classmethod
4242
def setUpClass(cls) -> None:
@@ -49,17 +49,6 @@ def setUp(self):
4949
with open(self.dag_dot_path, "w") as f:
5050
f.write(dag_dot)
5151

52-
def test_init_valid_values(self):
53-
54-
test_function = lambda x: x **2
55-
56-
surrogate_model = CubicSplineRegressionEstimator("", 0, 0, set(), "", 4)
57-
58-
search_function = SearchFitnessFunction(fitness_function=test_function, surrogate_model=surrogate_model)
59-
60-
self.assertTrue(callable(search_function.fitness_function))
61-
self.assertIsInstance(search_function.surrogate_model, CubicSplineRegressionEstimator)
62-
6352
def test_surrogate_model_generation(self):
6453
c_s_a_test_case = CausalSurrogateAssistedTestCase(None, None, None)
6554

0 commit comments

Comments
 (0)