Skip to content

Commit f0b2dec

Browse files
committed
Stopped storing regression estimator models
1 parent cdef513 commit f0b2dec

File tree

7 files changed

+19
-74
lines changed

7 files changed

+19
-74
lines changed

causal_testing/estimation/abstract_regression_estimator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(
4545
query=query,
4646
)
4747

48-
self.model = None
4948
if effect_modifiers is None:
5049
effect_modifiers = []
5150
if adjustment_set is None:
@@ -79,15 +78,14 @@ def add_modelling_assumptions(self):
7978
"do not need to be linear."
8079
)
8180

82-
def _run_regression(self, data=None) -> RegressionResultsWrapper:
81+
def fit_model(self, data=None) -> RegressionResultsWrapper:
8382
"""Run logistic regression of the treatment and adjustment set against the outcome and return the model.
8483
8584
:return: The model after fitting to data.
8685
"""
8786
if data is None:
8887
data = self.df
8988
model = self.regressor(formula=self.formula, data=data).fit(disp=0)
90-
self.model = model
9189
return model
9290

9391
def _predict(self, data=None, adjustment_config: dict = None) -> pd.DataFrame:
@@ -102,7 +100,7 @@ def _predict(self, data=None, adjustment_config: dict = None) -> pd.DataFrame:
102100
if adjustment_config is None:
103101
adjustment_config = {}
104102

105-
model = self._run_regression(data)
103+
model = self.fit_model(data)
106104

107105
x = pd.DataFrame(columns=self.df.columns)
108106
x["Intercept"] = 1 # self.intercept

causal_testing/estimation/cubic_spline_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
5959
6060
:return: The average treatment effect.
6161
"""
62-
model = self._run_regression()
62+
model = self.fit_model()
6363

6464
x = {"Intercept": 1, self.base_test_case.treatment_variable.name: self.treatment_value}
6565
if adjustment_config is not None:

causal_testing/estimation/linear_regression_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
9898
9999
:return: The unit average treatment effect and the 95% Wald confidence intervals.
100100
"""
101-
model = self._run_regression()
101+
model = self.fit_model()
102102
newline = "\n"
103103
patsy_md = ModelDesc.from_formula(self.base_test_case.treatment_variable.name)
104104

