Skip to content

Commit 4b4a9ef

Browse files
committed
debugprint
1 parent 7906dc4 commit 4b4a9ef

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,9 +366,8 @@ def backward(ctx, grad_output):
366366
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
367367
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
368368
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
369-
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt).to(ctx.dtype_B)
369+
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
370370
if state.threshold > 0.0 and subA is not None:
371-
assert False, idx
372371
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
373372

374373
if req_gradA:
@@ -382,8 +381,7 @@ def backward(ctx, grad_output):
382381
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
383382

384383
elif state.CB is not None:
385-
CB = state.CB.to(ctx.dtype_B)
386-
CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype))
384+
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).div(127.0))
387385
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
388386
else:
389387
raise Exception('State must contain either CBt or CB matrix for backward')

0 commit comments

Comments
 (0)