@@ -217,7 +217,7 @@ def test_program_11_2(self):
217
217
self .assertEqual (round (model .params ["Intercept" ] + 90 * model .params ["treatments" ], 1 ), 216.9 )
218
218
219
219
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
220
- self .assertEqual ( round (model .params ["treatments" ], 1 ), round (ate [ 0 ] , 1 ))
220
+ self .assertTrue ( all ( round (model .params ["treatments" ], 1 ) == round (ate_single , 1 ) for ate_single in ate ))
221
221
222
222
def test_program_11_3 (self ):
223
223
"""Test whether our linear regression implementation produces the same results as program 11.3 (p. 144)."""
@@ -237,7 +237,7 @@ def test_program_11_3(self):
237
237
197.1 ,
238
238
)
239
239
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
240
- self .assertEqual ( round (model .params ["treatments" ], 3 ), round (ate [ 0 ] , 3 ))
240
+ self .assertTrue ( all ( round (model .params ["treatments" ], 3 ) == round (ate_single , 3 ) for ate_single in ate ))
241
241
242
242
def test_program_15_1A (self ):
243
243
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)."""
@@ -315,6 +315,7 @@ def test_program_15_no_interaction(self):
315
315
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
316
316
# for term_to_square in terms_to_square:
317
317
ate , [ci_low , ci_high ] = linear_regression_estimator .estimate_coefficient ()
318
+
318
319
self .assertEqual (round (ate [0 ], 1 ), 3.5 )
319
320
self .assertEqual ([round (ci_low [0 ], 1 ), round (ci_high [0 ], 1 )], [2.6 , 4.3 ])
320
321
0 commit comments