Skip to content

Commit a07825a

Browse files
committed
review
1 parent 9b7d307 commit a07825a

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

tests/test_modules.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -569,12 +569,10 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
569569
(o1 * grad_proj).sum().backward()
570570
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
571571
scale = grad_ref.abs().mean()
572-
assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
573-
574-
575-
576-
577572

573+
torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
574+
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
575+
assert (idx == 0).sum().item() <= b1.numel() * 0.0
578576

579577

580578
def test_linear8bitlt_fp32_bias():

0 commit comments

Comments
 (0)