7
7
8
8
from causal_testing .specification .causal_specification import CausalSpecification
9
9
from causal_testing .testing .estimators import CubicSplineRegressionEstimator
10
- from causal_testing .surrogate .causal_surrogate_assisted import SearchAlgorithm , SearchFitnessFunction
10
+ from causal_testing .surrogate .causal_surrogate_assisted import SearchAlgorithm
11
11
12
12
13
13
class GeneticSearchAlgorithm (SearchAlgorithm ):
@@ -25,15 +25,15 @@ def __init__(self, delta=0.05, config: dict = None) -> None:
25
25
"some_effect" : lambda x : abs (1 / x ),
26
26
}
27
27
28
- def generate_fitness_functions (
29
- self , surrogate_models : list [CubicSplineRegressionEstimator ]
30
- ) -> list [ SearchFitnessFunction ] :
31
- fitness_functions = []
28
+ def search (
29
+ self , surrogate_models : list [CubicSplineRegressionEstimator ], specification : CausalSpecification
30
+ ) -> list :
31
+ solutions = []
32
32
33
33
for surrogate in surrogate_models :
34
34
contradiction_function = self .contradiction_functions [surrogate .expected_relationship ]
35
35
36
- # The returned fitness function after including required variables into the function's scope
36
+ # The GA fitness function after including required variables into the function's scope
37
37
# Unused arguments are required for pygad's fitness function signature
38
38
def fitness_function (ga , solution , idx ): # pylint: disable=unused-argument
39
39
surrogate .control_value = solution [0 ] - self .delta
@@ -46,25 +46,15 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
46
46
ate = surrogate .estimate_ate_calculated (adjustment_dict )
47
47
48
48
return contradiction_function (ate )
49
-
50
- search_fitness_function = SearchFitnessFunction (fitness_function , surrogate )
51
-
52
- fitness_functions .append (search_fitness_function )
53
-
54
- return fitness_functions
55
-
56
- def search (self , fitness_functions : list [SearchFitnessFunction ], specification : CausalSpecification ) -> list :
57
- solutions = []
58
-
59
- for fitness_function in fitness_functions :
60
- gene_types , gene_space = self .create_gene_types (fitness_function , specification )
49
+
50
+ gene_types , gene_space = self .create_gene_types (surrogate , specification )
61
51
62
52
ga = GA (
63
53
num_generations = 200 ,
64
54
num_parents_mating = 4 ,
65
- fitness_func = fitness_function . fitness_function ,
55
+ fitness_func = fitness_function ,
66
56
sol_per_pop = 10 ,
67
- num_genes = 1 + len (fitness_function . surrogate_model .adjustment_set ),
57
+ num_genes = 1 + len (surrogate .adjustment_set ),
68
58
gene_space = gene_space ,
69
59
gene_type = gene_types ,
70
60
)
@@ -82,24 +72,24 @@ def search(self, fitness_functions: list[SearchFitnessFunction], specification:
82
72
solution , fitness , _ = ga .best_solution ()
83
73
84
74
solution_dict = {}
85
- solution_dict [fitness_function . surrogate_model .treatment ] = solution [0 ]
86
- for idx , adj in enumerate (fitness_function . surrogate_model .adjustment_set ):
75
+ solution_dict [surrogate .treatment ] = solution [0 ]
76
+ for idx , adj in enumerate (surrogate .adjustment_set ):
87
77
solution_dict [adj ] = solution [idx + 1 ]
88
- solutions .append ((solution_dict , fitness , fitness_function . surrogate_model ))
78
+ solutions .append ((solution_dict , fitness , surrogate ))
89
79
90
80
return max (solutions , key = itemgetter (1 )) # This can be done better with fitness normalisation between edges
91
81
92
82
@staticmethod
93
83
def create_gene_types (
94
- fitness_function : SearchFitnessFunction , specification : CausalSpecification
84
+ surrogate_model : CubicSplineRegressionEstimator , specification : CausalSpecification
95
85
) -> tuple [list , list ]:
96
86
"""Generate the gene_types and gene_space for a given fitness function and specification
97
- :param fitness_function : Instance of SearchFitnessFunction
87
+ :param surrogate_model : Instance of a CubicSplineRegressionEstimator
98
88
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
99
89
100
90
var_space = {}
101
- var_space [fitness_function . surrogate_model .treatment ] = {}
102
- for adj in fitness_function . surrogate_model .adjustment_set :
91
+ var_space [surrogate_model .treatment ] = {}
92
+ for adj in surrogate_model .adjustment_set :
103
93
var_space [adj ] = {}
104
94
105
95
for relationship in list (specification .scenario .constraints ):
@@ -112,12 +102,12 @@ def create_gene_types(
112
102
var_space [rel_split [0 ]]["high" ] = int (rel_split [2 ])
113
103
114
104
gene_space = []
115
- gene_space .append (var_space [fitness_function . surrogate_model .treatment ])
116
- for adj in fitness_function . surrogate_model .adjustment_set :
105
+ gene_space .append (var_space [surrogate_model .treatment ])
106
+ for adj in surrogate_model .adjustment_set :
117
107
gene_space .append (var_space [adj ])
118
108
119
109
gene_types = []
120
- gene_types .append (specification .scenario .variables .get (fitness_function . surrogate_model .treatment ).datatype )
121
- for adj in fitness_function . surrogate_model .adjustment_set :
110
+ gene_types .append (specification .scenario .variables .get (surrogate_model .treatment ).datatype )
111
+ for adj in surrogate_model .adjustment_set :
122
112
gene_types .append (specification .scenario .variables .get (adj ).datatype )
123
113
return gene_types , gene_space
0 commit comments