|
7 | 7 | CausalForestEstimator,
|
8 | 8 | LogisticRegressionEstimator,
|
9 | 9 | InstrumentalVariableEstimator,
|
| 10 | + RegressionEstimator, |
10 | 11 | )
|
11 | 12 | from causal_testing.specification.variable import Input
|
12 | 13 | from causal_testing.utils.validation import CausalValidator
|
@@ -124,15 +125,15 @@ def test_ate_adjustment(self):
|
124 | 125 | logistic_regression_estimator = LogisticRegressionEstimator(
|
125 | 126 | "length_in", 65, 55, {"large_gauge"}, "completed", df
|
126 | 127 | )
|
127 |
| - ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config = {"large_gauge": 0}) |
| 128 | + ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0}) |
128 | 129 | self.assertEqual(round(ate, 4), -0.3388)
|
129 | 130 |
|
130 | 131 | def test_ate_invalid_adjustment(self):
|
131 | 132 | df = self.scarf_df.copy()
|
132 | 133 | logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
|
133 | 134 | with self.assertRaises(ValueError):
|
134 | 135 | ate, _ = logistic_regression_estimator.estimate_ate(
|
135 |
| - adjustment_config = {"large_gauge": 0} |
| 136 | + adjustment_config={"large_gauge": 0} |
136 | 137 | )
|
137 | 138 |
|
138 | 139 | def test_ate_effect_modifiers(self):
|
@@ -394,7 +395,7 @@ def test_program_15_no_interaction_ate_calculated(self):
|
394 | 395 | # for term_to_square in terms_to_square:
|
395 | 396 |
|
396 | 397 | ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
|
397 |
| - adjustment_config = {k: self.nhefs_df.mean()[k] for k in covariates} |
| 398 | + adjustment_config={k: self.nhefs_df.mean()[k] for k in covariates} |
398 | 399 | )
|
399 | 400 | self.assertEqual(round(ate, 1), 3.5)
|
400 | 401 | self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])
|
@@ -491,3 +492,21 @@ def test_X1_effect(self):
|
491 | 492 | test_results = lr_model.estimate_ate()
|
492 | 493 | ate = test_results[0]
|
493 | 494 | self.assertAlmostEqual(ate, 2.0)
|
| 495 | + |
| 496 | + |
| 497 | +class TestRegressionEstimator(unittest.TestCase): |
| 498 | + """Test the extended functionality of the TestRegressionEstimator""" |
| 499 | + |
| 500 | + @classmethod |
| 501 | + def setUpClass(cls): |
| 502 | + class RegressionEstimatorTesting(RegressionEstimator): |
| 503 | + def add_modelling_assumptions(self): |
| 504 | + pass |
| 505 | + |
| 506 | + cls.regression_estimator = RegressionEstimatorTesting("X", 1, 0, {"Z"}, "Y", formula="Y ~ X + Z") |
| 507 | + |
| 508 | + def test_get_formulae(self): |
| 509 | + outcome, treatment, covariates = self.regression_estimator.get_terms_from_formula() |
| 510 | + self.assertEqual(outcome, "Y") |
| 511 | + self.assertEqual(treatment, "X") |
| 512 | + self.assertEqual(covariates, ["Z"]) |
0 commit comments