Skip to content

Commit 92092b7

Browse files
committed
add: unit tests for the estimate_params parameter
1 parent 3c8d7e1 commit 92092b7

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tests/testing_tests/test_causal_test_case.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,34 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel
203203
)
204204
causal_test_result = self.causal_test_case.execute_test(estimation_model)
205205
pd.testing.assert_series_equal(causal_test_result.test_value.value, pd.Series(4.0), atol=1)
206+
207+
def test_estimate_params_none(self):
208+
"""Check that estimate_params defaults to empty dict when None is passed into the estimator object"""
209+
causal_test_case = CausalTestCase(
210+
base_test_case=self.base_test_case_A_C,
211+
expected_causal_effect=self.expected_causal_effect,
212+
estimate_params=None,
213+
estimator=LinearRegressionEstimator(
214+
base_test_case=self.base_test_case_A_C,
215+
adjustment_set=set(),
216+
control_value=0,
217+
treatment_value=1,
218+
),
219+
)
220+
self.assertEqual(causal_test_case.estimate_params, {})
221+
222+
def test_estimate_params_with_formula(self):
223+
"""Ensure estimate params is handled correctly when a formula is passed into the estimator object"""
224+
estimate_params = {"formula": "C ~ A + D"}
225+
causal_test_case = CausalTestCase(
226+
base_test_case=self.base_test_case_A_C,
227+
expected_causal_effect=self.expected_causal_effect,
228+
estimate_params=estimate_params,
229+
estimator=LinearRegressionEstimator(
230+
base_test_case=self.base_test_case_A_C,
231+
adjustment_set=set(),
232+
control_value=0,
233+
treatment_value=1,
234+
),
235+
)
236+
self.assertEqual(causal_test_case.estimate_params, estimate_params)

0 commit comments

Comments
 (0)