@@ -38,10 +38,10 @@ def setUp(self) -> None:
38
38
self .expected_causal_effect ,
39
39
0 ,
40
40
2 )]
41
- estimators = [LinearRegressionEstimator ]
41
+ self . estimators = [LinearRegressionEstimator ]
42
42
self .test_suite .add_test_object (base_test_case = self .base_test_case ,
43
43
causal_test_case_list = test_list ,
44
- estimators = estimators )
44
+ estimators = self . estimators )
45
45
46
46
temp_dir_path = create_temp_dir_if_non_existent ()
47
47
dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
@@ -91,18 +91,24 @@ def test_return_single_test_object(self):
91
91
92
92
self .assertEqual (test_case , manual_test_case )
93
93
94
- def test_execute_test_suite_single_test_case (self ):
95
- causal_specification = CausalSpecification (self .scenario , self .causal_dag )
96
-
97
- data_collector = ObservationalDataCollector (self .scenario , self .observational_data_csv_path )
98
- causal_test_engine = CausalTestEngine (causal_specification , data_collector )
94
+ def test_execute_test_suite_single_base_test_case (self ):
95
+ """Check that the test suite can return the correct results from dummy data for a single base_test-case"""
96
+ causal_test_engine = self .create_causal_test_engine ()
99
97
100
98
causal_test_results = causal_test_engine .execute_test_suite (test_suite = self .test_suite )
101
- breakpoint ()
102
- self .assertAlmostEqual (causal_test_results [ self . base_test_case ][ 0 ][0 ].ate , 4 , delta = 1e-10 )
99
+ causal_test_case = causal_test_results [ self . base_test_case ]
100
+ self .assertAlmostEqual (causal_test_case [ 'LinearRegressionEstimator' ][0 ].ate , 4 , delta = 1e-10 )
103
101
104
- def test_execute_test_suite_multiple_test_case (self ):
102
+ def test_execute_test_suite_multiple_base_test_cases (self ):
105
103
pass
106
104
107
105
def test_execute_test_suite_multiple_estimators (self ):
108
106
pass
107
+
108
+ def create_causal_test_engine (self ):
109
+
110
+ causal_specification = CausalSpecification (self .scenario , self .causal_dag )
111
+
112
+ data_collector = ObservationalDataCollector (self .scenario , self .observational_data_csv_path )
113
+ causal_test_engine = CausalTestEngine (causal_specification , data_collector )
114
+ return causal_test_engine
0 commit comments