@@ -41,16 +41,16 @@ class Estimator(ABC):
41
41
"""
42
42
43
43
def __init__ (
44
- # pylint: disable=too-many-arguments
45
- self ,
46
- treatment : str ,
47
- treatment_value : float ,
48
- control_value : float ,
49
- adjustment_set : set ,
50
- outcome : str ,
51
- df : pd .DataFrame = None ,
52
- effect_modifiers : dict [str :Any ] = None ,
53
- alpha : float = 0.05 ,
44
+ # pylint: disable=too-many-arguments
45
+ self ,
46
+ treatment : str ,
47
+ treatment_value : float ,
48
+ control_value : float ,
49
+ adjustment_set : set ,
50
+ outcome : str ,
51
+ df : pd .DataFrame = None ,
52
+ effect_modifiers : dict [str :Any ] = None ,
53
+ alpha : float = 0.05 ,
54
54
):
55
55
self .treatment = treatment
56
56
self .treatment_value = treatment_value
@@ -85,25 +85,24 @@ def compute_confidence_intervals(self) -> list[float, float]:
85
85
86
86
87
87
class RegressionEstimator (Estimator ):
88
- """
89
-
90
- """
88
+ """ """
91
89
92
90
def __init__ (
93
- # pylint: disable=too-many-arguments
94
- self ,
95
- treatment : str ,
96
- treatment_value : float ,
97
- control_value : float ,
98
- adjustment_set : set ,
99
- outcome : str ,
100
- df : pd .DataFrame = None ,
101
- effect_modifiers : dict [str :Any ] = None ,
102
- formula : str = None ,
103
- alpha : float = 0.05 ,
91
+ # pylint: disable=too-many-arguments
92
+ self ,
93
+ treatment : str ,
94
+ treatment_value : float ,
95
+ control_value : float ,
96
+ adjustment_set : set ,
97
+ outcome : str ,
98
+ df : pd .DataFrame = None ,
99
+ effect_modifiers : dict [str :Any ] = None ,
100
+ formula : str = None ,
101
+ alpha : float = 0.05 ,
104
102
):
105
- super ().__init__ (treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers ,
106
- alpha = alpha )
103
+ super ().__init__ (
104
+ treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , alpha = alpha
105
+ )
107
106
108
107
if effect_modifiers is None :
109
108
effect_modifiers = []
@@ -134,14 +133,19 @@ def get_terms_from_formula(self):
134
133
if self .treatment not in rhs_terms :
135
134
raise ValueError (f"Treatment variable '{ self .treatment } ' not found in formula" )
136
135
covariates = rhs_terms .remove (self .treatment )
136
+ if covariates is None :
137
+ covariates = []
137
138
return outcome , self .treatment , covariates
138
139
139
140
def validate_formula (self , causal_dag : CausalDAG ):
140
141
outcome , treatment , covariates = self .get_terms_from_formula ()
141
142
proper_backdoor_graph = causal_dag .get_proper_backdoor_graph (treatments = [treatment ], outcomes = [outcome ])
142
- return CausalDAG .constructive_backdoor_criterion (proper_backdoor_graph = proper_backdoor_graph ,
143
- treatments = [treatment ], outcomes = [outcome ],
144
- covariates = list (covariates ))
143
+ return causal_dag .constructive_backdoor_criterion (
144
+ proper_backdoor_graph = proper_backdoor_graph ,
145
+ treatments = [treatment ],
146
+ outcomes = [outcome ],
147
+ covariates = list (covariates ),
148
+ )
145
149
146
150
147
151
class LogisticRegressionEstimator (RegressionEstimator ):
@@ -151,19 +155,20 @@ class LogisticRegressionEstimator(RegressionEstimator):
151
155
"""
152
156
153
157
def __init__ (
154
- # pylint: disable=too-many-arguments
155
- self ,
156
- treatment : str ,
157
- treatment_value : float ,
158
- control_value : float ,
159
- adjustment_set : set ,
160
- outcome : str ,
161
- df : pd .DataFrame = None ,
162
- effect_modifiers : dict [str :Any ] = None ,
163
- formula : str = None ,
158
+ # pylint: disable=too-many-arguments
159
+ self ,
160
+ treatment : str ,
161
+ treatment_value : float ,
162
+ control_value : float ,
163
+ adjustment_set : set ,
164
+ outcome : str ,
165
+ df : pd .DataFrame = None ,
166
+ effect_modifiers : dict [str :Any ] = None ,
167
+ formula : str = None ,
164
168
):
165
- super ().__init__ (treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers ,
166
- formula )
169
+ super ().__init__ (
170
+ treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , formula
171
+ )
167
172
168
173
self .model = None
169
174
@@ -218,7 +223,7 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
218
223
return model .predict (x )
219
224
220
225
def estimate_control_treatment (
221
- self , adjustment_config : dict = None , bootstrap_size : int = 100
226
+ self , adjustment_config : dict = None , bootstrap_size : int = 100
222
227
) -> tuple [pd .Series , pd .Series ]:
223
228
"""Estimate the outcomes under control and treatment.
224
229
@@ -336,23 +341,28 @@ class LinearRegressionEstimator(RegressionEstimator):
336
341
"""
337
342
338
343
def __init__ (
339
- # pylint: disable=too-many-arguments
340
- self ,
341
- treatment : str ,
342
- treatment_value : float ,
343
- control_value : float ,
344
- adjustment_set : set ,
345
- outcome : str ,
346
- df : pd .DataFrame = None ,
347
- effect_modifiers : dict [Variable :Any ] = None ,
348
- formula : str = None ,
349
- alpha : float = 0.05 ,
350
-
344
+ # pylint: disable=too-many-arguments
345
+ self ,
346
+ treatment : str ,
347
+ treatment_value : float ,
348
+ control_value : float ,
349
+ adjustment_set : set ,
350
+ outcome : str ,
351
+ df : pd .DataFrame = None ,
352
+ effect_modifiers : dict [Variable :Any ] = None ,
353
+ formula : str = None ,
354
+ alpha : float = 0.05 ,
351
355
):
352
-
353
356
super ().__init__ (
354
- treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , alpha = alpha ,
355
- formula = formula
357
+ treatment ,
358
+ treatment_value ,
359
+ control_value ,
360
+ adjustment_set ,
361
+ outcome ,
362
+ df ,
363
+ effect_modifiers ,
364
+ alpha = alpha ,
365
+ formula = formula ,
356
366
)
357
367
358
368
self .model = None
@@ -497,17 +507,17 @@ class InstrumentalVariableEstimator(Estimator):
497
507
"""
498
508
499
509
def __init__ (
500
- # pylint: disable=too-many-arguments
501
- self ,
502
- treatment : str ,
503
- treatment_value : float ,
504
- control_value : float ,
505
- adjustment_set : set ,
506
- outcome : str ,
507
- instrument : str ,
508
- df : pd .DataFrame = None ,
509
- intercept : int = 1 ,
510
- effect_modifiers : dict = None , # Not used (yet?). Needed for compatibility
510
+ # pylint: disable=too-many-arguments
511
+ self ,
512
+ treatment : str ,
513
+ treatment_value : float ,
514
+ control_value : float ,
515
+ adjustment_set : set ,
516
+ outcome : str ,
517
+ instrument : str ,
518
+ df : pd .DataFrame = None ,
519
+ intercept : int = 1 ,
520
+ effect_modifiers : dict = None , # Not used (yet?). Needed for compatibility
511
521
):
512
522
super ().__init__ (treatment , treatment_value , control_value , adjustment_set , outcome , df , None )
513
523
self .intercept = intercept
0 commit comments