@@ -124,16 +124,14 @@ def test_ate_adjustment(self):
124
124
logistic_regression_estimator = LogisticRegressionEstimator (
125
125
"length_in" , 65 , 55 , {"large_gauge" }, "completed" , df
126
126
)
127
- ate , _ = logistic_regression_estimator .estimate_ate (adjustment_config = {"large_gauge" : 0 })
127
+ ate , _ = logistic_regression_estimator .estimate_ate (adjustment_config = {"large_gauge" : 0 })
128
128
self .assertEqual (round (ate , 4 ), - 0.3388 )
129
129
130
130
def test_ate_invalid_adjustment (self ):
131
131
df = self .scarf_df .copy ()
132
132
logistic_regression_estimator = LogisticRegressionEstimator ("length_in" , 65 , 55 , {}, "completed" , df )
133
133
with self .assertRaises (ValueError ):
134
- ate , _ = logistic_regression_estimator .estimate_ate (
135
- adjustment_config = {"large_gauge" : 0 }
136
- )
134
+ ate , _ = logistic_regression_estimator .estimate_ate (adjustment_config = {"large_gauge" : 0 })
137
135
138
136
def test_ate_effect_modifiers (self ):
139
137
df = self .scarf_df .copy ()
@@ -215,6 +213,13 @@ def setUpClass(cls) -> None:
215
213
cls .nhefs_df = load_nhefs_df ()
216
214
cls .chapter_11_df = load_chapter_11_df ()
217
215
216
+ def test_query (self ):
217
+ df = self .nhefs_df
218
+ linear_regression_estimator = LinearRegressionEstimator (
219
+ "treatments" , None , None , set (), "outcomes" , df , query = "sex==1"
220
+ )
221
+ self .assertTrue (linear_regression_estimator .df .sex .all ())
222
+
218
223
def test_program_11_2 (self ):
219
224
"""Test whether our linear regression implementation produces the same results as program 11.2 (p. 141)."""
220
225
df = self .chapter_11_df
@@ -394,7 +399,7 @@ def test_program_15_no_interaction_ate_calculated(self):
394
399
# for term_to_square in terms_to_square:
395
400
396
401
ate , [ci_low , ci_high ] = linear_regression_estimator .estimate_ate_calculated (
397
- adjustment_config = {k : self .nhefs_df .mean ()[k ] for k in covariates }
402
+ adjustment_config = {k : self .nhefs_df .mean ()[k ] for k in covariates }
398
403
)
399
404
self .assertEqual (round (ate , 1 ), 3.5 )
400
405
self .assertEqual ([round (ci_low , 1 ), round (ci_high , 1 )], [1.9 , 5 ])
0 commit comments