Skip to content

Commit 2cd047e

Browse files
committed
run backward
1 parent 591f603 commit 2cd047e

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/test_modules.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,11 +554,22 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
554554
assert mlp.fc1.state.idx is not None
555555
if threshold > 0:
556556
assert mlp.fc2.state.idx is not None
557+
557558
assert mlp.fc1.weight.dtype == torch.int8
558559
assert mlp.fc2.weight.dtype == torch.int8
559560
assert mlp.fc1.weight.device.type == "cuda"
560561
assert mlp.fc2.weight.device.type == "cuda"
561562

563+
if memory_efficient_backward:
564+
b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
565+
o1 = mlp(b1)
566+
assert o1.dtype == torch.float16
567+
assert o1.requires_grad
568+
grad_proj = torch.randn_like(o1)
569+
570+
(o1 * grad_proj).sum().backward()
571+
572+
562573

563574

564575
def test_linear8bitlt_fp32_bias():

0 commit comments

Comments
 (0)