@@ -203,3 +203,34 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel
203
203
)
204
204
causal_test_result = self .causal_test_case .execute_test (estimation_model )
205
205
pd .testing .assert_series_equal (causal_test_result .test_value .value , pd .Series (4.0 ), atol = 1 )
206
+
207
+ def test_estimate_params_none (self ):
208
+ """Check that estimate_params defaults to empty dict when None is passed into the estimator object"""
209
+ causal_test_case = CausalTestCase (
210
+ base_test_case = self .base_test_case_A_C ,
211
+ expected_causal_effect = self .expected_causal_effect ,
212
+ estimate_params = None ,
213
+ estimator = LinearRegressionEstimator (
214
+ base_test_case = self .base_test_case_A_C ,
215
+ adjustment_set = set (),
216
+ control_value = 0 ,
217
+ treatment_value = 1 ,
218
+ ),
219
+ )
220
+ self .assertEqual (causal_test_case .estimate_params , {})
221
+
222
+ def test_estimate_params_with_formula (self ):
223
+ """Ensure estimate params is handled correctly when a formula is passed into the estimator object"""
224
+ estimate_params = {"formula" : "C ~ A + D" }
225
+ causal_test_case = CausalTestCase (
226
+ base_test_case = self .base_test_case_A_C ,
227
+ expected_causal_effect = self .expected_causal_effect ,
228
+ estimate_params = estimate_params ,
229
+ estimator = LinearRegressionEstimator (
230
+ base_test_case = self .base_test_case_A_C ,
231
+ adjustment_set = set (),
232
+ control_value = 0 ,
233
+ treatment_value = 1 ,
234
+ ),
235
+ )
236
+ self .assertEqual (causal_test_case .estimate_params , estimate_params )
0 commit comments