12
12
from causal_testing .testing .causal_test_outcome import ExactValue
13
13
from causal_testing .testing .estimators import CausalForestEstimator , LinearRegressionEstimator
14
14
from causal_testing .testing .base_test_case import BaseTestCase
15
-
15
+ from causal_testing . testing . causal_test_result import CausalTestResult , TestValue
16
16
17
17
class TestCausalTestEngineObservational (unittest .TestCase ):
18
18
"""Test the CausalTestEngine workflow using observational data.
@@ -60,27 +60,27 @@ def setUp(self) -> None:
60
60
# 5. Create observational data collector
61
61
# Obsolete?
62
62
self .data_collector = ObservationalDataCollector (self .scenario , df )
63
-
63
+ self . df = self . data_collector . collect_data ()
64
64
# 5. Create causal test engine
65
- self .causal_test_engine = CausalTestEngine (self .causal_specification , self .data_collector )
65
+ # self.causal_test_engine = CausalTestEngine(self.causal_specification, self.data_collector)
66
66
self .minimal_adjustment_set = self .causal_dag .identification (self .base_test_case )
67
67
# 6. Easier to access treatment and outcome values
68
68
self .treatment_value = 1
69
69
self .control_value = 0
70
70
71
71
def test_positivity_violation_throws_exception (self ):
72
- causal_test_engine = CausalTestEngine ( self .causal_specification , self . data_collector )
73
- causal_test_engine . scenario_execution_data_df .drop ("A" , axis = 1 , inplace = True )
72
+ data_collector = self .data_collector
73
+ data_collector . data .drop ("A" , axis = 1 , inplace = True )
74
74
estimation_model = LinearRegressionEstimator (
75
75
"A" ,
76
76
self .treatment_value ,
77
77
self .control_value ,
78
78
self .minimal_adjustment_set ,
79
79
"C" ,
80
- self .causal_test_engine . scenario_execution_data_df ,
80
+ self .df
81
81
)
82
- with self .assertRaises (Exception ):
83
- causal_test_engine . execute_test (estimation_model )
82
+ with self .assertRaises (ValueError ):
83
+ self . causal_test_case . execute_test (estimation_model , self . data_collector , self . causal_specification )
84
84
85
85
def test_check_no_positivity_violation (self ):
86
86
"""Check that no positivity violation is identified when there is no positivity violation."""
@@ -124,7 +124,7 @@ def test_execute_test_observational_causal_forest_estimator(self):
124
124
"C" ,
125
125
self .causal_test_engine .scenario_execution_data_df ,
126
126
)
127
- causal_test_result = self .causal_test_engine .execute_test (estimation_model , self .causal_test_case )
127
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .causal_test_case )
128
128
self .assertAlmostEqual (causal_test_result .test_value .value , 4 , delta = 1 )
129
129
130
130
def test_invalid_causal_effect (self ):
@@ -146,7 +146,7 @@ def test_execute_test_observational_linear_regression_estimator(self):
146
146
"C" ,
147
147
self .causal_test_engine .scenario_execution_data_df ,
148
148
)
149
- causal_test_result = self .causal_test_engine .execute_test (estimation_model , self .causal_test_case )
149
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .causal_test_case )
150
150
self .assertAlmostEqual (causal_test_result .test_value .value , 4 , delta = 1e-10 )
151
151
152
152
def test_execute_test_observational_linear_regression_estimator_direct_effect (self ):
@@ -175,7 +175,7 @@ def test_execute_test_observational_linear_regression_estimator_direct_effect(se
175
175
"C" ,
176
176
causal_test_engine .scenario_execution_data_df ,
177
177
)
178
- causal_test_result = causal_test_engine .execute_test (estimation_model , causal_test_case )
178
+ causal_test_result = causal_test_case .execute_test (estimation_model , causal_test_case )
179
179
self .assertAlmostEqual (causal_test_result .test_value .value , 4 , delta = 1e-10 )
180
180
181
181
def test_execute_test_observational_linear_regression_estimator_coefficient (self ):
@@ -190,7 +190,7 @@ def test_execute_test_observational_linear_regression_estimator_coefficient(self
190
190
self .causal_test_engine .scenario_execution_data_df ,
191
191
)
192
192
self .causal_test_case .estimate_type = "coefficient"
193
- causal_test_result = self .causal_test_engine .execute_test (estimation_model , self .causal_test_case )
193
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .causal_test_case )
194
194
self .assertEqual (int (causal_test_result .test_value .value ), 0 )
195
195
196
196
def test_execute_test_observational_linear_regression_estimator_risk_ratio (self ):
@@ -205,7 +205,7 @@ def test_execute_test_observational_linear_regression_estimator_risk_ratio(self)
205
205
self .causal_test_engine .scenario_execution_data_df ,
206
206
)
207
207
self .causal_test_case .estimate_type = "risk_ratio"
208
- causal_test_result = self .causal_test_engine .execute_test (estimation_model , self .causal_test_case )
208
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .causal_test_case )
209
209
self .assertEqual (int (causal_test_result .test_value .value ), 0 )
210
210
211
211
def test_invalid_estimate_type (self ):
@@ -221,7 +221,7 @@ def test_invalid_estimate_type(self):
221
221
)
222
222
self .causal_test_case .estimate_type = "invalid"
223
223
with self .assertRaises (AttributeError ):
224
- self .causal_test_engine .execute_test (estimation_model , self .causal_test_case )
224
+ self .causal_test_case .execute_test (estimation_model , self .causal_test_case )
225
225
226
226
def test_execute_test_observational_linear_regression_estimator_squared_term (self ):
227
227
"""Check that executing the causal test case returns the correct results for dummy data with a squared term
@@ -235,7 +235,7 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel
235
235
self .causal_test_engine .scenario_execution_data_df ,
236
236
formula = f"C ~ A + { '+' .join (self .minimal_adjustment_set )} + (D ** 2)" ,
237
237
)
238
- causal_test_result = self .causal_test_engine .execute_test (estimation_model , self .causal_test_case )
238
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .causal_test_case )
239
239
self .assertAlmostEqual (round (causal_test_result .test_value .value , 1 ), 4 , delta = 1 )
240
240
241
241
def test_execute_observational_causal_forest_estimator_cates (self ):
@@ -258,7 +258,7 @@ def test_execute_observational_causal_forest_estimator_cates(self):
258
258
effect_modifiers = {"M" : None },
259
259
)
260
260
self .causal_test_case .estimate_type = "cates"
261
- causal_test_result = self .causal_test_engine .execute_test (estimation_model , self .causal_test_case )
261
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .causal_test_case )
262
262
causal_test_result = causal_test_result .test_value .value
263
263
# Check that each effect modifier's strata has a greater ATE than the last (ascending order)
264
264
causal_test_result_m1 = causal_test_result .loc [causal_test_result ["M" ] == 1 ]
0 commit comments