Skip to content

Commit 32a9a88

Browse files
committed
cast device
1 parent 01b4c6a commit 32a9a88

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/test_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
569569
(o1 * grad_proj).sum().backward()
570570
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
571571
scale = grad_ref.abs().mean()
572-
assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.1 * scale)
572+
assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.01 * scale)
573573

574574

575575

0 commit comments

Comments
 (0)