11
11
from causal_testing .testing .causal_test_engine import CausalTestEngine
12
12
from causal_testing .testing .causal_test_outcome import ExactValue
13
13
from causal_testing .testing .estimators import CausalForestEstimator , LinearRegressionEstimator
14
-
14
+ from causal_testing . testing . base_test_case import BaseTestCase
15
15
16
16
class TestCausalTestEngineObservational (unittest .TestCase ):
17
17
""" Test the CausalTestEngine workflow using observational data.
@@ -41,11 +41,12 @@ def setUp(self) -> None:
41
41
42
42
# 3. Create a causal test case
43
43
self .expected_causal_effect = ExactValue (4 )
44
+ self .base_test_case = BaseTestCase (A , C )
44
45
self .causal_test_case = CausalTestCase (
45
- control_input_configuration = { A : 0 } ,
46
+ base_test_case = self . base_test_case ,
46
47
expected_causal_effect = self .expected_causal_effect ,
47
- treatment_input_configuration = { A : 1 } ,
48
- outcome_variables = { C } )
48
+ control_value = 0 ,
49
+ treatment_value = 1 )
49
50
50
51
# 4. Create dummy test data and write to csv
51
52
np .random .seed (1 )
@@ -64,9 +65,7 @@ def setUp(self) -> None:
64
65
self .causal_specification ,
65
66
self .data_collector
66
67
)
67
-
68
- self .causal_test_engine .identification (self .causal_test_case )
69
- self .minimal_adjustment_set = self .causal_test_engine .minimal_adjustment_set
68
+ self .minimal_adjustment_set = self .causal_dag .identification (self .base_test_case )
70
69
# 6. Easier to access treatment and outcome values
71
70
self .treatment_value = 1
72
71
self .control_value = 0
@@ -114,9 +113,8 @@ def test_check_positivity_violation_missing_outcome(self):
114
113
115
114
def test_check_minimum_adjustment_set (self ):
116
115
""" Check that the minimum adjustment set is correctly made"""
117
- self .causal_test_engine .identification (self .causal_test_case )
118
- minimum_adjustment_set = self .causal_test_engine .minimal_adjustment_set
119
- self .assertEqual (minimum_adjustment_set , {'D' })
116
+ minimal_adjustment_set = self .causal_dag .identification (self .base_test_case )
117
+ self .assertEqual (minimal_adjustment_set , {'D' })
120
118
121
119
def test_execute_test_observational_causal_forest_estimator (self ):
122
120
""" Check that executing the causal test case returns the correct results for the dummy data using a causal
@@ -134,19 +132,14 @@ def test_execute_test_observational_causal_forest_estimator(self):
134
132
def test_invalid_causal_effect (self ):
135
133
""" Check that executing the causal test case returns the correct results for dummy data using a linear
136
134
regression estimator. """
137
- causal_test_case = CausalTestCase (
138
- control_input_configuration = {self .A : 0 },
139
- expected_causal_effect = self .expected_causal_effect ,
140
- treatment_input_configuration = {self .A : 1 },
141
- outcome_variables = {self .C },
142
- effect = "error" )
143
- # 5. Create causal test engine
144
- causal_test_engine = CausalTestEngine (
145
- self .causal_specification ,
146
- self .data_collector
135
+ base_test_case = BaseTestCase (
136
+ treatment_variable = self .A ,
137
+ outcome_variable = self .C ,
138
+ effect = "error"
147
139
)
140
+
148
141
with self .assertRaises (Exception ):
149
- causal_test_engine . identification ()
142
+ self . causal_dag . identification (base_test_case )
150
143
151
144
152
145
def test_execute_test_observational_linear_regression_estimator (self ):
@@ -165,27 +158,30 @@ def test_execute_test_observational_linear_regression_estimator(self):
165
158
def test_execute_test_observational_linear_regression_estimator_direct_effect (self ):
166
159
""" Check that executing the causal test case returns the correct results for dummy data using a linear
167
160
regression estimator. """
161
+ base_test_case = BaseTestCase (
162
+ treatment_variable = self .A ,
163
+ outcome_variable = self .C ,
164
+ effect = "direct" )
165
+
168
166
causal_test_case = CausalTestCase (
169
- control_input_configuration = { self . A : 0 } ,
167
+ base_test_case = base_test_case ,
170
168
expected_causal_effect = self .expected_causal_effect ,
171
- treatment_input_configuration = {self .A : 1 },
172
- outcome_variables = {self .C },
173
- effect = "direct" )
169
+ control_value = 0 ,
170
+ treatment_value = 1 )
174
171
175
172
# 5. Create causal test engine
176
173
causal_test_engine = CausalTestEngine (
177
174
self .causal_specification ,
178
175
self .data_collector
179
176
)
180
- causal_test_engine .identification (causal_test_case )
181
- self .minimal_adjustment_set = causal_test_engine .minimal_adjustment_set
177
+ minimal_adjustment_set = self .causal_dag .identification (base_test_case )
182
178
# 6. Easier to access treatment and outcome values
183
179
self .treatment_value = 1
184
180
self .control_value = 0
185
181
estimation_model = LinearRegressionEstimator (('A' ,),
186
182
self .treatment_value ,
187
183
self .control_value ,
188
- self . minimal_adjustment_set ,
184
+ minimal_adjustment_set ,
189
185
('C' ,),
190
186
causal_test_engine .scenario_execution_data_df )
191
187
causal_test_result = causal_test_engine .execute_test (estimation_model , causal_test_case )
0 commit comments