Skip to content

Commit 45391c3

Browse files
test_get_formulae unittest and fix bug with covariates
1 parent f3608d1 commit 45391c3

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

causal_testing/testing/estimators.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,12 @@ def get_terms_from_formula(self) -> tuple[str, str, list[str]]:
142142
rhs_terms.add(term.factors[0].code)
143143
if self.treatment not in rhs_terms:
144144
raise ValueError(f"Treatment variable '{self.treatment}' not found in formula")
145-
covariates = rhs_terms.remove(self.treatment)
145+
rhs_terms.remove(self.treatment)
146+
covariates = rhs_terms
146147
if covariates is None:
147148
covariates = []
149+
else:
150+
covariates = list(covariates)
148151
return outcome, self.treatment, covariates
149152

150153
def validate_formula(self, causal_dag: CausalDAG):

tests/testing_tests/test_estimators.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
CausalForestEstimator,
88
LogisticRegressionEstimator,
99
InstrumentalVariableEstimator,
10+
RegressionEstimator,
1011
)
1112
from causal_testing.specification.variable import Input
1213
from causal_testing.utils.validation import CausalValidator
@@ -124,15 +125,15 @@ def test_ate_adjustment(self):
124125
logistic_regression_estimator = LogisticRegressionEstimator(
125126
"length_in", 65, 55, {"large_gauge"}, "completed", df
126127
)
127-
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config = {"large_gauge": 0})
128+
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
128129
self.assertEqual(round(ate, 4), -0.3388)
129130

130131
def test_ate_invalid_adjustment(self):
131132
df = self.scarf_df.copy()
132133
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
133134
with self.assertRaises(ValueError):
134135
ate, _ = logistic_regression_estimator.estimate_ate(
135-
adjustment_config = {"large_gauge": 0}
136+
adjustment_config={"large_gauge": 0}
136137
)
137138

138139
def test_ate_effect_modifiers(self):
@@ -394,7 +395,7 @@ def test_program_15_no_interaction_ate_calculated(self):
394395
# for term_to_square in terms_to_square:
395396

396397
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
397-
adjustment_config = {k: self.nhefs_df.mean()[k] for k in covariates}
398+
adjustment_config={k: self.nhefs_df.mean()[k] for k in covariates}
398399
)
399400
self.assertEqual(round(ate, 1), 3.5)
400401
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])
@@ -491,3 +492,21 @@ def test_X1_effect(self):
491492
test_results = lr_model.estimate_ate()
492493
ate = test_results[0]
493494
self.assertAlmostEqual(ate, 2.0)
495+
496+
497+
class TestRegressionEstimator(unittest.TestCase):
498+
"""Test the extended functionality of the TestRegressionEstimator"""
499+
500+
@classmethod
501+
def setUpClass(cls):
502+
class RegressionEstimatorTesting(RegressionEstimator):
503+
def add_modelling_assumptions(self):
504+
pass
505+
506+
cls.regression_estimator = RegressionEstimatorTesting("X", 1, 0, {"Z"}, "Y", formula="Y ~ X + Z")
507+
508+
def test_get_formulae(self):
509+
outcome, treatment, covariates = self.regression_estimator.get_terms_from_formula()
510+
self.assertEqual(outcome, "Y")
511+
self.assertEqual(treatment, "X")
512+
self.assertEqual(covariates, ["Z"])

0 commit comments

Comments
 (0)