1
- from causal_testing . specification . causal_specification import CausalSpecification
2
- from causal_testing . testing . estimators import Estimator , PolynomialRegressionEstimator
3
- from causal_testing . surrogate . causal_surrogate_assisted import SearchAlgorithm , SearchFitnessFunction
1
+ """Module containing implementation of search algorithm for surrogate search """
2
+ # pylint: disable=cell-var-from-loop
3
+ # Fitness functions are required to be iteratively defined, including all variables within.
4
4
5
- from pygad import GA
6
5
from operator import itemgetter
6
+ from pygad import GA
7
+
8
+ from causal_testing .specification .causal_specification import CausalSpecification
9
+ from causal_testing .testing .estimators import CubicSplineRegressionEstimator
10
+ from causal_testing .surrogate .causal_surrogate_assisted import SearchAlgorithm
7
11
8
12
9
13
class GeneticSearchAlgorithm (SearchAlgorithm ):
14
+ """Implementation of SearchAlgorithm class. Implements genetic search algorithm for surrogate models."""
15
+
10
16
def __init__ (self , delta = 0.05 , config : dict = None ) -> None :
11
17
super ().__init__ ()
12
18
@@ -15,91 +21,93 @@ def __init__(self, delta=0.05, config: dict = None) -> None:
15
21
self .contradiction_functions = {
16
22
"positive" : lambda x : - 1 * x ,
17
23
"negative" : lambda x : x ,
18
- "no_effect" : lambda x : abs ( x ) ,
24
+ "no_effect" : abs ,
19
25
"some_effect" : lambda x : abs (1 / x ),
20
26
}
21
27
22
- def generate_fitness_functions (
23
- self , surrogate_models : list [PolynomialRegressionEstimator ]
24
- ) -> list [ SearchFitnessFunction ] :
25
- fitness_functions = []
28
+ def search (
29
+ self , surrogate_models : list [CubicSplineRegressionEstimator ], specification : CausalSpecification
30
+ ) -> list :
31
+ solutions = []
26
32
27
33
for surrogate in surrogate_models :
28
34
contradiction_function = self .contradiction_functions [surrogate .expected_relationship ]
29
35
30
- # The returned fitness function after including required variables into the function's scope
31
- def fitness_function (_ga , solution , idx ):
36
+ # The GA fitness function after including required variables into the function's scope
37
+ # Unused arguments are required for pygad's fitness function signature
38
+ def fitness_function (ga , solution , idx ): # pylint: disable=unused-argument
32
39
surrogate .control_value = solution [0 ] - self .delta
33
40
surrogate .treatment_value = solution [0 ] + self .delta
34
41
35
- adjustment_dict = dict ()
42
+ adjustment_dict = {}
36
43
for i , adjustment in enumerate (surrogate .adjustment_set ):
37
44
adjustment_dict [adjustment ] = solution [i + 1 ]
38
45
39
46
ate = surrogate .estimate_ate_calculated (adjustment_dict )
40
47
41
48
return contradiction_function (ate )
42
49
43
- search_fitness_function = SearchFitnessFunction (fitness_function , surrogate )
44
-
45
- fitness_functions .append (search_fitness_function )
46
-
47
- return fitness_functions
48
-
49
- def search (self , fitness_functions : list [SearchFitnessFunction ], specification : CausalSpecification ) -> list :
50
- solutions = []
51
-
52
- for fitness_function in fitness_functions :
53
- var_space = dict ()
54
- var_space [fitness_function .surrogate_model .treatment ] = dict ()
55
- for adj in fitness_function .surrogate_model .adjustment_set :
56
- var_space [adj ] = dict ()
57
-
58
- for relationship in list (specification .scenario .constraints ):
59
- rel_split = str (relationship ).split (" " )
60
-
61
- if rel_split [1 ] == ">=" :
62
- var_space [rel_split [0 ]]["low" ] = int (rel_split [2 ])
63
- elif rel_split [1 ] == "<=" :
64
- var_space [rel_split [0 ]]["high" ] = int (rel_split [2 ])
65
-
66
- gene_space = []
67
- gene_space .append (var_space [fitness_function .surrogate_model .treatment ])
68
- for adj in fitness_function .surrogate_model .adjustment_set :
69
- gene_space .append (var_space [adj ])
70
-
71
- gene_types = []
72
- gene_types .append (specification .scenario .variables .get (fitness_function .surrogate_model .treatment ).datatype )
73
- for adj in fitness_function .surrogate_model .adjustment_set :
74
- gene_types .append (specification .scenario .variables .get (adj ).datatype )
50
+ gene_types , gene_space = self .create_gene_types (surrogate , specification )
75
51
76
52
ga = GA (
77
53
num_generations = 200 ,
78
54
num_parents_mating = 4 ,
79
- fitness_func = fitness_function . fitness_function ,
55
+ fitness_func = fitness_function ,
80
56
sol_per_pop = 10 ,
81
- num_genes = 1 + len (fitness_function . surrogate_model .adjustment_set ),
57
+ num_genes = 1 + len (surrogate .adjustment_set ),
82
58
gene_space = gene_space ,
83
59
gene_type = gene_types ,
84
60
)
85
61
86
62
if self .config is not None :
87
63
for k , v in self .config .items ():
88
64
if k == "gene_space" :
89
- raise Exception (
90
- "Gene space should not be set through config. This is generated from the causal specification"
65
+ raise ValueError (
66
+ "Gene space should not be set through config. This is generated from the causal "
67
+ "specification"
91
68
)
92
69
setattr (ga , k , v )
93
70
94
71
ga .run ()
95
- solution , fitness , _idx = ga .best_solution ()
72
+ solution , fitness , _ = ga .best_solution ()
96
73
97
- solution_dict = dict ()
98
- solution_dict [fitness_function . surrogate_model .treatment ] = solution [0 ]
99
- for idx , adj in enumerate (fitness_function . surrogate_model .adjustment_set ):
74
+ solution_dict = {}
75
+ solution_dict [surrogate .treatment ] = solution [0 ]
76
+ for idx , adj in enumerate (surrogate .adjustment_set ):
100
77
solution_dict [adj ] = solution [idx + 1 ]
101
- solutions .append ((solution_dict , fitness , fitness_function .surrogate_model ))
78
+ solutions .append ((solution_dict , fitness , surrogate ))
79
+
80
+ return max (solutions , key = itemgetter (1 )) # This can be done better with fitness normalisation between edges
81
+
82
+ @staticmethod
83
+ def create_gene_types (
84
+ surrogate_model : CubicSplineRegressionEstimator , specification : CausalSpecification
85
+ ) -> tuple [list , list ]:
86
+ """Generate the gene_types and gene_space for a given fitness function and specification
87
+ :param surrogate_model: Instance of a CubicSplineRegressionEstimator
88
+ :param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
89
+
90
+ var_space = {}
91
+ var_space [surrogate_model .treatment ] = {}
92
+ for adj in surrogate_model .adjustment_set :
93
+ var_space [adj ] = {}
94
+
95
+ for relationship in list (specification .scenario .constraints ):
96
+ rel_split = str (relationship ).split (" " )
97
+
98
+ if rel_split [0 ] in var_space :
99
+ if rel_split [1 ] == ">=" :
100
+ var_space [rel_split [0 ]]["low" ] = int (rel_split [2 ])
101
+ elif rel_split [1 ] == "<=" :
102
+ var_space [rel_split [0 ]]["high" ] = int (rel_split [2 ])
103
+
104
+ gene_space = []
105
+ gene_space .append (var_space [surrogate_model .treatment ])
106
+ for adj in surrogate_model .adjustment_set :
107
+ gene_space .append (var_space [adj ])
102
108
103
- return max (
104
- solutions , key = itemgetter (1 )
105
- ) # TODO This can be done better with fitness normalisation between edges
109
+ gene_types = []
110
+ gene_types .append (specification .scenario .variables .get (surrogate_model .treatment ).datatype )
111
+ for adj in surrogate_model .adjustment_set :
112
+ gene_types .append (specification .scenario .variables .get (adj ).datatype )
113
+ return gene_types , gene_space
0 commit comments