Skip to content

Commit a5f2dba

Browse files
Balandatfacebook-github-bot
authored andcommitted
Fix unit test for initialize_q_batch (#169)
Summary: The `test_initialize_q_batch_largeZ` test didn't actually test the new codepath (see code coverage results), since with `n=5` the whole logic gets short-circuited and `X` is returned. This fixes this with a proper test, getting back to full coverage. Pull Request resolved: #169 Reviewed By: bkarrer Differential Revision: D15700751 Pulled By: Balandat fbshipit-source-id: 708bc6309fbc8a42ebd65f2ec5b77b644dfc1065
1 parent 5d8e818 commit a5f2dba

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

test/optim/test_initializers.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,9 @@ def test_initialize_q_batch_largeZ(self, cuda=False):
9090
for dtype in (torch.float, torch.double):
9191
# testing large eta*Z
9292
X = torch.rand(5, 3, 4, device=device, dtype=dtype)
93-
Y = torch.rand(5, device=device, dtype=dtype)
94-
Ystd = Y.std()
95-
Z = (Y - Y.mean()) / Ystd
96-
eta = (1e6 / (torch.abs(Z) + 1e-7).min()).item()
97-
ics = initialize_q_batch(X=X, Y=Y, n=5, eta=eta)
98-
self.assertTrue(torch.equal(X, ics))
93+
Y = torch.tensor([-1e12, 0, 0, 0, 1e12], device=device, dtype=dtype)
94+
ics = initialize_q_batch(X=X, Y=Y, n=2, eta=100)
95+
self.assertEqual(ics.shape[0], 2)
9996

10097
def test_initialize_q_batch_largeZ_cuda(self):
10198
if torch.cuda.is_available():

0 commit comments

Comments
 (0)