@@ -563,6 +563,28 @@ def backward(ctx, grad_output):
563563 return grad_A , grad_B , None , grad_bias , None
564564
565565
566+ class MatMul8bitFp (torch .autograd .Function ):
567+ # For Intel CPU and XPU, the double quant has many unsafe operations which will breaks the finetune.
568+ # We'd like to use dequant + matmul to run finetune currently.
569+
570+ @staticmethod
571+ def forward (ctx , A , B , out = None , bias = None , state = MatmulLtState ):
572+ CB = B .data .to (A .dtype ).mul_ (state .SCB .unsqueeze (1 ).mul (1.0 / 127.0 )).t ()
573+ output = torch .matmul (A , CB ).to (A .dtype )
574+ ctx .state = state
575+ ctx .dtype_A = A .dtype
576+ ctx .grad_shape = A .shape
577+ return output
578+
579+ @staticmethod
580+ def backward (ctx , grad_output ):
581+ state = ctx .state
582+ CB = state .CB .to (ctx .dtype_A ).mul_ (state .SCB .unsqueeze (1 ).mul (1.0 / 127.0 ))
583+ grad_A = torch .matmul (grad_output , CB ).view (ctx .grad_shape ).to (ctx .dtype_A )
584+
585+ return grad_A , None , None , None , None
586+
587+
566588def matmul (
567589 A : torch .Tensor ,
568590 B : torch .Tensor ,
@@ -574,6 +596,8 @@ def matmul(
574596 state = state or MatmulLtState ()
575597 if threshold > 0.0 :
576598 state .threshold = threshold
599+ if A .device .type in ("cpu" , "xpu" ) and state .is_training :
600+ return MatMul8bitFp .apply (A , B , out , bias , state )
577601 return MatMul8bitLt .apply (A , B , out , bias , state )
578602
579603
0 commit comments