Skip to content

Commit 4bd1151

Browse files
committed
Fixed gradient accumulation test.
1 parent 675baa7 commit 4bd1151

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,6 @@ def backward(ctx, grad_output):
456456

457457
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
458458
if req_gradB:
459-
#grad_B = torch.matmul(grad_output.t(), A)
460459
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
461460
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
462461
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)

tests/test_modules.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,13 @@ def test_linear8bitlt_inference(threshold):
332332
def test_linear8bitlt_accumulated_gradient():
333333
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
334334
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
335-
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
336-
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
337-
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
338-
l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
339-
opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
340-
opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
335+
l1[0].weight.data.copy_(l2[0].weight.data)
336+
l1[1].weight.data.copy_(l2[1].weight.data)
337+
l1[0].bias.data.copy_(l2[0].bias.data)
338+
l1[1].bias.data.copy_(l2[1].bias.data)
339+
340+
opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
341+
opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
341342

342343
acc_steps = 10
343344

@@ -353,7 +354,6 @@ def test_linear8bitlt_accumulated_gradient():
353354
assert l1[0].state.CxB is not None
354355
assert l1[1].state.CxB is not None
355356

356-
print(i)
357357
if i > 0 and i % acc_steps == 0:
358358
opt1.step()
359359
opt1.zero_grad(True)
@@ -368,9 +368,11 @@ def test_linear8bitlt_accumulated_gradient():
368368
# we do this copy because otherwise we have small divergences over time that add up
369369
l1[0].weight.data.copy_(l2[0].weight.data)
370370
l1[1].weight.data.copy_(l2[1].weight.data)
371+
l1[0].bias.data.copy_(l2[0].bias.data)
372+
l1[1].bias.data.copy_(l2[1].bias.data)
371373
else:
372-
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad)
373-
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad)
374+
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
375+
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
374376

375377

376378
@pytest.mark.parametrize("threshold", [0.0, 2.0])

0 commit comments

Comments
 (0)