Skip to content

Commit a2abf7b

Browse files
Get terms from formula method
1 parent 979e005 commit a2abf7b

File tree

1 file changed

+84
-44
lines changed

1 file changed

+84
-44
lines changed

causal_testing/testing/estimators.py

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import statsmodels.api as sm
1111
import statsmodels.formula.api as smf
1212
from econml.dml import CausalForestDML
13-
from patsy import dmatrix
13+
from patsy import dmatrix, ModelDesc
1414

1515
from sklearn.ensemble import GradientBoostingRegressor
1616
from statsmodels.regression.linear_model import RegressionResultsWrapper
@@ -40,16 +40,16 @@ class Estimator(ABC):
4040
"""
4141

4242
def __init__(
43-
# pylint: disable=too-many-arguments
44-
self,
45-
treatment: str,
46-
treatment_value: float,
47-
control_value: float,
48-
adjustment_set: set,
49-
outcome: str,
50-
df: pd.DataFrame = None,
51-
effect_modifiers: dict[str:Any] = None,
52-
alpha: float = 0.05,
43+
# pylint: disable=too-many-arguments
44+
self,
45+
treatment: str,
46+
treatment_value: float,
47+
control_value: float,
48+
adjustment_set: set,
49+
outcome: str,
50+
df: pd.DataFrame = None,
51+
effect_modifiers: dict[str:Any] = None,
52+
alpha: float = 0.05,
5353
):
5454
self.treatment = treatment
5555
self.treatment_value = treatment_value
@@ -83,23 +83,62 @@ def compute_confidence_intervals(self) -> list[float, float]:
8383
"""
8484

8585

86+
class RegressionEstimator(Estimator):
87+
"""
88+
89+
"""
90+
91+
def __init__(
92+
# pylint: disable=too-many-arguments
93+
self,
94+
treatment: str,
95+
treatment_value: float,
96+
control_value: float,
97+
adjustment_set: set,
98+
outcome: str,
99+
df: pd.DataFrame = None,
100+
effect_modifiers: dict[str:Any] = None,
101+
formula: str = None,
102+
):
103+
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
104+
105+
if formula is not None:
106+
self.formula = formula
107+
else:
108+
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
109+
self.formula = f"{outcome} ~ {'+'.join(terms)}"
110+
111+
def get_terms_from_formula(self):
112+
desc = ModelDesc.from_formula(self.formula)
113+
if len(desc.lhs_termlist > 1):
114+
raise ValueError("More than 1 left hand side term provided in formula, only single term is accepted")
115+
outcome = desc.lhs_termlist[0].factors[0].code
116+
rhs_terms = set()
117+
for term in desc.rhs_termlist:
118+
if term.factors:
119+
rhs_terms.add(term.factors[0].code)
120+
if self.treatment not in rhs_terms:
121+
raise ValueError(f"Treatment variable '{self.treatment}' not found in formula")
122+
covariates = rhs_terms.remove(self.treatment)
123+
return outcome, self.treatment, covariates
124+
86125
class LogisticRegressionEstimator(Estimator):
87126
"""A Logistic Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
88127
combination of parameters and functions of the variables (note these functions need not be linear). It is designed
89128
for estimating categorical outcomes.
90129
"""
91130

92131
def __init__(
93-
# pylint: disable=too-many-arguments
94-
self,
95-
treatment: str,
96-
treatment_value: float,
97-
control_value: float,
98-
adjustment_set: set,
99-
outcome: str,
100-
df: pd.DataFrame = None,
101-
effect_modifiers: dict[str:Any] = None,
102-
formula: str = None,
132+
# pylint: disable=too-many-arguments
133+
self,
134+
treatment: str,
135+
treatment_value: float,
136+
control_value: float,
137+
adjustment_set: set,
138+
outcome: str,
139+
df: pd.DataFrame = None,
140+
effect_modifiers: dict[str:Any] = None,
141+
formula: str = None,
103142
):
104143
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
105144

@@ -162,7 +201,7 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
162201
return model.predict(x)
163202

164203
def estimate_control_treatment(
165-
self, adjustment_config: dict = None, bootstrap_size: int = 100
204+
self, adjustment_config: dict = None, bootstrap_size: int = 100
166205
) -> tuple[pd.Series, pd.Series]:
167206
"""Estimate the outcomes under control and treatment.
168207
@@ -280,17 +319,18 @@ class LinearRegressionEstimator(Estimator):
280319
"""
281320

282321
def __init__(
283-
# pylint: disable=too-many-arguments
284-
self,
285-
treatment: str,
286-
treatment_value: float,
287-
control_value: float,
288-
adjustment_set: set,
289-
outcome: str,
290-
df: pd.DataFrame = None,
291-
effect_modifiers: dict[Variable:Any] = None,
292-
formula: str = None,
293-
alpha: float = 0.05,
322+
# pylint: disable=too-many-arguments
323+
self,
324+
treatment: str,
325+
treatment_value: float,
326+
control_value: float,
327+
adjustment_set: set,
328+
outcome: str,
329+
df: pd.DataFrame = None,
330+
effect_modifiers: dict[Variable:Any] = None,
331+
formula: str = None,
332+
alpha: float = 0.05,
333+
294334
):
295335
super().__init__(
296336
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, alpha=alpha
@@ -446,17 +486,17 @@ class InstrumentalVariableEstimator(Estimator):
446486
"""
447487

448488
def __init__(
449-
# pylint: disable=too-many-arguments
450-
self,
451-
treatment: str,
452-
treatment_value: float,
453-
control_value: float,
454-
adjustment_set: set,
455-
outcome: str,
456-
instrument: str,
457-
df: pd.DataFrame = None,
458-
intercept: int = 1,
459-
effect_modifiers: dict = None, # Not used (yet?). Needed for compatibility
489+
# pylint: disable=too-many-arguments
490+
self,
491+
treatment: str,
492+
treatment_value: float,
493+
control_value: float,
494+
adjustment_set: set,
495+
outcome: str,
496+
instrument: str,
497+
df: pd.DataFrame = None,
498+
intercept: int = 1,
499+
effect_modifiers: dict = None, # Not used (yet?). Needed for compatibility
460500
):
461501
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, None)
462502
self.intercept = intercept

0 commit comments

Comments
 (0)