Skip to content

Commit c8ebb29

Browse files
committed
Test coverage
1 parent ae8b86c commit c8ebb29

File tree

1 file changed

+112
-3
lines changed

1 file changed

+112
-3
lines changed

tests/surrogate_tests/test_causal_surrogate_assisted.py

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from causal_testing.specification.causal_specification import CausalSpecification
55
from causal_testing.specification.scenario import Scenario
66
from causal_testing.specification.variable import Input
7-
from causal_testing.surrogate.causal_surrogate_assisted import SimulationResult, SearchFitnessFunction, CausalSurrogateAssistedTestCase, Simulator
7+
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm, SimulationResult, SearchFitnessFunction, CausalSurrogateAssistedTestCase, Simulator
88
from causal_testing.surrogate.surrogate_search_algorithms import GeneticSearchAlgorithm
99
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
1010
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
@@ -113,6 +113,101 @@ def test_causal_surrogate_assisted_execution(self):
113113
self.assertEqual(iterations, 1)
114114
self.assertEqual(len(result_data), 17)
115115

116+
def test_causal_surrogate_assisted_execution_failure(self):
117+
df = self.class_df.copy()
118+
119+
causal_dag = CausalDAG(self.dag_dot_path)
120+
z = Input("Z", int)
121+
x = Input("X", int)
122+
m = Input("M", int)
123+
y = Input("Y", int)
124+
scenario = Scenario(variables={z, x, m, y}, constraints={
125+
z <= 0, z >= 3,
126+
x <= 0, x >= 3,
127+
m <= 0, m >= 3
128+
})
129+
specification = CausalSpecification(scenario, causal_dag)
130+
131+
search_algorithm = GeneticSearchAlgorithm(config= {
132+
"parent_selection_type": "tournament",
133+
"K_tournament": 4,
134+
"mutation_type": "random",
135+
"mutation_percent_genes": 50,
136+
"mutation_by_replacement": True,
137+
})
138+
simulator = TestSimulatorFailing()
139+
140+
c_s_a_test_case = CausalSurrogateAssistedTestCase(specification, search_algorithm, simulator)
141+
142+
result, iterations, result_data = c_s_a_test_case.execute(ObservationalDataCollector(scenario, df), 1)
143+
144+
self.assertIsInstance(result, str)
145+
self.assertEqual(iterations, 1)
146+
self.assertEqual(len(result_data), 17)
147+
148+
def test_causal_surrogate_assisted_execution_custom_aggregator(self):
149+
df = self.class_df.copy()
150+
151+
causal_dag = CausalDAG(self.dag_dot_path)
152+
z = Input("Z", int)
153+
x = Input("X", int)
154+
m = Input("M", int)
155+
y = Input("Y", int)
156+
scenario = Scenario(variables={z, x, m, y}, constraints={
157+
z <= 0, z >= 3,
158+
x <= 0, x >= 3,
159+
m <= 0, m >= 3
160+
})
161+
specification = CausalSpecification(scenario, causal_dag)
162+
163+
search_algorithm = GeneticSearchAlgorithm(config= {
164+
"parent_selection_type": "tournament",
165+
"K_tournament": 4,
166+
"mutation_type": "random",
167+
"mutation_percent_genes": 50,
168+
"mutation_by_replacement": True,
169+
})
170+
simulator = TestSimulator()
171+
172+
c_s_a_test_case = CausalSurrogateAssistedTestCase(specification, search_algorithm, simulator)
173+
174+
result, iterations, result_data = c_s_a_test_case.execute(ObservationalDataCollector(scenario, df),
175+
custom_data_aggregator=data_double_aggregator)
176+
177+
self.assertIsInstance(result, SimulationResult)
178+
self.assertEqual(iterations, 1)
179+
self.assertEqual(len(result_data), 18)
180+
181+
def test_causal_surrogate_assisted_execution_incorrect_search_config(self):
182+
df = self.class_df.copy()
183+
184+
causal_dag = CausalDAG(self.dag_dot_path)
185+
z = Input("Z", int)
186+
x = Input("X", int)
187+
m = Input("M", int)
188+
y = Input("Y", int)
189+
scenario = Scenario(variables={z, x, m, y}, constraints={
190+
z <= 0, z >= 3,
191+
x <= 0, x >= 3,
192+
m <= 0, m >= 3
193+
})
194+
specification = CausalSpecification(scenario, causal_dag)
195+
196+
search_algorithm = GeneticSearchAlgorithm(config= {
197+
"parent_selection_type": "tournament",
198+
"K_tournament": 4,
199+
"mutation_type": "random",
200+
"mutation_percent_genes": 50,
201+
"mutation_by_replacement": True,
202+
"gene_space": "Something"
203+
})
204+
simulator = TestSimulator()
205+
206+
c_s_a_test_case = CausalSurrogateAssistedTestCase(specification, search_algorithm, simulator)
207+
208+
self.assertRaises(c_s_a_test_case.execute(ValueError, ObservationalDataCollector(scenario, df),
209+
custom_data_aggregator=data_double_aggregator))
210+
116211
def tearDown(self) -> None:
117212
remove_temp_dir_if_existent()
118213

@@ -122,7 +217,7 @@ def load_class_df():
122217
class_df = pd.DataFrame({"Z": np.arange(16), "X": np.arange(16), "M": np.arange(16, 32), "Y": np.arange(32,16,-1)})
123218
return class_df
124219

125-
class TestSimulator():
220+
class TestSimulator(Simulator):
126221

127222
def run_with_config(self, configuration: dict) -> SimulationResult:
128223
return SimulationResult({"Z": 1, "X": 1, "M": 1, "Y": 1}, True, None)
@@ -131,4 +226,18 @@ def startup(self):
131226
pass
132227

133228
def shutdown(self):
134-
pass
229+
pass
230+
231+
class TestSimulatorFailing(Simulator):
232+
233+
def run_with_config(self, configuration: dict) -> SimulationResult:
234+
return SimulationResult({"Z": 1, "X": 1, "M": 1, "Y": 1}, False, None)
235+
236+
def startup(self):
237+
pass
238+
239+
def shutdown(self):
240+
pass
241+
242+
def data_double_aggregator(data, new_data):
243+
return data.append(new_data, ignore_index=True).append(new_data, ignore_index=True)

0 commit comments

Comments
 (0)