Skip to content

Commit 123d4db

Browse files
tests represent the logic of returning Series better
1 parent 33e7e53 commit 123d4db

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/testing_tests/test_estimators.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_program_11_2(self):
217217
self.assertEqual(round(model.params["Intercept"] + 90 * model.params["treatments"], 1), 216.9)
218218

219219
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
220-
self.assertEqual(round(model.params["treatments"], 1), round(ate[0], 1))
220+
self.assertTrue(all(round(model.params["treatments"], 1) == round(ate_single, 1) for ate_single in ate))
221221

222222
def test_program_11_3(self):
223223
"""Test whether our linear regression implementation produces the same results as program 11.3 (p. 144)."""
@@ -237,7 +237,7 @@ def test_program_11_3(self):
237237
197.1,
238238
)
239239
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
240-
self.assertEqual(round(model.params["treatments"], 3), round(ate[0], 3))
240+
self.assertTrue(all(round(model.params["treatments"], 3) == round(ate_single, 3) for ate_single in ate))
241241

242242
def test_program_15_1A(self):
243243
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)."""
@@ -315,6 +315,7 @@ def test_program_15_no_interaction(self):
315315
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
316316
# for term_to_square in terms_to_square:
317317
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_coefficient()
318+
318319
self.assertEqual(round(ate[0], 1), 3.5)
319320
self.assertEqual([round(ci_low[0], 1), round(ci_high[0], 1)], [2.6, 4.3])
320321

0 commit comments

Comments
 (0)