@@ -206,31 +206,42 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel
206
206
207
207
def test_estimate_params_none (self ):
208
208
"""Check that estimate_params defaults to empty dict when None is passed into the estimator object"""
209
+ estimator = LinearRegressionEstimator (
210
+ base_test_case = self .base_test_case_A_C ,
211
+ adjustment_set = set (),
212
+ control_value = 0 ,
213
+ treatment_value = 1 ,
214
+ formula = "C ~ A + D" ,
215
+ df = self .df ,
216
+ )
209
217
causal_test_case = CausalTestCase (
210
218
base_test_case = self .base_test_case_A_C ,
211
219
expected_causal_effect = self .expected_causal_effect ,
212
220
estimate_params = None ,
213
- estimator = LinearRegressionEstimator (
214
- base_test_case = self .base_test_case_A_C ,
215
- adjustment_set = set (),
216
- control_value = 0 ,
217
- treatment_value = 1 ,
218
- ),
221
+ estimator = estimator ,
222
+ estimate_type = "risk_ratio" ,
219
223
)
220
224
self .assertEqual (causal_test_case .estimate_params , {})
225
+ with self .assertRaises (ValueError ):
226
+ causal_test_case .execute_test ()
221
227
222
228
def test_estimate_params_with_formula (self ):
223
229
"""Ensure estimate params is handled correctly when a formula is passed into the estimator object"""
224
- estimate_params = {"formula" : "C ~ A + D" }
230
+ estimate_params = {"adjustment_config" : {"D" : 1 }}
231
+ estimator = LinearRegressionEstimator (
232
+ base_test_case = self .base_test_case_A_C ,
233
+ adjustment_set = set (),
234
+ control_value = 0 ,
235
+ treatment_value = 1 ,
236
+ formula = "C ~ A + D" ,
237
+ df = self .df ,
238
+ )
225
239
causal_test_case = CausalTestCase (
226
240
base_test_case = self .base_test_case_A_C ,
227
241
expected_causal_effect = self .expected_causal_effect ,
228
242
estimate_params = estimate_params ,
229
- estimator = LinearRegressionEstimator (
230
- base_test_case = self .base_test_case_A_C ,
231
- adjustment_set = set (),
232
- control_value = 0 ,
233
- treatment_value = 1 ,
234
- ),
243
+ estimate_type = "risk_ratio" ,
244
+ estimator = estimator ,
235
245
)
236
- self .assertEqual (causal_test_case .estimate_params , estimate_params )
246
+ self .assertEqual (causal_test_case .estimate_params , estimate_params )
247
+ self .assertEqual (round (causal_test_case .execute_test ().test_value .value [0 ], 3 ), 1.444 )
0 commit comments