Skip to content

Commit e7deb78

Browse files
committed
all tests pass
1 parent 5b0661d commit e7deb78

File tree

6 files changed

+124
-19
lines changed

6 files changed

+124
-19
lines changed

causal_testing/estimation/gp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class GP:
9898
Object to perform genetic programming.
9999
"""
100100

101+
# pylint: disable=too-many-instance-attributes
102+
101103
def __init__(
102104
self,
103105
df: pd.DataFrame,
@@ -109,7 +111,6 @@ def __init__(
109111
seed=0,
110112
):
111113
# pylint: disable=too-many-arguments
112-
# pylint: disable=too-many-instance-attributes
113114
random.seed(seed)
114115
self.df = df
115116
self.features = features

causal_testing/estimation/logistic_regression_estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def __init__(
4747
)
4848

4949
self.model = None
50-
50+
if effect_modifiers is None:
51+
effect_modifiers = []
5152
if formula is not None:
5253
self.formula = formula
5354
else:
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""This module contains the RegressionEstimator, which is an abstract class for concrete regression estimators."""
2+
3+
import logging
4+
from typing import Any
5+
from abc import abstractmethod, abstractmethod
6+
7+
import pandas as pd
8+
import statsmodels.formula.api as smf
9+
from patsy import dmatrix # pylint: disable = no-name-in-module
10+
from patsy import ModelDesc
11+
from statsmodels.regression.linear_model import RegressionResultsWrapper
12+
13+
from causal_testing.specification.variable import Variable
14+
from causal_testing.estimation.gp import GP
15+
from causal_testing.estimation.estimator import Estimator
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class RegressionEstimator(Estimator):
21+
"""A Linear Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
22+
combination of parameters and functions of the variables (note these functions need not be linear).
23+
"""
24+
25+
def __init__(
26+
# pylint: disable=too-many-arguments
27+
self,
28+
treatment: str,
29+
treatment_value: float,
30+
control_value: float,
31+
adjustment_set: set,
32+
outcome: str,
33+
df: pd.DataFrame = None,
34+
effect_modifiers: dict[Variable:Any] = None,
35+
formula: str = None,
36+
alpha: float = 0.05,
37+
query: str = "",
38+
):
39+
super().__init__(
40+
treatment=treatment,
41+
treatment_value=treatment_value,
42+
control_value=control_value,
43+
adjustment_set=adjustment_set,
44+
outcome=outcome,
45+
df=df,
46+
effect_modifiers=effect_modifiers,
47+
query=query,
48+
)
49+
50+
self.model = None
51+
if effect_modifiers is None:
52+
effect_modifiers = []
53+
if formula is not None:
54+
self.formula = formula
55+
else:
56+
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
57+
self.formula = f"{outcome} ~ {'+'.join(terms)}"
58+
for term in self.effect_modifiers:
59+
self.adjustment_set.add(term)
60+
61+
@property
62+
@abstractmethod
63+
def regression(self):
64+
raise NotImplementedError("Subclasses must implement the 'model' property.")
65+
66+
def add_modelling_assumptions(self):
67+
"""
68+
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
69+
must hold if the resulting causal inference is to be considered valid.
70+
"""
71+
self.modelling_assumptions.append(
72+
"The variables in the data must fit a shape which can be expressed as a linear"
73+
"combination of parameters and functions of variables. Note that these functions"
74+
"do not need to be linear."
75+
)
76+
77+
def _run_regression(self, data=None) -> RegressionResultsWrapper:
78+
"""Run logistic regression of the treatment and adjustment set against the outcome and return the model.
79+
80+
:return: The model after fitting to data.
81+
"""
82+
if data is None:
83+
data = self.df
84+
model = self.regression(formula=self.formula, data=data).fit(disp=0)
85+
self.model = model
86+
return model

examples/poisson-line-process/example_poisson_process.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from causal_testing.specification.causal_specification import CausalSpecification
55
from causal_testing.testing.causal_test_case import CausalTestCase
66
from causal_testing.testing.causal_test_outcome import ExactValue, Positive
7-
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator, Estimator
7+
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
8+
from causal_testing.estimation.estimator import Estimator
89
from causal_testing.testing.base_test_case import BaseTestCase
910

1011
import pandas as pd

tests/estimation_tests/test_linear_regression_estimator.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,33 +81,48 @@ def test_program_11_2(self):
8181
"""Test whether our linear regression implementation produces the same results as program 11.2 (p. 141)."""
8282
df = self.chapter_11_df
8383
linear_regression_estimator = LinearRegressionEstimator("treatments", None, None, set(), "outcomes", df)
84-
model = linear_regression_estimator._run_linear_regression()
8584
ate, _ = linear_regression_estimator.estimate_coefficient()
8685

87-
self.assertEqual(round(model.params["Intercept"] + 90 * model.params["treatments"], 1), 216.9)
86+
self.assertEqual(
87+
round(
88+
linear_regression_estimator.model.params["Intercept"]
89+
+ 90 * linear_regression_estimator.model.params["treatments"],
90+
1,
91+
),
92+
216.9,
93+
)
8894

