Skip to content

Commit b02b757

Browse files
committed
new matmul8bit
Signed-off-by: jiqing-feng <[email protected]>
1 parent f6025bc commit b02b757

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,28 @@ def backward(ctx, grad_output):
563563
return grad_A, grad_B, None, grad_bias, None
564564

565565

566+
class MatMul8bitFp(torch.autograd.Function):
567+
# For Intel CPU and XPU, the double quant has many unsafe operations which will breaks the finetune.
568+
# We'd like to use dequant + matmul to run finetune currently.
569+
570+
@staticmethod
571+
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
572+
CB = B.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)).t()
573+
output = torch.matmul(A, CB).to(A.dtype)
574+
ctx.state = state
575+
ctx.dtype_A = A.dtype
576+
ctx.grad_shape = A.shape
577+
return output
578+
579+
@staticmethod
580+
def backward(ctx, grad_output):
581+
state = ctx.state
582+
CB = state.CB.to(ctx.dtype_A).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
583+
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
584+
585+
return grad_A, None, None, None, None
586+
587+
566588
def matmul(
567589
A: torch.Tensor,
568590
B: torch.Tensor,
@@ -574,6 +596,8 @@ def matmul(
574596
state = state or MatmulLtState()
575597
if threshold > 0.0:
576598
state.threshold = threshold
599+
if A.device.type in ("cpu", "xpu") and state.is_training:
600+
return MatMul8bitFp.apply(A, B, out, bias, state)
577601
return MatMul8bitLt.apply(A, B, out, bias, state)
578602

579603

0 commit comments

Comments
 (0)