Skip to content

Commit 2d473e5

Browse files
Move functionality into static method called create_gene_types
1 parent 1e22537 commit 2d473e5

File tree

1 file changed

+35
-23
lines changed

1 file changed

+35
-23
lines changed

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
from pygad import GA
55

66
from causal_testing.specification.causal_specification import CausalSpecification
7-
from causal_testing.testing.estimators import Estimator, CubicSplineRegressionEstimator
7+
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
88
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm, SearchFitnessFunction
99

1010

1111
class GeneticSearchAlgorithm(SearchAlgorithm):
1212
""" Implementation of SearchAlgorithm class. Implements genetic search algorithm for surrogate models."""
13+
1314
def __init__(self, delta=0.05, config: dict = None) -> None:
1415
super().__init__()
1516

@@ -23,7 +24,7 @@ def __init__(self, delta=0.05, config: dict = None) -> None:
2324
}
2425

2526
def generate_fitness_functions(
26-
self, surrogate_models: list[CubicSplineRegressionEstimator]
27+
self, surrogate_models: list[CubicSplineRegressionEstimator]
2728
) -> list[SearchFitnessFunction]:
2829
fitness_functions = []
2930

@@ -53,28 +54,8 @@ def search(self, fitness_functions: list[SearchFitnessFunction], specification:
5354
solutions = []
5455

5556
for fitness_function in fitness_functions:
56-
var_space = {}
57-
var_space[fitness_function.surrogate_model.treatment] = {}
58-
for adj in fitness_function.surrogate_model.adjustment_set:
59-
var_space[adj] = {}
60-
61-
for relationship in list(specification.scenario.constraints):
62-
rel_split = str(relationship).split(" ")
63-
64-
if rel_split[1] == ">=":
65-
var_space[rel_split[0]]["low"] = int(rel_split[2])
66-
elif rel_split[1] == "<=":
67-
var_space[rel_split[0]]["high"] = int(rel_split[2])
6857

69-
gene_space = []
70-
gene_space.append(var_space[fitness_function.surrogate_model.treatment])
71-
for adj in fitness_function.surrogate_model.adjustment_set:
72-
gene_space.append(var_space[adj])
73-
74-
gene_types = []
75-
gene_types.append(specification.scenario.variables.get(fitness_function.surrogate_model.treatment).datatype)
76-
for adj in fitness_function.surrogate_model.adjustment_set:
77-
gene_types.append(specification.scenario.variables.get(adj).datatype)
58+
gene_types, gene_space = self.create_gene_types(fitness_function, specification)
7859

7960
ga = GA(
8061
num_generations=200,
@@ -105,3 +86,34 @@ def search(self, fitness_functions: list[SearchFitnessFunction], specification:
10586
solutions.append((solution_dict, fitness, fitness_function.surrogate_model))
10687

10788
return max(solutions, key=itemgetter(1)) # This can be done better with fitness normalisation between edges
89+
90+
@staticmethod
91+
def create_gene_types(fitness_function: SearchFitnessFunction, specification: CausalSpecification) -> tuple[
92+
list, list]:
93+
"""Generate the gene_types and gene_space for a given fitness function and specification
94+
:param fitness_function: Instance of SearchFitnessFunction
95+
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
96+
97+
var_space = {}
98+
var_space[fitness_function.surrogate_model.treatment] = {}
99+
for adj in fitness_function.surrogate_model.adjustment_set:
100+
var_space[adj] = {}
101+
102+
for relationship in list(specification.scenario.constraints):
103+
rel_split = str(relationship).split(" ")
104+
105+
if rel_split[1] == ">=":
106+
var_space[rel_split[0]]["low"] = int(rel_split[2])
107+
elif rel_split[1] == "<=":
108+
var_space[rel_split[0]]["high"] = int(rel_split[2])
109+
110+
gene_space = []
111+
gene_space.append(var_space[fitness_function.surrogate_model.treatment])
112+
for adj in fitness_function.surrogate_model.adjustment_set:
113+
gene_space.append(var_space[adj])
114+
115+
gene_types = []
116+
gene_types.append(specification.scenario.variables.get(fitness_function.surrogate_model.treatment).datatype)
117+
for adj in fitness_function.surrogate_model.adjustment_set:
118+
gene_types.append(specification.scenario.variables.get(adj).datatype)
119+
return gene_types, gene_space

0 commit comments

Comments
 (0)