Skip to content

Commit 5f78858

Browse files
committed
fix 4bit format
Signed-off-by: jiqing-feng <[email protected]>
1 parent 96f4ac8 commit 5f78858

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import warnings
44

55
import torch
6+
import torch.nn.functional as F
67

78
from bitsandbytes.functional import (
89
QuantState,
910
create_dynamic_map,
1011
get_4bit_type,
1112
)
13+
from bitsandbytes.utils import reverse_4bit_compress_format
1214

1315
try:
1416
# to support Intel CPU/GPU (XPU) backend
@@ -367,7 +369,7 @@ def quantize_4bit_impl(
367369
else:
368370
if out_uint8.size(-1) % 2:
369371
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
370-
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])
372+
out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2])
371373
code = get_4bit_type(quant_type, device=A.device)
372374

373375
if compress_statistics:
@@ -401,7 +403,13 @@ def quantize_4bit_impl(
401403
def dequant_8bit(A, offset, quant_state):
402404
assert A.dtype == torch.uint8
403405
absmax = quant_state.code[A.reshape(-1).int()]
404-
absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(A.shape)
406+
blocks = absmax.shape[-1] // 256
407+
res = absmax.shape[-1] % 256
408+
if res != 0:
409+
absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0)
410+
absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1)
411+
absmax = absmax[: blocks * 256 + res]
412+
absmax = absmax.reshape(A.shape)
405413
absmax += offset
406414
return absmax
407415

@@ -471,14 +479,15 @@ def dequantize_4bit_impl(
471479
absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2)
472480

473481
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
474-
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
482+
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
483+
A = reverse_4bit_compress_format(ipex_weight)
475484
quant_state.ipex = False
476485

477486
# Map nf4 to [-1, 1]
478487
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
479488
n = out_dq.numel()
480-
out_dq[::2] = A & 0xF
481-
out_dq[1::2] = A >> 4
489+
out_dq[1::2] = A & 0xF
490+
out_dq[::2] = A >> 4
482491
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
483492
quant_state.code = quant_state.code.to(quant_state.dtype)
484493
out_dq = quant_state.code[out_dq]

bitsandbytes/nn/modules.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
2121
OutlierTracer,
2222
enable_ipex_fusion,
23+
reverse_4bit_compress_format,
2324
)
2425

2526
T = TypeVar("T", bound="torch.nn.Module")
@@ -460,9 +461,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
460461
original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
461462
self.weight, "nf4", self.weight.quant_state.shape, 2
462463
)
463-
self.weight.data = original_weight.data
464+
self.weight.data = reverse_4bit_compress_format(original_weight.data)
464465
elif self.weight.device.type == "xpu":
465-
self.weight.data = self.weight.data.reshape(1, -1)
466+
self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1))
466467

467468
self.weight.quant_state.ipex = False
468469

bitsandbytes/utils.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,22 @@ def unpack_tensor_to_dict(tensor_data):
200200
return unpacked_dict
201201

202202

203+
def reverse_4bit_compress_format(weight):
204+
out_1 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device)
205+
out_2 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device)
206+
out_1 = (weight & 0xF0) >> 4
207+
out_2 = (weight & 0xF) << 4
208+
out = out_1 | out_2
209+
return out
210+
211+
203212
def enable_ipex_fusion(linear, x):
204213
from bitsandbytes.backends.cpu_xpu_common import (
205214
_ipex_cpu_version_prereq,
206215
_ipex_xpu_version_prereq,
216+
dequant_8bit,
207217
ipex_cpu,
208218
ipex_xpu,
209-
dequant_8bit,
210219
)
211220

212221
quant_state = linear.weight.quant_state
@@ -217,8 +226,9 @@ def enable_ipex_fusion(linear, x):
217226
delattr(quant_state, "state2")
218227

219228
if x.device.type == "cpu" and ipex_cpu and _ipex_cpu_version_prereq(2, 5):
229+
converted_weight = reverse_4bit_compress_format(linear.weight.data)
220230
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
221-
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
231+
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
222232
"nf4",
223233
quant_state.shape, # weight shape
224234
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
@@ -229,11 +239,16 @@ def enable_ipex_fusion(linear, x):
229239
2,
230240
)
231241
elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5):
232-
new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
233-
242+
converted_weight = reverse_4bit_compress_format(linear.weight.data)
243+
new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
234244
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
235245
new_zeros = None
236246
compensation = None
247+
else:
248+
raise ValueError(
249+
"Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5"
250+
)
251+
237252
linear.weight.data = new_weight.data
238253
linear.weight.quant_state.ipex = True
239254
linear.weight.quant_state.new_scales = new_scales

0 commit comments

Comments
 (0)