Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import warnings

import torch
import torch.nn.functional as F

from bitsandbytes.functional import (
QuantState,
create_dynamic_map,
get_4bit_type,
)
from bitsandbytes.utils import reverse_4bit_compress_format

try:
# to support Intel CPU/GPU (XPU) backend
Expand Down Expand Up @@ -279,8 +282,9 @@ def mm_dequant_impl(
0.8333333: 3, # 0b0011
}

INT8_QUANT_TABLE = create_dynamic_map().tolist()


@_maybe_torch_compile
def quantize_4bit_impl(
A: Tensor,
absmax: Tensor = None,
Expand Down Expand Up @@ -314,7 +318,7 @@ def quantize_4bit_impl(
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
if quant_type not in ["nf4", "fp4"]:
if quant_type not in ["nf4", "fp4", "int8"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.")
if quant_type == "fp4":
warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.")
Expand Down Expand Up @@ -355,14 +359,34 @@ def quantize_4bit_impl(
for key, val in FP4_QUANT_TABLE.items():
out_uint8[abs_scaled_A > key] = val
out_uint8 += sign.to(torch.uint8) * 8
if out_uint8.size(-1) % 2:
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])
elif quant_type == "int8":
for i in range(len(INT8_QUANT_TABLE)):
out_uint8[scaled_A > INT8_QUANT_TABLE[i]] = i

code = get_4bit_type(quant_type, device=A.device)
if quant_type == "int8":
out = out_uint8
code = torch.Tensor(INT8_QUANT_TABLE).to(A.device)
else:
if out_uint8.size(-1) % 2:
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2])
code = get_4bit_type(quant_type, device=A.device)

if compress_statistics:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
offset = absmax.mean()
absmax -= offset
qabsmax, state2 = quantize_4bit_impl(absmax, blocksize=256, quant_type="int8")
del absmax
state = QuantState(
absmax=qabsmax,
shape=input_shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
offset=offset,
state2=state2,
)
else:
state = QuantState(
absmax=absmax,
Expand All @@ -373,7 +397,21 @@ def quantize_4bit_impl(
quant_type=quant_type,
)

return out.unsqueeze(0), state
return out.reshape(-1, 1), state


def dequant_8bit(A, offset, quant_state):
assert A.dtype == torch.uint8
absmax = quant_state.code[A.reshape(-1).int()]
blocks = absmax.shape[-1] // 256
res = absmax.shape[-1] % 256
if res != 0:
absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0)
absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1)
absmax = absmax[: blocks * 256 + res]
absmax = absmax.reshape(A.shape)
absmax += offset
return absmax


@_maybe_torch_compile
Expand Down Expand Up @@ -411,12 +449,8 @@ def dequantize_4bit_impl(
torch.Tensor:
Dequantized tensor.
"""
if A.shape[0] == 1:
transpose = False
A = A.squeeze(0)
elif A.shape[1] == 1:
transpose = True
A = A.squeeze(1)
transpose = True if A.shape[0] == 1 else False
A = A.reshape(-1)

if quant_state is None:
assert absmax is not None and out is not None
Expand All @@ -438,17 +472,18 @@ def dequantize_4bit_impl(
)

if quant_state.nested:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2)

if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
A = reverse_4bit_compress_format(ipex_weight)
quant_state.ipex = False

# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[::2] = A & 0xF
out_dq[1::2] = A >> 4
out_dq[1::2] = A & 0xF
out_dq[::2] = A >> 4
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
quant_state.code = quant_state.code.to(quant_state.dtype)
out_dq = quant_state.code[out_dq]
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/backends/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def dequantize_4bit(
if blocksize is None:
blocksize = 64
assert_on_xpu([A, absmax, out])
if quant_type == "nf4":
if quant_type == "nf4" and getattr(quant_state, "ipex", False):
output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t()
else:
output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)
Expand Down
5 changes: 3 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
enable_ipex_fusion,
reverse_4bit_compress_format,
)

T = TypeVar("T", bound="torch.nn.Module")
Expand Down Expand Up @@ -460,9 +461,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
self.weight, "nf4", self.weight.quant_state.shape, 2
)
self.weight.data = original_weight.data
self.weight.data = reverse_4bit_compress_format(original_weight.data)
elif self.weight.device.type == "xpu":
self.weight.data = self.weight.data.reshape(1, -1)
self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1))

self.weight.quant_state.ipex = False

Expand Down
31 changes: 26 additions & 5 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,35 @@ def unpack_tensor_to_dict(tensor_data):
return unpacked_dict


def reverse_4bit_compress_format(weight):
out_1 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device)
out_2 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device)
out_1 = (weight & 0xF0) >> 4
out_2 = (weight & 0xF) << 4
out = out_1 | out_2
return out


def enable_ipex_fusion(linear, x):
from bitsandbytes.backends.cpu_xpu_common import (
_ipex_cpu_version_prereq,
_ipex_xpu_version_prereq,
dequant_8bit,
ipex_cpu,
ipex_xpu,
)

quant_state = linear.weight.quant_state

if quant_state.nested:
quant_state.absmax = dequant_8bit(quant_state.absmax, quant_state.offset, quant_state.state2)
quant_state.nested = False
delattr(quant_state, "state2")

if x.device.type == "cpu" and ipex_cpu and _ipex_cpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
converted_weight = reverse_4bit_compress_format(linear.weight.data)
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
Expand All @@ -222,12 +239,16 @@ def enable_ipex_fusion(linear, x):
2,
)
elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2])

converted_weight = reverse_4bit_compress_format(linear.weight.data)
new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
new_zeros = None
compensation = None
else:
raise ValueError(
"Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5"
)

linear.weight.data = new_weight.data
linear.weight.quant_state.ipex = True
linear.weight.quant_state.new_scales = new_scales
Expand Down
Loading