Skip to content

Commit 558d3e2

Browse files
Make use of class inheritance to reduce code in inits
1 parent 9b0c478 commit 558d3e2

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

causal_testing/testing/estimators.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ def __init__(
109109
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
110110
self.formula = f"{outcome} ~ {'+'.join(terms)}"
111111

112+
@abstractmethod
113+
def add_modelling_assumptions(self):
114+
"""
115+
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
116+
must hold if the resulting causal inference is to be considered valid.
117+
"""
118+
112119
def get_terms_from_formula(self):
113120
desc = ModelDesc.from_formula(self.formula)
114121
if len(desc.lhs_termlist > 1):
@@ -131,7 +138,7 @@ def validate_formula(self, causal_dag: CausalDAG):
131138
covariates=list(covariates))
132139

133140

134-
class LogisticRegressionEstimator(Estimator):
141+
class LogisticRegressionEstimator(RegressionEstimator):
135142
"""A Logistic Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
136143
combination of parameters and functions of the variables (note these functions need not be linear). It is designed
137144
for estimating categorical outcomes.
@@ -149,16 +156,11 @@ def __init__(
149156
effect_modifiers: dict[str:Any] = None,
150157
formula: str = None,
151158
):
152-
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
159+
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers,
160+
formula)
153161

154162
self.model = None
155163

156-
if formula is not None:
157-
self.formula = formula
158-
else:
159-
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(self.effect_modifiers))
160-
self.formula = f"{outcome} ~ {'+'.join(((terms)))}"
161-
162164
def add_modelling_assumptions(self):
163165
"""
164166
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
@@ -322,7 +324,7 @@ def estimate_unit_odds_ratio(self) -> float:
322324
return np.exp(model.params[self.treatment])
323325

324326

325-
class LinearRegressionEstimator(Estimator):
327+
class LinearRegressionEstimator(RegressionEstimator):
326328
"""A Linear Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
327329
combination of parameters and functions of the variables (note these functions need not be linear).
328330
"""
@@ -342,18 +344,11 @@ def __init__(
342344

343345
):
344346
super().__init__(
345-
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, alpha=alpha
347+
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, alpha=alpha,
348+
formula=formula
346349
)
347350

348351
self.model = None
349-
if effect_modifiers is None:
350-
effect_modifiers = []
351-
352-
if formula is not None:
353-
self.formula = formula
354-
else:
355-
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
356-
self.formula = f"{outcome} ~ {'+'.join(terms)}"
357352

358353
for term in self.effect_modifiers:
359354
self.adjustment_set.add(term)

0 commit comments

Comments
 (0)