@@ -129,7 +129,7 @@ def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
129129
130130
:return: The average treatment effect and the 95% Wald confidence intervals.
131131
"""
132-
model = self._run_regression()
132+
model = self.fit_model()
133133

134134
# Create an empty individual for the control and treated
135135
individuals = pd.DataFrame(1, index=["control", "treated"], columns=model.params.index)

causal_testing/estimation/logistic_regression_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def estimate_unit_odds_ratio(self) -> tuple[pd.Series, list[pd.Series, pd.Series
3838
3939
:return: The odds ratio. Confidence intervals are not yet supported.
4040
"""
41-
model = self._run_regression(self.df)
41+
model = self.fit_model(self.df)
4242
ci_low, ci_high = np.exp(model.conf_int(self.alpha).loc[self.base_test_case.treatment_variable.name])
4343
return pd.Series(np.exp(model.params[self.base_test_case.treatment_variable.name])), [
4444
pd.Series(ci_low),

causal_testing/main.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,6 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
361361
for test_case in current_batch:
362362
try:
363363
batch_results.append(test_case.execute_test())
364-
# Need to remove the model so we don't take up all the memory
365-
# Would be good to profile the execute_test() method a bit further so we don't need to do this
366-
test_case.estimator.model = None
367364
# pylint: disable=broad-exception-caught
368365
except Exception as e:
369366
if not silent:

tests/estimation_tests/test_cubic_spline_estimator.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
11
import unittest
2-
import pandas as pd
3-
import numpy as np
4-
import matplotlib.pyplot as plt
5-
from causal_testing.specification.variable import Input
6-
from causal_testing.utils.validation import CausalValidator
72

83
from causal_testing.estimation.cubic_spline_estimator import CubicSplineRegressionEstimator
9-
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
10-
11-
from tests.estimation_tests.test_linear_regression_estimator import TestLinearRegressionEstimator
124
from causal_testing.testing.base_test_case import BaseTestCase
135
from causal_testing.specification.variable import Input, Output
146

7+
from tests.estimation_tests.test_linear_regression_estimator import load_chapter_11_df
8+
159

16-
class TestCubicSplineRegressionEstimator(TestLinearRegressionEstimator):
10+
class TestCubicSplineRegressionEstimator(unittest.TestCase):
1711
@classmethod
1812
def setUpClass(cls):
1913
super().setUpClass()
@@ -24,22 +18,14 @@ def test_program_11_3_cublic_spline(self):
2418
Slightly modified as Hernan et al. use linear regression for this example.
2519
"""
2620

27-
df = self.chapter_11_df.copy()
21+
df = load_chapter_11_df()
2822

2923
base_test_case = BaseTestCase(Input("treatments", float), Output("outcomes", float))
3024

3125
cublic_spline_estimator = CubicSplineRegressionEstimator(base_test_case, 1, 0, set(), 3, df)
3226

3327
ate_1 = cublic_spline_estimator.estimate_ate_calculated()
3428

35-
self.assertEqual(
36-
round(
37-
cublic_spline_estimator.model.predict({"Intercept": 1, "treatments": 90}).iloc[0],
38-
1,
39-
),
40-
195.6,
41-
)
42-
4329
cublic_spline_estimator.treatment_value = 2
4430
ate_2 = cublic_spline_estimator.estimate_ate_calculated()
4531

tests/estimation_tests/test_linear_regression_estimator.py

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import unittest
22
import pandas as pd
33
import numpy as np
4-
import matplotlib.pyplot as plt
5-
from causal_testing.specification.variable import Input
4+
from causal_testing.specification.variable import Input, Output
65
from causal_testing.utils.validation import CausalValidator
76

87
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
98
from causal_testing.testing.base_test_case import BaseTestCase
10-
from causal_testing.specification.variable import Input, Output
119

1210

1311
def load_nhefs_df():
@@ -77,7 +75,7 @@ def test_linear_regression_categorical_ate(self):
7775
df = self.scarf_df.copy()
7876
base_test_case = BaseTestCase(Input("color", float), Output("completed", float))
7977
logistic_regression_estimator = LinearRegressionEstimator(base_test_case, None, None, set(), df)
80-
ate, confidence = logistic_regression_estimator.estimate_coefficient()
78+
_, confidence = logistic_regression_estimator.estimate_coefficient()
8179
self.assertTrue(all([ci_low < 0 < ci_high for ci_low, ci_high in zip(confidence[0], confidence[1])]))
8280

8381
def test_program_11_2(self):
@@ -86,22 +84,8 @@ def test_program_11_2(self):
8684
linear_regression_estimator = LinearRegressionEstimator(self.base_test_case, None, None, set(), df)
8785
ate, _ = linear_regression_estimator.estimate_coefficient()
8886

89-
self.assertEqual(
90-
round(
91-
linear_regression_estimator.model.params["Intercept"]
92-
+ 90 * linear_regression_estimator.model.params["treatments"],
93-
1,
94-
),
95-
216.9,
96-
)
97-
9887
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
99-
self.assertTrue(
100-
all(
101-
round(linear_regression_estimator.model.params["treatments"], 1) == round(ate_single, 1)
102-
for ate_single in ate
103-
)
104-
)
88+
self.assertTrue(all(round(ate["treatments"], 1) == round(ate_single, 1) for ate_single in ate))
10589

10690
def test_program_11_3(self):
10791
"""Test whether our linear regression implementation produces the same results as program 11.3 (p. 144)."""
@@ -110,23 +94,8 @@ def test_program_11_3(self):
11094
self.base_test_case, None, None, set(), df, formula="outcomes ~ treatments + I(treatments ** 2)"
11195
)
11296
ate, _ = linear_regression_estimator.estimate_coefficient()
113-
print(linear_regression_estimator.model.summary())
114-
self.assertEqual(
115-
round(
116-
linear_regression_estimator.model.params["Intercept"]
117-
+ 90 * linear_regression_estimator.model.params["treatments"]
118-
+ 90 * 90 * linear_regression_estimator.model.params["I(treatments ** 2)"],
119-
1,
120-
),
121-
197.1,
122-
)
12397
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
124-
self.assertTrue(
125-
all(
126-
round(linear_regression_estimator.model.params["treatments"], 3) == round(ate_single, 3)
127-
for ate_single in ate
128-
)
129-
)
98+
self.assertTrue(all(round(ate["treatments"], 3) == round(ate_single, 3) for ate_single in ate))
13099

131100
def test_program_15_1A(self):
132101
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)."""
@@ -161,15 +130,9 @@ def test_program_15_1A(self):
161130
I(smokeyrs ** 2) +
162131
(qsmk * smokeintensity)""",
163132
)
164-
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
165-
# terms_to_product = [("qsmk", "smokeintensity")]
166-
# for term_to_square in terms_to_square:
167-
# for term_a, term_b in terms_to_product:
168-
# linear_regression_estimator.add_product_term_to_df(term_a, term_b)
169133

170-
linear_regression_estimator.estimate_coefficient()
171-
self.assertEqual(round(linear_regression_estimator.model.params["qsmk"], 1), 2.6)
172-
self.assertEqual(round(linear_regression_estimator.model.params["qsmk:smokeintensity"], 2), 0.05)
134+
coefficient, _ = linear_regression_estimator.estimate_coefficient()
135+
self.assertEqual(round(coefficient["qsmk"], 1), 2.6)
173136

174137
def test_program_15_no_interaction(self):
175138
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)
@@ -281,10 +244,11 @@ def test_program_11_2_with_robustness_validation(self):
281244
"""Test whether our linear regression estimator, as used in test_program_11_2 can correctly estimate robustness."""
282245
df = self.chapter_11_df.copy()
283246
linear_regression_estimator = LinearRegressionEstimator(self.base_test_case, 100, 90, set(), df)
284-
linear_regression_estimator.estimate_coefficient()
285247

286248
cv = CausalValidator()
287-
self.assertEqual(round(cv.estimate_robustness(linear_regression_estimator.model)["treatments"], 4), 0.7353)
249+
self.assertEqual(
250+
round(cv.estimate_robustness(linear_regression_estimator.fit_model())["treatments"], 4), 0.7353
251+
)
288252

289253
def test_gp(self):
290254
df = pd.DataFrame()

0 commit comments

Comments
 (0)