@@ -108,7 +108,7 @@ def setUp(self) -> None:
108
108
# Obsolete?
109
109
self .data_collector = ObservationalDataCollector (self .scenario , df )
110
110
self .data_collector .collect_data ()
111
- self .df = self .data_collector .data
111
+ self .df = self .data_collector .collect_data ()
112
112
self .minimal_adjustment_set = self .causal_dag .identification (self .base_test_case )
113
113
# 6. Easier to access treatment and outcome values
114
114
self .treatment_value = 1
@@ -130,7 +130,7 @@ def test_execute_test_observational_causal_forest_estimator(self):
130
130
"C" ,
131
131
self .df ,
132
132
)
133
- causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self . causal_specification )
133
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector )
134
134
self .assertAlmostEqual (causal_test_result .test_value .value , 4 , delta = 1 )
135
135
136
136
def test_invalid_causal_effect (self ):
@@ -152,7 +152,7 @@ def test_execute_test_observational_linear_regression_estimator(self):
152
152
"C" ,
153
153
self .df ,
154
154
)
155
- causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self . causal_specification )
155
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector )
156
156
self .assertAlmostEqual (causal_test_result .test_value .value , 4 , delta = 1e-10 )
157
157
158
158
def test_execute_test_observational_linear_regression_estimator_direct_effect (self ):
@@ -179,7 +179,7 @@ def test_execute_test_observational_linear_regression_estimator_direct_effect(se
179
179
"C" ,
180
180
self .df ,
181
181
)
182
- causal_test_result = causal_test_case .execute_test (estimation_model , self .data_collector , self . causal_specification )
182
+ causal_test_result = causal_test_case .execute_test (estimation_model , self .data_collector )
183
183
self .assertAlmostEqual (causal_test_result .test_value .value , 4 , delta = 1e-10 )
184
184
185
185
def test_execute_test_observational_linear_regression_estimator_coefficient (self ):
@@ -194,7 +194,7 @@ def test_execute_test_observational_linear_regression_estimator_coefficient(self
194
194
self .df ,
195
195
)
196
196
self .causal_test_case .estimate_type = "coefficient"
197
- causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self . causal_specification )
197
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector )
198
198
self .assertEqual (int (causal_test_result .test_value .value ), 0 )
199
199
200
200
def test_execute_test_observational_linear_regression_estimator_risk_ratio (self ):
@@ -209,7 +209,7 @@ def test_execute_test_observational_linear_regression_estimator_risk_ratio(self)
209
209
self .df ,
210
210
)
211
211
self .causal_test_case .estimate_type = "risk_ratio"
212
- causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self . causal_specification )
212
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector )
213
213
self .assertEqual (int (causal_test_result .test_value .value ), 0 )
214
214
215
215
def test_invalid_estimate_type (self ):
@@ -225,7 +225,7 @@ def test_invalid_estimate_type(self):
225
225
)
226
226
self .causal_test_case .estimate_type = "invalid"
227
227
with self .assertRaises (AttributeError ):
228
- self .causal_test_case .execute_test (estimation_model , self .data_collector , self . causal_specification )
228
+ self .causal_test_case .execute_test (estimation_model , self .data_collector )
229
229
230
230
def test_execute_test_observational_linear_regression_estimator_squared_term (self ):
231
231
"""Check that executing the causal test case returns the correct results for dummy data with a squared term
@@ -239,7 +239,7 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel
239
239
self .df ,
240
240
formula = f"C ~ A + { '+' .join (self .minimal_adjustment_set )} + (D ** 2)" ,
241
241
)
242
- causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self . causal_specification )
242
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector )
243
243
self .assertAlmostEqual (round (causal_test_result .test_value .value , 1 ), 4 , delta = 1 )
244
244
245
245
def test_execute_observational_causal_forest_estimator_cates (self ):
@@ -262,7 +262,7 @@ def test_execute_observational_causal_forest_estimator_cates(self):
262
262
effect_modifiers = {"M" : None },
263
263
)
264
264
self .causal_test_case .estimate_type = "cates"
265
- causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector , self . causal_specification )
265
+ causal_test_result = self .causal_test_case .execute_test (estimation_model , self .data_collector )
266
266
causal_test_result = causal_test_result .test_value .value
267
267
# Check that each effect modifier's strata has a greater ATE than the last (ascending order)
268
268
causal_test_result_m1 = causal_test_result .loc [causal_test_result ["M" ] == 1 ]
0 commit comments