Skip to content

Commit 591f603

Browse files
committed
add memory efficient backward
1 parent 579b8c7 commit 591f603

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,6 @@ def backward(ctx, grad_output):
381381
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
382382

383383
elif state.CB is not None:
384-
raise NotImplementedError("WIP")
385384
CB = state.CB.to(ctx.dtype_B)
386385
CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype))
387386
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)

tests/test_modules.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ def __init__(self, initial_data):
1414

1515

1616
class MLP8bit(torch.nn.Module):
17-
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
17+
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
1818
super(MLP8bit, self).__init__()
1919
self.fc1 = bnb.nn.Linear8bitLt(
20-
dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
20+
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
21+
threshold=threshold
2122
)
2223
self.fc2 = bnb.nn.Linear8bitLt(
23-
dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
24+
dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
25+
threshold=threshold
2426
)
2527

2628
def forward(self, x):
@@ -451,9 +453,12 @@ def test_linear8bitlt_accumulated_gradient():
451453

452454

453455
@pytest.mark.parametrize("threshold", values, ids=names)
454-
def test_linear8bitlt_no_fp16_weights(threshold):
456+
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
457+
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
455458
l1 = (
456-
bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
459+
bnb.nn.Linear8bitLt(
460+
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
461+
)
457462
.cuda()
458463
.half()
459464
)
@@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
513518
assert mlp.fc2.weight.dtype == torch.int8
514519

515520
mlp = (
516-
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
521+
MLP8bit(
522+
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
523+
)
517524
.half()
518525
.to("cuda")
519526
)
@@ -532,7 +539,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
532539
assert mlp.fc2.weight.device.type == "cuda"
533540

534541
mlp = (
535-
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
542+
MLP8bit(
543+
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
544+
)
536545
.to(torch.float16)
537546
.to("cuda")
538547
)
@@ -551,6 +560,7 @@ def test_linear8bitlt_no_fp16_weights(threshold):
551560
assert mlp.fc2.weight.device.type == "cuda"
552561

553562

563+
554564
def test_linear8bitlt_fp32_bias():
555565
# casts model to fp16 -> int8 automatically
556566
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()

0 commit comments

Comments
 (0)