4
4
from pygad import GA
5
5
6
6
from causal_testing .specification .causal_specification import CausalSpecification
7
- from causal_testing .testing .estimators import Estimator , CubicSplineRegressionEstimator
7
+ from causal_testing .testing .estimators import CubicSplineRegressionEstimator
8
8
from causal_testing .surrogate .causal_surrogate_assisted import SearchAlgorithm , SearchFitnessFunction
9
9
10
10
11
11
class GeneticSearchAlgorithm (SearchAlgorithm ):
12
12
""" Implementation of SearchAlgorithm class. Implements genetic search algorithm for surrogate models."""
13
+
13
14
def __init__ (self , delta = 0.05 , config : dict = None ) -> None :
14
15
super ().__init__ ()
15
16
@@ -23,7 +24,7 @@ def __init__(self, delta=0.05, config: dict = None) -> None:
23
24
}
24
25
25
26
def generate_fitness_functions (
26
- self , surrogate_models : list [CubicSplineRegressionEstimator ]
27
+ self , surrogate_models : list [CubicSplineRegressionEstimator ]
27
28
) -> list [SearchFitnessFunction ]:
28
29
fitness_functions = []
29
30
@@ -53,28 +54,8 @@ def search(self, fitness_functions: list[SearchFitnessFunction], specification:
53
54
solutions = []
54
55
55
56
for fitness_function in fitness_functions :
56
- var_space = {}
57
- var_space [fitness_function .surrogate_model .treatment ] = {}
58
- for adj in fitness_function .surrogate_model .adjustment_set :
59
- var_space [adj ] = {}
60
-
61
- for relationship in list (specification .scenario .constraints ):
62
- rel_split = str (relationship ).split (" " )
63
-
64
- if rel_split [1 ] == ">=" :
65
- var_space [rel_split [0 ]]["low" ] = int (rel_split [2 ])
66
- elif rel_split [1 ] == "<=" :
67
- var_space [rel_split [0 ]]["high" ] = int (rel_split [2 ])
68
57
69
- gene_space = []
70
- gene_space .append (var_space [fitness_function .surrogate_model .treatment ])
71
- for adj in fitness_function .surrogate_model .adjustment_set :
72
- gene_space .append (var_space [adj ])
73
-
74
- gene_types = []
75
- gene_types .append (specification .scenario .variables .get (fitness_function .surrogate_model .treatment ).datatype )
76
- for adj in fitness_function .surrogate_model .adjustment_set :
77
- gene_types .append (specification .scenario .variables .get (adj ).datatype )
58
+ gene_types , gene_space = self .create_gene_types (fitness_function , specification )
78
59
79
60
ga = GA (
80
61
num_generations = 200 ,
@@ -105,3 +86,34 @@ def search(self, fitness_functions: list[SearchFitnessFunction], specification:
105
86
solutions .append ((solution_dict , fitness , fitness_function .surrogate_model ))
106
87
107
88
return max (solutions , key = itemgetter (1 )) # This can be done better with fitness normalisation between edges
89
+
90
+ @staticmethod
91
+ def create_gene_types (fitness_function : SearchFitnessFunction , specification : CausalSpecification ) -> tuple [
92
+ list , list ]:
93
+ """Generate the gene_types and gene_space for a given fitness function and specification
94
+ :param fitness_function: Instance of SearchFitnessFunction
95
+ :param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
96
+
97
+ var_space = {}
98
+ var_space [fitness_function .surrogate_model .treatment ] = {}
99
+ for adj in fitness_function .surrogate_model .adjustment_set :
100
+ var_space [adj ] = {}
101
+
102
+ for relationship in list (specification .scenario .constraints ):
103
+ rel_split = str (relationship ).split (" " )
104
+
105
+ if rel_split [1 ] == ">=" :
106
+ var_space [rel_split [0 ]]["low" ] = int (rel_split [2 ])
107
+ elif rel_split [1 ] == "<=" :
108
+ var_space [rel_split [0 ]]["high" ] = int (rel_split [2 ])
109
+
110
+ gene_space = []
111
+ gene_space .append (var_space [fitness_function .surrogate_model .treatment ])
112
+ for adj in fitness_function .surrogate_model .adjustment_set :
113
+ gene_space .append (var_space [adj ])
114
+
115
+ gene_types = []
116
+ gene_types .append (specification .scenario .variables .get (fitness_function .surrogate_model .treatment ).datatype )
117
+ for adj in fitness_function .surrogate_model .adjustment_set :
118
+ gene_types .append (specification .scenario .variables .get (adj ).datatype )
119
+ return gene_types , gene_space
0 commit comments