diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 8fdf7569d..75f647939 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -3,11 +3,14 @@ import warnings import torch +import torch.nn.functional as F from bitsandbytes.functional import ( QuantState, + create_dynamic_map, get_4bit_type, ) +from bitsandbytes.utils import reverse_4bit_compress_format try: # to support Intel CPU/GPU (XPU) backend @@ -279,8 +282,9 @@ def mm_dequant_impl( 0.8333333: 3, # 0b0011 } +INT8_QUANT_TABLE = create_dynamic_map().tolist() + -@_maybe_torch_compile def quantize_4bit_impl( A: Tensor, absmax: Tensor = None, @@ -314,7 +318,7 @@ def quantize_4bit_impl( tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ - if quant_type not in ["nf4", "fp4"]: + if quant_type not in ["nf4", "fp4", "int8"]: raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.") if quant_type == "fp4": warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.") @@ -355,14 +359,34 @@ def quantize_4bit_impl( for key, val in FP4_QUANT_TABLE.items(): out_uint8[abs_scaled_A > key] = val out_uint8 += sign.to(torch.uint8) * 8 - if out_uint8.size(-1) % 2: - out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) - out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2]) + elif quant_type == "int8": + for i in range(len(INT8_QUANT_TABLE)): + out_uint8[scaled_A > INT8_QUANT_TABLE[i]] = i - code = get_4bit_type(quant_type, device=A.device) + if quant_type == "int8": + out = out_uint8 + code = torch.Tensor(INT8_QUANT_TABLE).to(A.device) + else: + if out_uint8.size(-1) % 2: + out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) + out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2]) + code = get_4bit_type(quant_type, device=A.device) if compress_statistics: - raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_4bit_impl(absmax, blocksize=256, quant_type="int8") + del absmax + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) else: state = QuantState( absmax=absmax, @@ -373,7 +397,21 @@ def quantize_4bit_impl( quant_type=quant_type, ) - return out.unsqueeze(0), state + return out.reshape(-1, 1), state + + +def dequant_8bit(A, offset, quant_state): + assert A.dtype == torch.uint8 + absmax = quant_state.code[A.reshape(-1).int()] + blocks = absmax.shape[-1] // 256 + res = absmax.shape[-1] % 256 + if res != 0: + absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0) + absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1) + absmax = absmax[: blocks * 256 + res] + absmax = absmax.reshape(A.shape) + absmax += offset + return absmax @_maybe_torch_compile @@ -411,12 +449,8 @@ def dequantize_4bit_impl( torch.Tensor: Dequantized tensor. """ - if A.shape[0] == 1: - transpose = False - A = A.squeeze(0) - elif A.shape[1] == 1: - transpose = True - A = A.squeeze(1) + transpose = True if A.shape[0] == 1 else False + A = A.reshape(-1) if quant_state is None: assert absmax is not None and out is not None @@ -438,17 +472,18 @@ def dequantize_4bit_impl( ) if quant_state.nested: - raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") + absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2) if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False): - A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2) + ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2) + A = reverse_4bit_compress_format(ipex_weight) quant_state.ipex = False # Map nf4 to [-1, 1] out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) n = out_dq.numel() - out_dq[::2] = A & 0xF - out_dq[1::2] = A >> 4 + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 # quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue quant_state.code = quant_state.code.to(quant_state.dtype) out_dq = quant_state.code[out_dq] diff --git a/bitsandbytes/backends/xpu.py b/bitsandbytes/backends/xpu.py index bc13963e6..aca0a0103 100644 --- a/bitsandbytes/backends/xpu.py +++ b/bitsandbytes/backends/xpu.py @@ -155,7 +155,7 @@ def dequantize_4bit( if blocksize is None: blocksize = 64 assert_on_xpu([A, absmax, out]) - if quant_type == "nf4": + if quant_type == "nf4" and getattr(quant_state, "ipex", False): output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t() else: output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ad5a7d443..2320ffd39 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -20,6 +20,7 @@ LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, enable_ipex_fusion, + reverse_4bit_compress_format, ) T = TypeVar("T", bound="torch.nn.Module") @@ -460,9 +461,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight( self.weight, "nf4", self.weight.quant_state.shape, 2 ) - self.weight.data = original_weight.data + self.weight.data = reverse_4bit_compress_format(original_weight.data) elif self.weight.device.type == "xpu": - self.weight.data = self.weight.data.reshape(1, -1) + self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1)) self.weight.quant_state.ipex = False diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 02c9ac2ca..e3748685e 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -200,18 +200,35 @@ def unpack_tensor_to_dict(tensor_data): return unpacked_dict +def reverse_4bit_compress_format(weight): + out_1 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device) + out_2 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device) + out_1 = (weight & 0xF0) >> 4 + out_2 = (weight & 0xF) << 4 + out = out_1 | out_2 + return out + + def enable_ipex_fusion(linear, x): from bitsandbytes.backends.cpu_xpu_common import ( _ipex_cpu_version_prereq, _ipex_xpu_version_prereq, + dequant_8bit, ipex_cpu, ipex_xpu, ) + quant_state = linear.weight.quant_state + + if quant_state.nested: + quant_state.absmax = dequant_8bit(quant_state.absmax, quant_state.offset, quant_state.state2) + quant_state.nested = False + delattr(quant_state, "state2") + if x.device.type == "cpu" and ipex_cpu and _ipex_cpu_version_prereq(2, 5): - quant_state = linear.weight.quant_state + converted_weight = reverse_4bit_compress_format(linear.weight.data) new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight( - linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), "nf4", quant_state.shape, # weight shape quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales @@ -222,12 +239,16 @@ def enable_ipex_fusion(linear, x): 2, ) elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5): - quant_state = linear.weight.quant_state - new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) - + converted_weight = reverse_4bit_compress_format(linear.weight.data) + new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) new_zeros = None compensation = None + else: + raise ValueError( + "Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5" + ) + linear.weight.data = new_weight.data linear.weight.quant_state.ipex = True linear.weight.quant_state.new_scales = new_scales