Skip to content

Commit b01ab8a

Browse files
Update test_suite tests for new index
1 parent efe57c0 commit b01ab8a

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

tests/testing_tests/test_causal_test_suite.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ def setUp(self) -> None:
3838
self.expected_causal_effect,
3939
0,
4040
2)]
41-
estimators = [LinearRegressionEstimator]
41+
self.estimators = [LinearRegressionEstimator]
4242
self.test_suite.add_test_object(base_test_case=self.base_test_case,
4343
causal_test_case_list=test_list,
44-
estimators=estimators)
44+
estimators=self.estimators)
4545

4646
temp_dir_path = create_temp_dir_if_non_existent()
4747
dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
@@ -91,18 +91,24 @@ def test_return_single_test_object(self):
9191

9292
self.assertEqual(test_case, manual_test_case)
9393

94-
def test_execute_test_suite_single_test_case(self):
95-
causal_specification = CausalSpecification(self.scenario, self.causal_dag)
96-
97-
data_collector = ObservationalDataCollector(self.scenario, self.observational_data_csv_path)
98-
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
94+
def test_execute_test_suite_single_base_test_case(self):
95+
"""Check that the test suite can return the correct results from dummy data for a single base_test-case"""
96+
causal_test_engine = self.create_causal_test_engine()
9997

10098
causal_test_results = causal_test_engine.execute_test_suite(test_suite=self.test_suite)
101-
breakpoint()
102-
self.assertAlmostEqual(causal_test_results[self.base_test_case][0][0].ate, 4, delta=1e-10)
99+
causal_test_case = causal_test_results[self.base_test_case]
100+
self.assertAlmostEqual(causal_test_case['LinearRegressionEstimator'][0].ate, 4, delta=1e-10)
103101

104-
def test_execute_test_suite_multiple_test_case(self):
102+
def test_execute_test_suite_multiple_base_test_cases(self):
105103
pass
106104

107105
def test_execute_test_suite_multiple_estimators(self):
108106
pass
107+
108+
def create_causal_test_engine(self):
109+
110+
causal_specification = CausalSpecification(self.scenario, self.causal_dag)
111+
112+
data_collector = ObservationalDataCollector(self.scenario, self.observational_data_csv_path)
113+
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
114+
return causal_test_engine

0 commit comments

Comments
 (0)