Skip to content

Commit e4086a2

Browse files
committed
cast device
1 parent 725cc72 commit e4086a2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/test_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
567567

568568
mlp.zero_grad()
569569
(o1 * grad_proj).sum().backward()
570-
grad_ref = grad_proj.flatten(2) @ w2 @ w1
570+
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
571571
assert torch.allclose(b1.grad, grad_ref)
572572

573573

0 commit comments

Comments
 (0)