Skip to content

Commit c920b97

Browse files
Add more tests
1 parent daa6b10 commit c920b97

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

tests/testing_tests/test_causal_test_suite.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,30 @@ def test_execute_test_suite_single_base_test_case(self):
9696
causal_test_engine = self.create_causal_test_engine()
9797

9898
causal_test_results = causal_test_engine.execute_test_suite(test_suite=self.test_suite)
99-
causal_test_case = causal_test_results[self.base_test_case]
100-
self.assertAlmostEqual(causal_test_case['LinearRegressionEstimator'][0].ate, 4, delta=1e-10)
101-
102-
def test_execute_test_suite_multiple_base_test_cases(self):
103-
pass
99+
causal_test_case_result = causal_test_results[self.base_test_case]
100+
self.assertAlmostEqual(causal_test_case_result['LinearRegressionEstimator'][0].ate, 4, delta=1e-10)
104101

105102
def test_execute_test_suite_multiple_estimators(self):
106-
#estimators = [LinearRegressionEstimator, CausalForestEstimator]
107-
#test_suite_2_estimators = CausalTestSuite
108-
#test_suite_2_estimators
103+
estimators = [LinearRegressionEstimator, CausalForestEstimator]
104+
test_suite_2_estimators = CausalTestSuite()
105+
test_list = [CausalTestCase(self.base_test_case,
106+
self.expected_causal_effect,
107+
0,
108+
1)]
109+
test_suite_2_estimators.add_test_object(base_test_case=self.base_test_case,
110+
causal_test_case_list=test_list,
111+
estimators=estimators)
109112
causal_test_engine = self.create_causal_test_engine()
110-
def create_causal_test_engine(self):
113+
causal_test_results = causal_test_engine.execute_test_suite(test_suite=test_suite_2_estimators)
114+
causal_test_case_result = causal_test_results[self.base_test_case]
115+
linear_regression_result = causal_test_case_result['LinearRegressionEstimator'][0]
116+
causal_forrest_result = causal_test_case_result['CausalForestEstimator'][0]
117+
self.assertAlmostEqual(linear_regression_result.ate, 4, delta=1e-1)
118+
self.assertAlmostEqual(causal_forrest_result.ate, 4, delta=1e-1)
111119

120+
def create_causal_test_engine(self):
112121
causal_specification = CausalSpecification(self.scenario, self.causal_dag)
113122

114123
data_collector = ObservationalDataCollector(self.scenario, self.observational_data_csv_path)
115124
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
116-
return causal_test_engine
125+
return causal_test_engine

0 commit comments

Comments
 (0)