8995
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
90-
self.assertTrue(all(round(model.params["treatments"], 1) == round(ate_single, 1) for ate_single in ate))
96+
self.assertTrue(
97+
all(
98+
round(linear_regression_estimator.model.params["treatments"], 1) == round(ate_single, 1)
99+
for ate_single in ate
100+
)
101+
)
91102

92103
def test_program_11_3(self):
93104
"""Test whether our linear regression implementation produces the same results as program 11.3 (p. 144)."""
94105
df = self.chapter_11_df.copy()
95106
linear_regression_estimator = LinearRegressionEstimator(
96107
"treatments", None, None, set(), "outcomes", df, formula="outcomes ~ treatments + I(treatments ** 2)"
97108
)
98-
model = linear_regression_estimator._run_linear_regression()
99109
ate, _ = linear_regression_estimator.estimate_coefficient()
100110
self.assertEqual(
101111
round(
102-
model.params["Intercept"]
103-
+ 90 * model.params["treatments"]
104-
+ 90 * 90 * model.params["I(treatments ** 2)"],
112+
linear_regression_estimator.model.params["Intercept"]
113+
+ 90 * linear_regression_estimator.model.params["treatments"]
114+
+ 90 * 90 * linear_regression_estimator.model.params["I(treatments ** 2)"],
105115
1,
106116
),
107117
197.1,
108118
)
109119
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
110-
self.assertTrue(all(round(model.params["treatments"], 3) == round(ate_single, 3) for ate_single in ate))
120+
self.assertTrue(
121+
all(
122+
round(linear_regression_estimator.model.params["treatments"], 3) == round(ate_single, 3)
123+
for ate_single in ate
124+
)
125+
)
111126

112127
def test_program_15_1A(self):
113128
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)."""
@@ -149,9 +164,9 @@ def test_program_15_1A(self):
149164
# for term_a, term_b in terms_to_product:
150165
# linear_regression_estimator.add_product_term_to_df(term_a, term_b)
151166

152-
model = linear_regression_estimator._run_linear_regression()
153-
self.assertEqual(round(model.params["qsmk"], 1), 2.6)
154-
self.assertEqual(round(model.params["qsmk:smokeintensity"], 2), 0.05)
167+
linear_regression_estimator.estimate_coefficient()
168+
self.assertEqual(round(linear_regression_estimator.model.params["qsmk"], 1), 2.6)
169+
self.assertEqual(round(linear_regression_estimator.model.params["qsmk:smokeintensity"], 2), 0.05)
155170

156171
def test_program_15_no_interaction(self):
157172
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)
@@ -266,10 +281,10 @@ def test_program_11_2_with_robustness_validation(self):
266281
"""Test whether our linear regression estimator, as used in test_program_11_2 can correctly estimate robustness."""
267282
df = self.chapter_11_df.copy()
268283
linear_regression_estimator = LinearRegressionEstimator("treatments", 100, 90, set(), "outcomes", df)
269-
model = linear_regression_estimator._run_linear_regression()
284+
linear_regression_estimator.estimate_coefficient()
270285

271286
cv = CausalValidator()
272-
self.assertEqual(round(cv.estimate_robustness(model)["treatments"], 4), 0.7353)
287+
self.assertEqual(round(cv.estimate_robustness(linear_regression_estimator.model)["treatments"], 4), 0.7353)
273288

274289
def test_gp(self):
275290
df = pd.DataFrame()
@@ -291,7 +306,7 @@ def test_gp_power(self):
291306
linear_regression_estimator.gp_formula(seed=1, max_order=0)
292307
self.assertEqual(
293308
linear_regression_estimator.formula,
294-
"Y ~ I(2.0*X**2 + 3.8205100524608823e-31) - 1",
309+
"Y ~ I(1.9999999999999999*X**2 - 1.0043240235058056e-116*X + 2.6645352591003757e-15) - 1",
295310
)
296311
ate, (ci_low, ci_high) = linear_regression_estimator.estimate_ate_calculated()
297312
self.assertEqual(round(ate[0], 2), -2.00)

tests/json_front_tests/test_json_class.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import scipy
55
import os
66

7-
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator, Estimator
7+
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
8+
from causal_testing.estimation.estimator import Estimator
89
from causal_testing.testing.causal_test_outcome import NoEffect, Positive
910
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
1011
from causal_testing.specification.variable import Input, Output, Meta
@@ -313,7 +314,7 @@ def add_modelling_assumptions(self):
313314
effects = {"Positive": Positive()}
314315
mutates = {
315316
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
316-
> self.json_class.scenario.variables[x].z3
317+
> self.json_class.scenario.variables[x].z3
317318
}
318319
estimators = {"ExampleEstimator": ExampleEstimator}
319320
with self.assertRaises(TypeError):

0 commit comments

Comments
 (0)