@@ -76,19 +76,21 @@ class TestLogisticRegressionEstimator(unittest.TestCase):
76
76
77
77
@classmethod
78
78
def setUpClass (cls ) -> None :
79
- cls .scarf_df = pd .DataFrame ([
80
- { 'length_in' : 55 , 'large_gauge' : 1 , 'color' : 'orange' , 'completed' : 1 },
81
- { 'length_in' : 55 , 'large_gauge' : 0 , 'color' : 'orange' , 'completed' : 1 },
82
- { 'length_in' : 55 , 'large_gauge' : 0 , 'color' : 'brown' , 'completed' : 1 },
83
- { 'length_in' : 60 , 'large_gauge' : 0 , 'color' : 'brown' , 'completed' : 1 },
84
- { 'length_in' : 60 , 'large_gauge' : 0 , 'color' : 'grey' , 'completed' : 0 },
85
- { 'length_in' : 70 , 'large_gauge' : 0 , 'color' : 'grey' , 'completed' : 1 },
86
- { 'length_in' : 70 , 'large_gauge' : 0 , 'color' : 'orange' , 'completed' : 0 },
87
- { 'length_in' : 82 , 'large_gauge' : 1 , 'color' : 'grey' , 'completed' : 1 },
88
- { 'length_in' : 82 , 'large_gauge' : 0 , 'color' : 'brown' , 'completed' : 0 },
89
- { 'length_in' : 82 , 'large_gauge' : 0 , 'color' : 'orange' , 'completed' : 0 },
90
- { 'length_in' : 82 , 'large_gauge' : 1 , 'color' : 'brown' , 'completed' : 0 },
91
- ])
79
+ cls .scarf_df = pd .DataFrame (
80
+ [
81
+ {"length_in" : 55 , "large_gauge" : 1 , "color" : "orange" , "completed" : 1 },
82
+ {"length_in" : 55 , "large_gauge" : 0 , "color" : "orange" , "completed" : 1 },
83
+ {"length_in" : 55 , "large_gauge" : 0 , "color" : "brown" , "completed" : 1 },
84
+ {"length_in" : 60 , "large_gauge" : 0 , "color" : "brown" , "completed" : 1 },
85
+ {"length_in" : 60 , "large_gauge" : 0 , "color" : "grey" , "completed" : 0 },
86
+ {"length_in" : 70 , "large_gauge" : 0 , "color" : "grey" , "completed" : 1 },
87
+ {"length_in" : 70 , "large_gauge" : 0 , "color" : "orange" , "completed" : 0 },
88
+ {"length_in" : 82 , "large_gauge" : 1 , "color" : "grey" , "completed" : 1 },
89
+ {"length_in" : 82 , "large_gauge" : 0 , "color" : "brown" , "completed" : 0 },
90
+ {"length_in" : 82 , "large_gauge" : 0 , "color" : "orange" , "completed" : 0 },
91
+ {"length_in" : 82 , "large_gauge" : 1 , "color" : "brown" , "completed" : 0 },
92
+ ]
93
+ )
92
94
93
95
def test_ate (self ):
94
96
df = self .scarf_df .copy ()
@@ -110,7 +112,9 @@ def test_odds_ratio(self):
110
112
111
113
def test_ate_effect_modifiers (self ):
112
114
df = self .scarf_df .copy ()
113
- logistic_regression_estimator = LogisticRegressionEstimator ("length_in" , 65 , 55 , set (), "completed" , df , effect_modifiers = {"large_gauge" : 0 })
115
+ logistic_regression_estimator = LogisticRegressionEstimator (
116
+ "length_in" , 65 , 55 , set (), "completed" , df , effect_modifiers = {"large_gauge" : 0 }
117
+ )
114
118
ate , _ = logistic_regression_estimator .estimate_ate ()
115
119
self .assertEqual (round (ate , 4 ), - 0.3388 )
116
120
@@ -371,9 +375,7 @@ def test_program_15_ate(self):
371
375
"smokeintensity" ,
372
376
"smokeyrs" ,
373
377
}
374
- causal_forest = CausalForestEstimator (
375
- "qsmk" , 1 , 0 , covariates , "wt82_71" , df , {"smokeintensity" : 40 }
376
- )
378
+ causal_forest = CausalForestEstimator ("qsmk" , 1 , 0 , covariates , "wt82_71" , df , {"smokeintensity" : 40 })
377
379
ate , _ = causal_forest .estimate_ate ()
378
380
self .assertGreater (round (ate , 1 ), 2.5 )
379
381
self .assertLess (round (ate , 1 ), 4.5 )
0 commit comments