Skip to content

Commit 6a826c4

Browse files
committed
pre-cast
1 parent d9b8789 commit 6a826c4

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

tests/test_modules.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -538,14 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
538538
assert mlp.fc1.weight.device.type == "cuda"
539539
assert mlp.fc2.weight.device.type == "cuda"
540540

541-
mlp = (
542-
MLP8bit(
541+
mlp = MLP8bit(
543542
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
544543
)
545-
.to(torch.float16)
546-
.to("cuda")
547-
)
548544
w1, w2 = mlp.fc1.weight.clone(), mlp.fc2.weight.clone()
545+
mlp = mlp.cuda().half()
549546

550547
for i in range(100):
551548
b1 = torch.randn(16, 8, 32, device="cuda").half()

0 commit comments

Comments
 (0)