Skip to content

Commit d18c407

Browse files
committed
codecov
1 parent 868ea5d commit d18c407

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

tests/testing_tests/test_estimators.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,14 @@ def test_ate_adjustment(self):
124124
logistic_regression_estimator = LogisticRegressionEstimator(
125125
"length_in", 65, 55, {"large_gauge"}, "completed", df
126126
)
127-
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config = {"large_gauge": 0})
127+
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
128128
self.assertEqual(round(ate, 4), -0.3388)
129129

130130
def test_ate_invalid_adjustment(self):
131131
df = self.scarf_df.copy()
132132
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
133133
with self.assertRaises(ValueError):
134-
ate, _ = logistic_regression_estimator.estimate_ate(
135-
adjustment_config = {"large_gauge": 0}
136-
)
134+
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
137135

138136
def test_ate_effect_modifiers(self):
139137
df = self.scarf_df.copy()
@@ -215,6 +213,13 @@ def setUpClass(cls) -> None:
215213
cls.nhefs_df = load_nhefs_df()
216214
cls.chapter_11_df = load_chapter_11_df()
217215

216+
def test_query(self):
217+
df = self.nhefs_df
218+
linear_regression_estimator = LinearRegressionEstimator(
219+
"treatments", None, None, set(), "outcomes", df, query="sex==1"
220+
)
221+
self.assertTrue(linear_regression_estimator.df.sex.all())
222+
218223
def test_program_11_2(self):
219224
"""Test whether our linear regression implementation produces the same results as program 11.2 (p. 141)."""
220225
df = self.chapter_11_df
@@ -394,7 +399,7 @@ def test_program_15_no_interaction_ate_calculated(self):
394399
# for term_to_square in terms_to_square:
395400

396401
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
397-
adjustment_config = {k: self.nhefs_df.mean()[k] for k in covariates}
402+
adjustment_config={k: self.nhefs_df.mean()[k] for k in covariates}
398403
)
399404
self.assertEqual(round(ate, 1), 3.5)
400405
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])

0 commit comments

Comments
 (0)