Skip to content

Commit f072403

Browse files
committed
fix cxb
Signed-off-by: jiqing-feng <[email protected]>
1 parent b02b757 commit f072403

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,8 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
579579
@staticmethod
580580
def backward(ctx, grad_output):
581581
state = ctx.state
582-
CB = state.CB.to(ctx.dtype_A).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
582+
B = state.CxB if state.CxB is not None else state.CB
583+
CB = B.to(ctx.dtype_A).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
583584
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
584585

585586
return grad_A, None, None, None, None

0 commit comments

Comments
 (0)