Skip to content

Commit cff3a71

Browse files
committed
cast device
1 parent 32a9a88 commit cff3a71

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/test_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
569569
(o1 * grad_proj).sum().backward()
570570
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
571571
scale = grad_ref.abs().mean()
572-
assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.01 * scale)
572+
assert torch.allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
573573

574574

575575

0 commit comments

Comments
 (0)