@@ -109,6 +109,13 @@ def __init__(
109
109
terms = [treatment ] + sorted (list (adjustment_set )) + sorted (list (effect_modifiers ))
110
110
self .formula = f"{ outcome } ~ { '+' .join (terms )} "
111
111
112
+ @abstractmethod
113
+ def add_modelling_assumptions (self ):
114
+ """
115
+ Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
116
+ must hold if the resulting causal inference is to be considered valid.
117
+ """
118
+
112
119
def get_terms_from_formula (self ):
113
120
desc = ModelDesc .from_formula (self .formula )
114
121
if len (desc .lhs_termlist > 1 ):
@@ -131,7 +138,7 @@ def validate_formula(self, causal_dag: CausalDAG):
131
138
covariates = list (covariates ))
132
139
133
140
134
- class LogisticRegressionEstimator (Estimator ):
141
+ class LogisticRegressionEstimator (RegressionEstimator ):
135
142
"""A Logistic Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
136
143
combination of parameters and functions of the variables (note these functions need not be linear). It is designed
137
144
for estimating categorical outcomes.
@@ -149,16 +156,11 @@ def __init__(
149
156
effect_modifiers : dict [str :Any ] = None ,
150
157
formula : str = None ,
151
158
):
152
- super ().__init__ (treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers )
159
+ super ().__init__ (treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers ,
160
+ formula )
153
161
154
162
self .model = None
155
163
156
- if formula is not None :
157
- self .formula = formula
158
- else :
159
- terms = [treatment ] + sorted (list (adjustment_set )) + sorted (list (self .effect_modifiers ))
160
- self .formula = f"{ outcome } ~ { '+' .join (((terms )))} "
161
-
162
164
def add_modelling_assumptions (self ):
163
165
"""
164
166
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
@@ -322,7 +324,7 @@ def estimate_unit_odds_ratio(self) -> float:
322
324
return np .exp (model .params [self .treatment ])
323
325
324
326
325
- class LinearRegressionEstimator (Estimator ):
327
+ class LinearRegressionEstimator (RegressionEstimator ):
326
328
"""A Linear Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
327
329
combination of parameters and functions of the variables (note these functions need not be linear).
328
330
"""
@@ -342,18 +344,11 @@ def __init__(
342
344
343
345
):
344
346
super ().__init__ (
345
- treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , alpha = alpha
347
+ treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , alpha = alpha ,
348
+ formula = formula
346
349
)
347
350
348
351
self .model = None
349
- if effect_modifiers is None :
350
- effect_modifiers = []
351
-
352
- if formula is not None :
353
- self .formula = formula
354
- else :
355
- terms = [treatment ] + sorted (list (adjustment_set )) + sorted (list (effect_modifiers ))
356
- self .formula = f"{ outcome } ~ { '+' .join (terms )} "
357
352
358
353
for term in self .effect_modifiers :
359
354
self .adjustment_set .add (term )
0 commit comments