10
10
import statsmodels .api as sm
11
11
import statsmodels .formula .api as smf
12
12
from econml .dml import CausalForestDML
13
- from patsy import dmatrix , ModelDesc
13
+ from patsy import dmatrix
14
14
15
15
from sklearn .ensemble import GradientBoostingRegressor
16
16
from statsmodels .regression .linear_model import RegressionResultsWrapper
17
17
from statsmodels .tools .sm_exceptions import PerfectSeparationError
18
18
19
19
from causal_testing .specification .variable import Variable
20
- from causal_testing .specification .causal_dag import CausalDAG
21
20
22
21
logger = logging .getLogger (__name__ )
23
22
@@ -84,92 +83,7 @@ def compute_confidence_intervals(self) -> list[float, float]:
84
83
"""
85
84
86
85
87
- class RegressionEstimator (Estimator ):
88
- """An abstract class extending the Estimator functionality to add support for formulae, which are used in
89
- regression based estimators.
90
-
91
- """
92
-
93
- def __init__ (
94
- # pylint: disable=too-many-arguments
95
- self ,
96
- treatment : str ,
97
- treatment_value : float ,
98
- control_value : float ,
99
- adjustment_set : set ,
100
- outcome : str ,
101
- df : pd .DataFrame = None ,
102
- effect_modifiers : dict [str :Any ] = None ,
103
- formula : str = None ,
104
- alpha : float = 0.05 ,
105
- ):
106
- super ().__init__ (
107
- treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , alpha = alpha
108
- )
109
-
110
- if effect_modifiers is None :
111
- effect_modifiers = []
112
-
113
- if formula is not None :
114
- self .formula = formula
115
-
116
- else :
117
- terms = [treatment ] + sorted (list (adjustment_set )) + sorted (list (effect_modifiers ))
118
- self .formula = f"{ outcome } ~ { '+' .join (terms )} "
119
-
120
- @abstractmethod
121
- def add_modelling_assumptions (self ):
122
- """
123
- Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
124
- must hold if the resulting causal inference is to be considered valid.
125
- """
126
-
127
- def get_terms_from_formula (self ) -> tuple [str , str , list [str ]]:
128
- """
129
- Parse all the terms from a Patsy formula string into outcome, treatment and covariate variables.
130
-
131
- Formulae are expected to only have a single left hand side term.
132
-
133
- :return: a truple containing the outcome, treatment and covariate variable names in string format
134
- """
135
- desc = ModelDesc .from_formula (self .formula )
136
- if len (desc .lhs_termlist ) > 1 :
137
- raise ValueError ("More than 1 left hand side term provided in formula, only single term is accepted" )
138
- outcome = desc .lhs_termlist [0 ].factors [0 ].code
139
- rhs_terms = set ()
140
- for term in desc .rhs_termlist :
141
- if term .factors :
142
- rhs_terms .add (term .factors [0 ].code )
143
- if self .treatment not in rhs_terms :
144
- raise ValueError (f"Treatment variable '{ self .treatment } ' not found in formula" )
145
- rhs_terms .remove (self .treatment )
146
- covariates = rhs_terms
147
- if covariates is None :
148
- covariates = []
149
- else :
150
- covariates = list (covariates )
151
- return outcome , self .treatment , covariates
152
-
153
- def validate_formula (self , causal_dag : CausalDAG ):
154
- """
155
- Validate the provided Patsy formula string using the constructive backdoor criterion method found in the
156
- CausalDAG class
157
-
158
- :param causal_dag: A CausalDAG object containing for the current test scenario
159
- :return: True for a formula that does not violate the criteria and False if the formula does violate the
160
- criteria
161
- """
162
- outcome , treatment , covariates = self .get_terms_from_formula ()
163
- proper_backdoor_graph = causal_dag .get_proper_backdoor_graph (treatments = [treatment ], outcomes = [outcome ])
164
- return causal_dag .constructive_backdoor_criterion (
165
- proper_backdoor_graph = proper_backdoor_graph ,
166
- treatments = [treatment ],
167
- outcomes = [outcome ],
168
- covariates = list (covariates ),
169
- )
170
-
171
-
172
- class LogisticRegressionEstimator (RegressionEstimator ):
86
+ class LogisticRegressionEstimator (Estimator ):
173
87
"""A Logistic Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
174
88
combination of parameters and functions of the variables (note these functions need not be linear). It is designed
175
89
for estimating categorical outcomes.
@@ -187,12 +101,16 @@ def __init__(
187
101
effect_modifiers : dict [str :Any ] = None ,
188
102
formula : str = None ,
189
103
):
190
- super ().__init__ (
191
- treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , formula
192
- )
104
+ super ().__init__ (treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers )
193
105
194
106
self .model = None
195
107
108
+ if formula is not None :
109
+ self .formula = formula
110
+ else :
111
+ terms = [treatment ] + sorted (list (adjustment_set )) + sorted (list (self .effect_modifiers ))
112
+ self .formula = f"{ outcome } ~ { '+' .join (((terms )))} "
113
+
196
114
def add_modelling_assumptions (self ):
197
115
"""
198
116
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
@@ -356,7 +274,7 @@ def estimate_unit_odds_ratio(self) -> float:
356
274
return np .exp (model .params [self .treatment ])
357
275
358
276
359
- class LinearRegressionEstimator (RegressionEstimator ):
277
+ class LinearRegressionEstimator (Estimator ):
360
278
"""A Linear Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
361
279
combination of parameters and functions of the variables (note these functions need not be linear).
362
280
"""
@@ -375,18 +293,18 @@ def __init__(
375
293
alpha : float = 0.05 ,
376
294
):
377
295
super ().__init__ (
378
- treatment ,
379
- treatment_value ,
380
- control_value ,
381
- adjustment_set ,
382
- outcome ,
383
- df ,
384
- effect_modifiers ,
385
- alpha = alpha ,
386
- formula = formula ,
296
+ treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , alpha = alpha
387
297
)
388
298
389
299
self .model = None
300
+ if effect_modifiers is None :
301
+ effect_modifiers = []
302
+
303
+ if formula is not None :
304
+ self .formula = formula
305
+ else :
306
+ terms = [treatment ] + sorted (list (adjustment_set )) + sorted (list (effect_modifiers ))
307
+ self .formula = f"{ outcome } ~ { '+' .join (terms )} "
390
308
391
309
for term in self .effect_modifiers :
392
310
self .adjustment_set .add (term )
0 commit comments