10
10
import statsmodels .api as sm
11
11
import statsmodels .formula .api as smf
12
12
from econml .dml import CausalForestDML
13
- from patsy import dmatrix
13
+ from patsy import dmatrix # pylint: disable = no-name-in-module
14
14
15
15
from sklearn .ensemble import GradientBoostingRegressor
16
16
from statsmodels .regression .linear_model import RegressionResultsWrapper
@@ -50,21 +50,25 @@ def __init__(
50
50
df : pd .DataFrame = None ,
51
51
effect_modifiers : dict [str :Any ] = None ,
52
52
alpha : float = 0.05 ,
53
+ query : str = "" ,
53
54
):
54
55
self .treatment = treatment
55
56
self .treatment_value = treatment_value
56
57
self .control_value = control_value
57
58
self .adjustment_set = adjustment_set
58
59
self .outcome = outcome
59
- self .df = df
60
60
self .alpha = alpha
61
+ self .df = df .query (query ) if query else df
62
+
61
63
if effect_modifiers is None :
62
64
self .effect_modifiers = {}
63
65
elif isinstance (effect_modifiers , dict ):
64
66
self .effect_modifiers = effect_modifiers
65
67
else :
66
68
raise ValueError (f"Unsupported type for effect_modifiers { effect_modifiers } . Expected iterable" )
67
69
self .modelling_assumptions = []
70
+ if query :
71
+ self .modelling_assumptions .append (query )
68
72
self .add_modelling_assumptions ()
69
73
logger .debug ("Effect Modifiers: %s" , self .effect_modifiers )
70
74
@@ -100,8 +104,18 @@ def __init__(
100
104
df : pd .DataFrame = None ,
101
105
effect_modifiers : dict [str :Any ] = None ,
102
106
formula : str = None ,
107
+ query : str = "" ,
103
108
):
104
- super ().__init__ (treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers )
109
+ super ().__init__ (
110
+ treatment = treatment ,
111
+ treatment_value = treatment_value ,
112
+ control_value = control_value ,
113
+ adjustment_set = adjustment_set ,
114
+ outcome = outcome ,
115
+ df = df ,
116
+ effect_modifiers = effect_modifiers ,
117
+ query = query ,
118
+ )
105
119
106
120
self .model = None
107
121
@@ -116,13 +130,13 @@ def add_modelling_assumptions(self):
116
130
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
117
131
must hold if the resulting causal inference is to be considered valid.
118
132
"""
119
- self .modelling_assumptions += (
133
+ self .modelling_assumptions . append (
120
134
"The variables in the data must fit a shape which can be expressed as a linear"
121
135
"combination of parameters and functions of variables. Note that these functions"
122
136
"do not need to be linear."
123
137
)
124
- self .modelling_assumptions += "The outcome must be binary."
125
- self .modelling_assumptions += "Independently and identically distributed errors."
138
+ self .modelling_assumptions . append ( "The outcome must be binary." )
139
+ self .modelling_assumptions . append ( "Independently and identically distributed errors." )
126
140
127
141
def _run_logistic_regression (self , data ) -> RegressionResultsWrapper :
128
142
"""Run logistic regression of the treatment and adjustment set against the outcome and return the model.
@@ -291,9 +305,18 @@ def __init__(
291
305
effect_modifiers : dict [Variable :Any ] = None ,
292
306
formula : str = None ,
293
307
alpha : float = 0.05 ,
308
+ query : str = "" ,
294
309
):
295
310
super ().__init__ (
296
- treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers , alpha = alpha
311
+ treatment ,
312
+ treatment_value ,
313
+ control_value ,
314
+ adjustment_set ,
315
+ outcome ,
316
+ df ,
317
+ effect_modifiers ,
318
+ alpha = alpha ,
319
+ query = query ,
297
320
)
298
321
299
322
self .model = None
@@ -314,7 +337,7 @@ def add_modelling_assumptions(self):
314
337
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
315
338
must hold if the resulting causal inference is to be considered valid.
316
339
"""
317
- self .modelling_assumptions += (
340
+ self .modelling_assumptions . append (
318
341
"The variables in the data must fit a shape which can be expressed as a linear"
319
342
"combination of parameters and functions of variables. Note that these functions"
320
343
"do not need to be linear."
@@ -509,8 +532,20 @@ def __init__(
509
532
df : pd .DataFrame = None ,
510
533
intercept : int = 1 ,
511
534
effect_modifiers : dict = None , # Not used (yet?). Needed for compatibility
535
+ alpha : float = 0.05 ,
536
+ query : str = "" ,
512
537
):
513
- super ().__init__ (treatment , treatment_value , control_value , adjustment_set , outcome , df , None )
538
+ super ().__init__ (
539
+ treatment = treatment ,
540
+ treatment_value = treatment_value ,
541
+ control_value = control_value ,
542
+ adjustment_set = adjustment_set ,
543
+ outcome = outcome ,
544
+ df = df ,
545
+ effect_modifiers = None ,
546
+ alpha = alpha ,
547
+ query = query ,
548
+ )
514
549
self .intercept = intercept
515
550
self .model = None
516
551
self .instrument = instrument
@@ -520,13 +555,17 @@ def add_modelling_assumptions(self):
520
555
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
521
556
must hold if the resulting causal inference is to be considered valid.
522
557
"""
523
- self .modelling_assumptions += """The instrument and the treatment, and the treatment and the outcome must be
558
+ self .modelling_assumptions .append (
559
+ """The instrument and the treatment, and the treatment and the outcome must be
524
560
related linearly in the form Y = aX + b."""
525
- self .modelling_assumptions += """The three IV conditions must hold
561
+ )
562
+ self .modelling_assumptions .append (
563
+ """The three IV conditions must hold
526
564
(i) Instrument is associated with treatment
527
565
(ii) Instrument does not affect outcome except through its potential effect on treatment
528
566
(iii) Instrument and outcome do not share causes
529
567
"""
568
+ )
530
569
531
570
def estimate_iv_coefficient (self , df ):
532
571
"""
@@ -569,7 +608,7 @@ def add_modelling_assumptions(self):
569
608
570
609
:return self: Update self.modelling_assumptions
571
610
"""
572
- self .modelling_assumptions += "Non-parametric estimator: no restrictions imposed on the data."
611
+ self .modelling_assumptions . append ( "Non-parametric estimator: no restrictions imposed on the data." )
573
612
574
613
def estimate_ate (self ) -> float :
575
614
"""Estimate the average treatment effect.
0 commit comments