Skip to content

Commit cd7ae1f

Browse files
committed
Want to merve instrumental-variables branch
1 parent fc70635 commit cd7ae1f

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

causal_testing/testing/estimators.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import numpy as np
88
import pandas as pd
99
import statsmodels.api as sm
10+
import statsmodels.formula.api as smf
1011
from econml.dml import CausalForestDML
12+
1113
from sklearn.ensemble import GradientBoostingRegressor
1214
from statsmodels.regression.linear_model import RegressionResultsWrapper
1315
from statsmodels.tools.sm_exceptions import PerfectSeparationError
@@ -36,11 +38,11 @@ class Estimator(ABC):
3638

3739
def __init__(
3840
self,
39-
treatment: tuple,
41+
treatment: str,
4042
treatment_value: float,
4143
control_value: float,
4244
adjustment_set: set,
43-
outcome: tuple,
45+
outcome: str,
4446
df: pd.DataFrame = None,
4547
effect_modifiers: dict[Variable:Any] = None,
4648
):
@@ -93,11 +95,11 @@ class LogisticRegressionEstimator(Estimator):
9395

9496
def __init__(
9597
self,
96-
treatment: tuple,
98+
treatment: str,
9799
treatment_value: float,
98100
control_value: float,
99101
adjustment_set: set,
100-
outcome: tuple,
102+
outcome: str,
101103
df: pd.DataFrame = None,
102104
effect_modifiers: dict[Variable:Any] = None,
103105
intercept: int = 1,
@@ -292,27 +294,28 @@ class LinearRegressionEstimator(Estimator):
292294

293295
def __init__(
294296
self,
295-
treatment: tuple,
297+
treatment: str,
296298
treatment_value: float,
297299
control_value: float,
298300
adjustment_set: set,
299-
outcome: tuple,
301+
outcome: str,
300302
df: pd.DataFrame = None,
301303
effect_modifiers: dict[Variable:Any] = None,
302-
product_terms: list[tuple[Variable, Variable]] = None,
303-
intercept: int = 1,
304+
formula: str = None
304305
):
305306
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
306307

307-
self.product_terms = []
308-
self.square_terms = []
309-
self.inverse_terms = []
310-
self.intercept = intercept
311308
self.model = None
309+
if effect_modifiers is None:
310+
effect_modifiers = []
311+
312+
if formula is not None:
313+
# TODO: validate it
314+
self.formula = formula
315+
else:
316+
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
317+
self.formula = f"{outcome} ~ {'+'.join(((terms)))} + Intercept"
312318

313-
if product_terms:
314-
for term_a, term_b in product_terms:
315-
self.add_product_term_to_df(term_a, term_b)
316319
for term in self.effect_modifiers:
317320
self.adjustment_set.add(term)
318321

@@ -399,10 +402,10 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
399402
individuals = pd.DataFrame(1, index=["control", "treated"], columns=model.params.index)
400403

401404
# This is a temporary hack
402-
for t in self.square_terms:
403-
individuals[t + "^2"] = individuals[t] ** 2
404-
for a, b in self.product_terms:
405-
individuals[f"{a}*{b}"] = individuals[a] * individuals[b]
405+
# for t in self.square_terms:
406+
# individuals[t + "^2"] = individuals[t] ** 2
407+
# for a, b in self.product_terms:
408+
# individuals[f"{a}*{b}"] = individuals[a] * individuals[b]
406409

407410
# It is ABSOLUTELY CRITICAL that these go last, otherwise we can't index
408411
# the effect with "ate = t_test_results.effect[0]"
@@ -429,7 +432,7 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
429432

430433
x = pd.DataFrame()
431434
x[self.treatment[0]] = [self.treatment_value, self.control_value]
432-
x["Intercept"] = self.intercept
435+
x["Intercept"] = 1#self.intercept
433436
for k, v in adjustment_config.items():
434437
x[k] = v
435438
for k, v in self.effect_modifiers.items():
@@ -485,7 +488,7 @@ def estimate_cates(self) -> tuple[float, list[float, float]]:
485488
), f"Must have at least one effect modifier to compute CATE - {self.effect_modifiers}."
486489
x = pd.DataFrame()
487490
x[self.treatment[0]] = [self.treatment_value, self.control_value]
488-
x["Intercept"] = self.intercept
491+
x["Intercept"] = 1#self.intercept
489492
for k, v in self.effect_modifiers.items():
490493
self.adjustment_set.add(k)
491494
x[k] = v
@@ -517,7 +520,7 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
517520
logger.debug(reduced_df[necessary_cols])
518521

519522
# 2. Add intercept
520-
reduced_df["Intercept"] = self.intercept
523+
reduced_df["Intercept"] = 1#self.intercept
521524

522525
# 3. Estimate the unit difference in outcome caused by unit difference in treatment
523526
cols = list(self.treatment)
@@ -529,8 +532,8 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
529532
treatment_and_adjustments_cols = pd.get_dummies(
530533
treatment_and_adjustments_cols, columns=[col], drop_first=True
531534
)
532-
regression = sm.OLS(outcome_col, treatment_and_adjustments_cols)
533-
model = regression.fit()
535+
# model = sm.OLS(outcome_col, treatment_and_adjustments_cols).fit()
536+
model = smf.ols(formula=self.formula, data=self.df).fit()
534537
return model
535538

536539
def _get_confidence_intervals(self, model):

tests/testing_tests/test_estimators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def setUpClass(cls) -> None:
125125
def test_program_11_2(self):
126126
"""Test whether our linear regression implementation produces the same results as program 11.2 (p. 141)."""
127127
df = self.chapter_11_df
128-
linear_regression_estimator = LinearRegressionEstimator(("treatments",), 100, 90, set(), ("outcomes",), df)
128+
linear_regression_estimator = LinearRegressionEstimator("treatments", 100, 90, set(), "outcomes", df)
129129
model = linear_regression_estimator._run_linear_regression()
130130
ate, _ = linear_regression_estimator.estimate_unit_ate()
131131

@@ -348,7 +348,7 @@ def test_X1_effect(self):
348348
"""When we fix the value of X2 to 0, the effect of X1 on Y should become ~2 (because X2 terms are cancelled)."""
349349
x2 = Input("X2", float)
350350
lr_model = LinearRegressionEstimator(
351-
("X1",), 1, 0, {"X2"}, ("Y",), effect_modifiers={x2: 0}, product_terms=[("X1", "X2")], df=self.df
351+
("X1",), 1, 0, {"X2"}, ("Y",), effect_modifiers={x2: 0}, formula="Y ~ X1 + X2 + (X1 * X2)", df=self.df
352352
)
353353
test_results = lr_model.estimate_ate()
354354
ate = test_results[0]

0 commit comments

Comments
 (0)