10
10
import statsmodels .api as sm
11
11
import statsmodels .formula .api as smf
12
12
from econml .dml import CausalForestDML
13
- from patsy import dmatrix
13
+ from patsy import dmatrix , ModelDesc
14
14
15
15
from sklearn .ensemble import GradientBoostingRegressor
16
16
from statsmodels .regression .linear_model import RegressionResultsWrapper
17
17
from statsmodels .tools .sm_exceptions import PerfectSeparationError
18
18
19
19
from causal_testing .specification .variable import Variable
20
+ from causal_testing .specification .causal_dag import CausalDAG
20
21
21
22
logger = logging .getLogger (__name__ )
22
23
@@ -83,10 +84,10 @@ def compute_confidence_intervals(self) -> list[float, float]:
83
84
"""
84
85
85
86
86
- class LogisticRegressionEstimator (Estimator ):
87
- """A Logistic Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
88
- combination of parameters and functions of the variables (note these functions need not be linear). It is designed
89
- for estimating categorical outcomes.
87
+ class RegressionEstimator (Estimator ):
88
+ """An abstract class extending the Estimator functionality to add support for formulae, which are used in
89
+ regression based estimators.
90
+
90
91
"""
91
92
92
93
def __init__ (
@@ -100,16 +101,97 @@ def __init__(
100
101
df : pd .DataFrame = None ,
101
102
effect_modifiers : dict [str :Any ] = None ,
102
103
formula : str = None ,
104
+ alpha : float = 0.05 ,
103
105
):
104
- super ().__init__ (treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers )
106
+ super ().__init__ (
107
+ treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , alpha = alpha
108
+ )
105
109
106
- self .model = None
110
+ if effect_modifiers is None :
111
+ effect_modifiers = []
107
112
108
113
if formula is not None :
109
114
self .formula = formula
115
+
116
+ else :
117
+ terms = [treatment ] + sorted (list (adjustment_set )) + sorted (list (effect_modifiers ))
118
+ self .formula = f"{ outcome } ~ { '+' .join (terms )} "
119
+
120
+ @abstractmethod
121
+ def add_modelling_assumptions (self ):
122
+ """
123
+ Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
124
+ must hold if the resulting causal inference is to be considered valid.
125
+ """
126
+
127
+ def get_terms_from_formula (self ) -> tuple [str , str , list [str ]]:
128
+ """
129
+ Parse all the terms from a Patsy formula string into outcome, treatment and covariate variables.
130
+
131
+ Formulae are expected to only have a single left hand side term.
132
+
133
+ :return: a truple containing the outcome, treatment and covariate variable names in string format
134
+ """
135
+ desc = ModelDesc .from_formula (self .formula )
136
+ if len (desc .lhs_termlist ) > 1 :
137
+ raise ValueError ("More than 1 left hand side term provided in formula, only single term is accepted" )
138
+ outcome = desc .lhs_termlist [0 ].factors [0 ].code
139
+ rhs_terms = set ()
140
+ for term in desc .rhs_termlist :
141
+ if term .factors :
142
+ rhs_terms .add (term .factors [0 ].code )
143
+ if self .treatment not in rhs_terms :
144
+ raise ValueError (f"Treatment variable '{ self .treatment } ' not found in formula" )
145
+ rhs_terms .remove (self .treatment )
146
+ covariates = rhs_terms
147
+ if covariates is None :
148
+ covariates = []
110
149
else :
111
- terms = [treatment ] + sorted (list (adjustment_set )) + sorted (list (self .effect_modifiers ))
112
- self .formula = f"{ outcome } ~ { '+' .join (((terms )))} "
150
+ covariates = list (covariates )
151
+ return outcome , self .treatment , covariates
152
+
153
+ def validate_formula (self , causal_dag : CausalDAG ):
154
+ """
155
+ Validate the provided Patsy formula string using the constructive backdoor criterion method found in the
156
+ CausalDAG class
157
+
158
+ :param causal_dag: A CausalDAG object containing for the current test scenario
159
+ :return: True for a formula that does not violate the criteria and False if the formula does violate the
160
+ criteria
161
+ """
162
+ outcome , treatment , covariates = self .get_terms_from_formula ()
163
+ proper_backdoor_graph = causal_dag .get_proper_backdoor_graph (treatments = [treatment ], outcomes = [outcome ])
164
+ return causal_dag .constructive_backdoor_criterion (
165
+ proper_backdoor_graph = proper_backdoor_graph ,
166
+ treatments = [treatment ],
167
+ outcomes = [outcome ],
168
+ covariates = list (covariates ),
169
+ )
170
+
171
+
172
+ class LogisticRegressionEstimator (RegressionEstimator ):
173
+ """A Logistic Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
174
+ combination of parameters and functions of the variables (note these functions need not be linear). It is designed
175
+ for estimating categorical outcomes.
176
+ """
177
+
178
+ def __init__ (
179
+ # pylint: disable=too-many-arguments
180
+ self ,
181
+ treatment : str ,
182
+ treatment_value : float ,
183
+ control_value : float ,
184
+ adjustment_set : set ,
185
+ outcome : str ,
186
+ df : pd .DataFrame = None ,
187
+ effect_modifiers : dict [str :Any ] = None ,
188
+ formula : str = None ,
189
+ ):
190
+ super ().__init__ (
191
+ treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , formula
192
+ )
193
+
194
+ self .model = None
113
195
114
196
def add_modelling_assumptions (self ):
115
197
"""
@@ -274,7 +356,7 @@ def estimate_unit_odds_ratio(self) -> float:
274
356
return np .exp (model .params [self .treatment ])
275
357
276
358
277
- class LinearRegressionEstimator (Estimator ):
359
+ class LinearRegressionEstimator (RegressionEstimator ):
278
360
"""A Linear Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
279
361
combination of parameters and functions of the variables (note these functions need not be linear).
280
362
"""
@@ -293,18 +375,18 @@ def __init__(
293
375
alpha : float = 0.05 ,
294
376
):
295
377
super ().__init__ (
296
- treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , alpha = alpha
378
+ treatment ,
379
+ treatment_value ,
380
+ control_value ,
381
+ adjustment_set ,
382
+ outcome ,
383
+ df ,
384
+ effect_modifiers ,
385
+ alpha = alpha ,
386
+ formula = formula ,
297
387
)
298
388
299
389
self .model = None
300
- if effect_modifiers is None :
301
- effect_modifiers = []
302
-
303
- if formula is not None :
304
- self .formula = formula
305
- else :
306
- terms = [treatment ] + sorted (list (adjustment_set )) + sorted (list (effect_modifiers ))
307
- self .formula = f"{ outcome } ~ { '+' .join (terms )} "
308
390
309
391
for term in self .effect_modifiers :
310
392
self .adjustment_set .add (term )
0 commit comments