Skip to content

Commit a940927

Browse files
committed
Added extra assertions that run the test case
1 parent 92092b7 commit a940927

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

tests/testing_tests/test_causal_test_case.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -206,31 +206,42 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel
206206

207207
def test_estimate_params_none(self):
208208
"""Check that estimate_params defaults to empty dict when None is passed into the estimator object"""
209+
estimator = LinearRegressionEstimator(
210+
base_test_case=self.base_test_case_A_C,
211+
adjustment_set=set(),
212+
control_value=0,
213+
treatment_value=1,
214+
formula="C ~ A + D",
215+
df=self.df,
216+
)
209217
causal_test_case = CausalTestCase(
210218
base_test_case=self.base_test_case_A_C,
211219
expected_causal_effect=self.expected_causal_effect,
212220
estimate_params=None,
213-
estimator=LinearRegressionEstimator(
214-
base_test_case=self.base_test_case_A_C,
215-
adjustment_set=set(),
216-
control_value=0,
217-
treatment_value=1,
218-
),
221+
estimator=estimator,
222+
estimate_type="risk_ratio",
219223
)
220224
self.assertEqual(causal_test_case.estimate_params, {})
225+
with self.assertRaises(ValueError):
226+
causal_test_case.execute_test()
221227

222228
def test_estimate_params_with_formula(self):
223229
"""Ensure estimate params is handled correctly when a formula is passed into the estimator object"""
224-
estimate_params = {"formula": "C ~ A + D"}
230+
estimate_params = {"adjustment_config": {"D": 1}}
231+
estimator = LinearRegressionEstimator(
232+
base_test_case=self.base_test_case_A_C,
233+
adjustment_set=set(),
234+
control_value=0,
235+
treatment_value=1,
236+
formula="C ~ A + D",
237+
df=self.df,
238+
)
225239
causal_test_case = CausalTestCase(
226240
base_test_case=self.base_test_case_A_C,
227241
expected_causal_effect=self.expected_causal_effect,
228242
estimate_params=estimate_params,
229-
estimator=LinearRegressionEstimator(
230-
base_test_case=self.base_test_case_A_C,
231-
adjustment_set=set(),
232-
control_value=0,
233-
treatment_value=1,
234-
),
243+
estimate_type="risk_ratio",
244+
estimator=estimator,
235245
)
236-
self.assertEqual(causal_test_case.estimate_params, estimate_params)
246+
self.assertEqual(causal_test_case.estimate_params, estimate_params)
247+
self.assertEqual(round(causal_test_case.execute_test().test_value.value[0], 3), 1.444)

0 commit comments

Comments
 (0)