7
7
import numpy as np
8
8
import pandas as pd
9
9
import statsmodels .api as sm
10
+ import statsmodels .formula .api as smf
10
11
from econml .dml import CausalForestDML
12
+
11
13
from sklearn .ensemble import GradientBoostingRegressor
12
14
from statsmodels .regression .linear_model import RegressionResultsWrapper
13
15
from statsmodels .tools .sm_exceptions import PerfectSeparationError
@@ -36,11 +38,11 @@ class Estimator(ABC):
36
38
37
39
def __init__ (
38
40
self ,
39
- treatment : tuple ,
41
+ treatment : str ,
40
42
treatment_value : float ,
41
43
control_value : float ,
42
44
adjustment_set : set ,
43
- outcome : tuple ,
45
+ outcome : str ,
44
46
df : pd .DataFrame = None ,
45
47
effect_modifiers : dict [Variable :Any ] = None ,
46
48
):
@@ -93,11 +95,11 @@ class LogisticRegressionEstimator(Estimator):
93
95
94
96
def __init__ (
95
97
self ,
96
- treatment : tuple ,
98
+ treatment : str ,
97
99
treatment_value : float ,
98
100
control_value : float ,
99
101
adjustment_set : set ,
100
- outcome : tuple ,
102
+ outcome : str ,
101
103
df : pd .DataFrame = None ,
102
104
effect_modifiers : dict [Variable :Any ] = None ,
103
105
intercept : int = 1 ,
@@ -292,27 +294,28 @@ class LinearRegressionEstimator(Estimator):
292
294
293
295
def __init__ (
294
296
self ,
295
- treatment : tuple ,
297
+ treatment : str ,
296
298
treatment_value : float ,
297
299
control_value : float ,
298
300
adjustment_set : set ,
299
- outcome : tuple ,
301
+ outcome : str ,
300
302
df : pd .DataFrame = None ,
301
303
effect_modifiers : dict [Variable :Any ] = None ,
302
- product_terms : list [tuple [Variable , Variable ]] = None ,
303
- intercept : int = 1 ,
304
+ formula : str = None
304
305
):
305
306
super ().__init__ (treatment , treatment_value , control_value , adjustment_set , outcome , df , effect_modifiers )
306
307
307
- self .product_terms = []
308
- self .square_terms = []
309
- self .inverse_terms = []
310
- self .intercept = intercept
311
308
self .model = None
309
+ if effect_modifiers is None :
310
+ effect_modifiers = []
311
+
312
+ if formula is not None :
313
+ # TODO: validate it
314
+ self .formula = formula
315
+ else :
316
+ terms = [treatment ] + sorted (list (adjustment_set )) + sorted (list (effect_modifiers ))
317
+ self .formula = f"{ outcome } ~ { '+' .join (((terms )))} + Intercept"
312
318
313
- if product_terms :
314
- for term_a , term_b in product_terms :
315
- self .add_product_term_to_df (term_a , term_b )
316
319
for term in self .effect_modifiers :
317
320
self .adjustment_set .add (term )
318
321
@@ -399,10 +402,10 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
399
402
individuals = pd .DataFrame (1 , index = ["control" , "treated" ], columns = model .params .index )
400
403
401
404
# This is a temporary hack
402
- for t in self .square_terms :
403
- individuals [t + "^2" ] = individuals [t ] ** 2
404
- for a , b in self .product_terms :
405
- individuals [f"{ a } *{ b } " ] = individuals [a ] * individuals [b ]
405
+ # for t in self.square_terms:
406
+ # individuals[t + "^2"] = individuals[t] ** 2
407
+ # for a, b in self.product_terms:
408
+ # individuals[f"{a}*{b}"] = individuals[a] * individuals[b]
406
409
407
410
# It is ABSOLUTELY CRITICAL that these go last, otherwise we can't index
408
411
# the effect with "ate = t_test_results.effect[0]"
@@ -429,7 +432,7 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
429
432
430
433
x = pd .DataFrame ()
431
434
x [self .treatment [0 ]] = [self .treatment_value , self .control_value ]
432
- x ["Intercept" ] = self .intercept
435
+ x ["Intercept" ] = 1 # self.intercept
433
436
for k , v in adjustment_config .items ():
434
437
x [k ] = v
435
438
for k , v in self .effect_modifiers .items ():
@@ -485,7 +488,7 @@ def estimate_cates(self) -> tuple[float, list[float, float]]:
485
488
), f"Must have at least one effect modifier to compute CATE - { self .effect_modifiers } ."
486
489
x = pd .DataFrame ()
487
490
x [self .treatment [0 ]] = [self .treatment_value , self .control_value ]
488
- x ["Intercept" ] = self .intercept
491
+ x ["Intercept" ] = 1 # self.intercept
489
492
for k , v in self .effect_modifiers .items ():
490
493
self .adjustment_set .add (k )
491
494
x [k ] = v
@@ -517,7 +520,7 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
517
520
logger .debug (reduced_df [necessary_cols ])
518
521
519
522
# 2. Add intercept
520
- reduced_df ["Intercept" ] = self .intercept
523
+ reduced_df ["Intercept" ] = 1 # self.intercept
521
524
522
525
# 3. Estimate the unit difference in outcome caused by unit difference in treatment
523
526
cols = list (self .treatment )
@@ -529,8 +532,8 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
529
532
treatment_and_adjustments_cols = pd .get_dummies (
530
533
treatment_and_adjustments_cols , columns = [col ], drop_first = True
531
534
)
532
- regression = sm .OLS (outcome_col , treatment_and_adjustments_cols )
533
- model = regression .fit ()
535
+ # model = sm.OLS(outcome_col, treatment_and_adjustments_cols).fit( )
536
+ model = smf . ols ( formula = self . formula , data = self . df ) .fit ()
534
537
return model
535
538
536
539
def _get_confidence_intervals (self , model ):
0 commit comments