4
4
from causal_testing .specification .causal_specification import CausalSpecification
5
5
from causal_testing .specification .scenario import Scenario
6
6
from causal_testing .specification .variable import Input
7
- from causal_testing .surrogate .causal_surrogate_assisted import SimulationResult , SearchFitnessFunction , CausalSurrogateAssistedTestCase , Simulator
7
+ from causal_testing .surrogate .causal_surrogate_assisted import SearchAlgorithm , SimulationResult , SearchFitnessFunction , CausalSurrogateAssistedTestCase , Simulator
8
8
from causal_testing .surrogate .surrogate_search_algorithms import GeneticSearchAlgorithm
9
9
from causal_testing .testing .estimators import CubicSplineRegressionEstimator
10
10
from tests .test_helpers import create_temp_dir_if_non_existent , remove_temp_dir_if_existent
@@ -113,6 +113,101 @@ def test_causal_surrogate_assisted_execution(self):
113
113
self .assertEqual (iterations , 1 )
114
114
self .assertEqual (len (result_data ), 17 )
115
115
116
+ def test_causal_surrogate_assisted_execution_failure (self ):
117
+ df = self .class_df .copy ()
118
+
119
+ causal_dag = CausalDAG (self .dag_dot_path )
120
+ z = Input ("Z" , int )
121
+ x = Input ("X" , int )
122
+ m = Input ("M" , int )
123
+ y = Input ("Y" , int )
124
+ scenario = Scenario (variables = {z , x , m , y }, constraints = {
125
+ z <= 0 , z >= 3 ,
126
+ x <= 0 , x >= 3 ,
127
+ m <= 0 , m >= 3
128
+ })
129
+ specification = CausalSpecification (scenario , causal_dag )
130
+
131
+ search_algorithm = GeneticSearchAlgorithm (config = {
132
+ "parent_selection_type" : "tournament" ,
133
+ "K_tournament" : 4 ,
134
+ "mutation_type" : "random" ,
135
+ "mutation_percent_genes" : 50 ,
136
+ "mutation_by_replacement" : True ,
137
+ })
138
+ simulator = TestSimulatorFailing ()
139
+
140
+ c_s_a_test_case = CausalSurrogateAssistedTestCase (specification , search_algorithm , simulator )
141
+
142
+ result , iterations , result_data = c_s_a_test_case .execute (ObservationalDataCollector (scenario , df ), 1 )
143
+
144
+ self .assertIsInstance (result , str )
145
+ self .assertEqual (iterations , 1 )
146
+ self .assertEqual (len (result_data ), 17 )
147
+
148
+ def test_causal_surrogate_assisted_execution_custom_aggregator (self ):
149
+ df = self .class_df .copy ()
150
+
151
+ causal_dag = CausalDAG (self .dag_dot_path )
152
+ z = Input ("Z" , int )
153
+ x = Input ("X" , int )
154
+ m = Input ("M" , int )
155
+ y = Input ("Y" , int )
156
+ scenario = Scenario (variables = {z , x , m , y }, constraints = {
157
+ z <= 0 , z >= 3 ,
158
+ x <= 0 , x >= 3 ,
159
+ m <= 0 , m >= 3
160
+ })
161
+ specification = CausalSpecification (scenario , causal_dag )
162
+
163
+ search_algorithm = GeneticSearchAlgorithm (config = {
164
+ "parent_selection_type" : "tournament" ,
165
+ "K_tournament" : 4 ,
166
+ "mutation_type" : "random" ,
167
+ "mutation_percent_genes" : 50 ,
168
+ "mutation_by_replacement" : True ,
169
+ })
170
+ simulator = TestSimulator ()
171
+
172
+ c_s_a_test_case = CausalSurrogateAssistedTestCase (specification , search_algorithm , simulator )
173
+
174
+ result , iterations , result_data = c_s_a_test_case .execute (ObservationalDataCollector (scenario , df ),
175
+ custom_data_aggregator = data_double_aggregator )
176
+
177
+ self .assertIsInstance (result , SimulationResult )
178
+ self .assertEqual (iterations , 1 )
179
+ self .assertEqual (len (result_data ), 18 )
180
+
181
+ def test_causal_surrogate_assisted_execution_incorrect_search_config (self ):
182
+ df = self .class_df .copy ()
183
+
184
+ causal_dag = CausalDAG (self .dag_dot_path )
185
+ z = Input ("Z" , int )
186
+ x = Input ("X" , int )
187
+ m = Input ("M" , int )
188
+ y = Input ("Y" , int )
189
+ scenario = Scenario (variables = {z , x , m , y }, constraints = {
190
+ z <= 0 , z >= 3 ,
191
+ x <= 0 , x >= 3 ,
192
+ m <= 0 , m >= 3
193
+ })
194
+ specification = CausalSpecification (scenario , causal_dag )
195
+
196
+ search_algorithm = GeneticSearchAlgorithm (config = {
197
+ "parent_selection_type" : "tournament" ,
198
+ "K_tournament" : 4 ,
199
+ "mutation_type" : "random" ,
200
+ "mutation_percent_genes" : 50 ,
201
+ "mutation_by_replacement" : True ,
202
+ "gene_space" : "Something"
203
+ })
204
+ simulator = TestSimulator ()
205
+
206
+ c_s_a_test_case = CausalSurrogateAssistedTestCase (specification , search_algorithm , simulator )
207
+
208
+ self .assertRaises (c_s_a_test_case .execute (ValueError , ObservationalDataCollector (scenario , df ),
209
+ custom_data_aggregator = data_double_aggregator ))
210
+
116
211
def tearDown (self ) -> None :
117
212
remove_temp_dir_if_existent ()
118
213
@@ -122,7 +217,7 @@ def load_class_df():
122
217
class_df = pd .DataFrame ({"Z" : np .arange (16 ), "X" : np .arange (16 ), "M" : np .arange (16 , 32 ), "Y" : np .arange (32 ,16 ,- 1 )})
123
218
return class_df
124
219
125
- class TestSimulator ():
220
+ class TestSimulator (Simulator ):
126
221
127
222
def run_with_config (self , configuration : dict ) -> SimulationResult :
128
223
return SimulationResult ({"Z" : 1 , "X" : 1 , "M" : 1 , "Y" : 1 }, True , None )
@@ -131,4 +226,18 @@ def startup(self):
131
226
pass
132
227
133
228
def shutdown (self ):
134
- pass
229
+ pass
230
+
231
+ class TestSimulatorFailing (Simulator ):
232
+
233
+ def run_with_config (self , configuration : dict ) -> SimulationResult :
234
+ return SimulationResult ({"Z" : 1 , "X" : 1 , "M" : 1 , "Y" : 1 }, False , None )
235
+
236
+ def startup (self ):
237
+ pass
238
+
239
+ def shutdown (self ):
240
+ pass
241
+
242
+ def data_double_aggregator (data , new_data ):
243
+ return data .append (new_data , ignore_index = True ).append (new_data , ignore_index = True )
0 commit comments