|
10 | 10 | import statsmodels.api as sm
|
11 | 11 | import statsmodels.formula.api as smf
|
12 | 12 | from econml.dml import CausalForestDML
|
13 |
| -from patsy import dmatrix |
| 13 | +from patsy import dmatrix, ModelDesc |
14 | 14 |
|
15 | 15 | from sklearn.ensemble import GradientBoostingRegressor
|
16 | 16 | from statsmodels.regression.linear_model import RegressionResultsWrapper
|
@@ -40,16 +40,16 @@ class Estimator(ABC):
|
40 | 40 | """
|
41 | 41 |
|
42 | 42 | def __init__(
|
43 |
| - # pylint: disable=too-many-arguments |
44 |
| - self, |
45 |
| - treatment: str, |
46 |
| - treatment_value: float, |
47 |
| - control_value: float, |
48 |
| - adjustment_set: set, |
49 |
| - outcome: str, |
50 |
| - df: pd.DataFrame = None, |
51 |
| - effect_modifiers: dict[str:Any] = None, |
52 |
| - alpha: float = 0.05, |
| 43 | + # pylint: disable=too-many-arguments |
| 44 | + self, |
| 45 | + treatment: str, |
| 46 | + treatment_value: float, |
| 47 | + control_value: float, |
| 48 | + adjustment_set: set, |
| 49 | + outcome: str, |
| 50 | + df: pd.DataFrame = None, |
| 51 | + effect_modifiers: dict[str:Any] = None, |
| 52 | + alpha: float = 0.05, |
53 | 53 | ):
|
54 | 54 | self.treatment = treatment
|
55 | 55 | self.treatment_value = treatment_value
|
@@ -83,23 +83,62 @@ def compute_confidence_intervals(self) -> list[float, float]:
|
83 | 83 | """
|
84 | 84 |
|
85 | 85 |
|
| 86 | +class RegressionEstimator(Estimator): |
| 87 | + """ |
| 88 | +
|
| 89 | + """ |
| 90 | + |
| 91 | + def __init__( |
| 92 | + # pylint: disable=too-many-arguments |
| 93 | + self, |
| 94 | + treatment: str, |
| 95 | + treatment_value: float, |
| 96 | + control_value: float, |
| 97 | + adjustment_set: set, |
| 98 | + outcome: str, |
| 99 | + df: pd.DataFrame = None, |
| 100 | + effect_modifiers: dict[str:Any] = None, |
| 101 | + formula: str = None, |
| 102 | + ): |
| 103 | + super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers) |
| 104 | + |
| 105 | + if formula is not None: |
| 106 | + self.formula = formula |
| 107 | + else: |
| 108 | + terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers)) |
| 109 | + self.formula = f"{outcome} ~ {'+'.join(terms)}" |
| 110 | + |
| 111 | + def get_terms_from_formula(self): |
| 112 | + desc = ModelDesc.from_formula(self.formula) |
| 113 | + if len(desc.lhs_termlist > 1): |
| 114 | + raise ValueError("More than 1 left hand side term provided in formula, only single term is accepted") |
| 115 | + outcome = desc.lhs_termlist[0].factors[0].code |
| 116 | + rhs_terms = set() |
| 117 | + for term in desc.rhs_termlist: |
| 118 | + if term.factors: |
| 119 | + rhs_terms.add(term.factors[0].code) |
| 120 | + if self.treatment not in rhs_terms: |
| 121 | + raise ValueError(f"Treatment variable '{self.treatment}' not found in formula") |
| 122 | + covariates = rhs_terms.remove(self.treatment) |
| 123 | + return outcome, self.treatment, covariates |
| 124 | + |
86 | 125 | class LogisticRegressionEstimator(Estimator):
|
87 | 126 | """A Logistic Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
|
88 | 127 | combination of parameters and functions of the variables (note these functions need not be linear). It is designed
|
89 | 128 | for estimating categorical outcomes.
|
90 | 129 | """
|
91 | 130 |
|
92 | 131 | def __init__(
|
93 |
| - # pylint: disable=too-many-arguments |
94 |
| - self, |
95 |
| - treatment: str, |
96 |
| - treatment_value: float, |
97 |
| - control_value: float, |
98 |
| - adjustment_set: set, |
99 |
| - outcome: str, |
100 |
| - df: pd.DataFrame = None, |
101 |
| - effect_modifiers: dict[str:Any] = None, |
102 |
| - formula: str = None, |
| 132 | + # pylint: disable=too-many-arguments |
| 133 | + self, |
| 134 | + treatment: str, |
| 135 | + treatment_value: float, |
| 136 | + control_value: float, |
| 137 | + adjustment_set: set, |
| 138 | + outcome: str, |
| 139 | + df: pd.DataFrame = None, |
| 140 | + effect_modifiers: dict[str:Any] = None, |
| 141 | + formula: str = None, |
103 | 142 | ):
|
104 | 143 | super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
|
105 | 144 |
|
@@ -162,7 +201,7 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
|
162 | 201 | return model.predict(x)
|
163 | 202 |
|
164 | 203 | def estimate_control_treatment(
|
165 |
| - self, adjustment_config: dict = None, bootstrap_size: int = 100 |
| 204 | + self, adjustment_config: dict = None, bootstrap_size: int = 100 |
166 | 205 | ) -> tuple[pd.Series, pd.Series]:
|
167 | 206 | """Estimate the outcomes under control and treatment.
|
168 | 207 |
|
@@ -280,17 +319,18 @@ class LinearRegressionEstimator(Estimator):
|
280 | 319 | """
|
281 | 320 |
|
282 | 321 | def __init__(
|
283 |
| - # pylint: disable=too-many-arguments |
284 |
| - self, |
285 |
| - treatment: str, |
286 |
| - treatment_value: float, |
287 |
| - control_value: float, |
288 |
| - adjustment_set: set, |
289 |
| - outcome: str, |
290 |
| - df: pd.DataFrame = None, |
291 |
| - effect_modifiers: dict[Variable:Any] = None, |
292 |
| - formula: str = None, |
293 |
| - alpha: float = 0.05, |
| 322 | + # pylint: disable=too-many-arguments |
| 323 | + self, |
| 324 | + treatment: str, |
| 325 | + treatment_value: float, |
| 326 | + control_value: float, |
| 327 | + adjustment_set: set, |
| 328 | + outcome: str, |
| 329 | + df: pd.DataFrame = None, |
| 330 | + effect_modifiers: dict[Variable:Any] = None, |
| 331 | + formula: str = None, |
| 332 | + alpha: float = 0.05, |
| 333 | + |
294 | 334 | ):
|
295 | 335 | super().__init__(
|
296 | 336 | treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, alpha=alpha
|
@@ -446,17 +486,17 @@ class InstrumentalVariableEstimator(Estimator):
|
446 | 486 | """
|
447 | 487 |
|
448 | 488 | def __init__(
|
449 |
| - # pylint: disable=too-many-arguments |
450 |
| - self, |
451 |
| - treatment: str, |
452 |
| - treatment_value: float, |
453 |
| - control_value: float, |
454 |
| - adjustment_set: set, |
455 |
| - outcome: str, |
456 |
| - instrument: str, |
457 |
| - df: pd.DataFrame = None, |
458 |
| - intercept: int = 1, |
459 |
| - effect_modifiers: dict = None, # Not used (yet?). Needed for compatibility |
| 489 | + # pylint: disable=too-many-arguments |
| 490 | + self, |
| 491 | + treatment: str, |
| 492 | + treatment_value: float, |
| 493 | + control_value: float, |
| 494 | + adjustment_set: set, |
| 495 | + outcome: str, |
| 496 | + instrument: str, |
| 497 | + df: pd.DataFrame = None, |
| 498 | + intercept: int = 1, |
| 499 | + effect_modifiers: dict = None, # Not used (yet?). Needed for compatibility |
460 | 500 | ):
|
461 | 501 | super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, None)
|
462 | 502 | self.intercept = intercept
|
|
0 commit comments