@@ -96,21 +96,30 @@ def test_execute_test_suite_single_base_test_case(self):
96
96
causal_test_engine = self .create_causal_test_engine ()
97
97
98
98
causal_test_results = causal_test_engine .execute_test_suite (test_suite = self .test_suite )
99
- causal_test_case = causal_test_results [self .base_test_case ]
100
- self .assertAlmostEqual (causal_test_case ['LinearRegressionEstimator' ][0 ].ate , 4 , delta = 1e-10 )
101
-
102
- def test_execute_test_suite_multiple_base_test_cases (self ):
103
- pass
99
+ causal_test_case_result = causal_test_results [self .base_test_case ]
100
+ self .assertAlmostEqual (causal_test_case_result ['LinearRegressionEstimator' ][0 ].ate , 4 , delta = 1e-10 )
104
101
105
102
def test_execute_test_suite_multiple_estimators (self ):
106
- #estimators = [LinearRegressionEstimator, CausalForestEstimator]
107
- #test_suite_2_estimators = CausalTestSuite
108
- #test_suite_2_estimators
103
+ estimators = [LinearRegressionEstimator , CausalForestEstimator ]
104
+ test_suite_2_estimators = CausalTestSuite ()
105
+ test_list = [CausalTestCase (self .base_test_case ,
106
+ self .expected_causal_effect ,
107
+ 0 ,
108
+ 1 )]
109
+ test_suite_2_estimators .add_test_object (base_test_case = self .base_test_case ,
110
+ causal_test_case_list = test_list ,
111
+ estimators = estimators )
109
112
causal_test_engine = self .create_causal_test_engine ()
110
- def create_causal_test_engine (self ):
113
+ causal_test_results = causal_test_engine .execute_test_suite (test_suite = test_suite_2_estimators )
114
+ causal_test_case_result = causal_test_results [self .base_test_case ]
115
+ linear_regression_result = causal_test_case_result ['LinearRegressionEstimator' ][0 ]
116
+ causal_forrest_result = causal_test_case_result ['CausalForestEstimator' ][0 ]
117
+ self .assertAlmostEqual (linear_regression_result .ate , 4 , delta = 1e-1 )
118
+ self .assertAlmostEqual (causal_forrest_result .ate , 4 , delta = 1e-1 )
111
119
120
+ def create_causal_test_engine (self ):
112
121
causal_specification = CausalSpecification (self .scenario , self .causal_dag )
113
122
114
123
data_collector = ObservationalDataCollector (self .scenario , self .observational_data_csv_path )
115
124
causal_test_engine = CausalTestEngine (causal_specification , data_collector )
116
- return causal_test_engine
125
+ return causal_test_engine
0 commit comments