Skip to content

Commit f6025bc

Browse files
authored
Enable double quant on Intel CPU and XPU (#1472)
* fix dequant 8bit Signed-off-by: jiqing-feng <[email protected]> * support double quant on intel cpu and xpu Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * fix shape Signed-off-by: jiqing-feng <[email protected]> * fix 4bit format Signed-off-by: jiqing-feng <[email protected]> * fix device error for xpu Signed-off-by: jiqing-feng <[email protected]> * fix 4bit tensor shape Signed-off-by: jiqing-feng <[email protected]> * fix nf4 xpu finetune Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]>
1 parent 7e6f865 commit f6025bc

File tree

4 files changed

+83
-26
lines changed

4 files changed

+83
-26
lines changed

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
import warnings
44

55
import torch
6+
import torch.nn.functional as F
67

78
from bitsandbytes.functional import (
89
QuantState,
10+
create_dynamic_map,
911
get_4bit_type,
1012
)
13+
from bitsandbytes.utils import reverse_4bit_compress_format
1114

1215
try:
1316
# to support Intel CPU/GPU (XPU) backend
@@ -279,8 +282,9 @@ def mm_dequant_impl(
279282
0.8333333: 3, # 0b0011
280283
}
281284

285+
INT8_QUANT_TABLE = create_dynamic_map().tolist()
286+
282287

283-
@_maybe_torch_compile
284288
def quantize_4bit_impl(
285289
A: Tensor,
286290
absmax: Tensor = None,
@@ -314,7 +318,7 @@ def quantize_4bit_impl(
314318
tuple(torch.Tensor, torch.Size, torch.dtype, int):
315319
The quantization state to undo the quantization.
316320
"""
317-
if quant_type not in ["nf4", "fp4"]:
321+
if quant_type not in ["nf4", "fp4", "int8"]:
318322
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.")
319323
if quant_type == "fp4":
320324
warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.")
@@ -355,14 +359,34 @@ def quantize_4bit_impl(
355359
for key, val in FP4_QUANT_TABLE.items():
356360
out_uint8[abs_scaled_A > key] = val
357361
out_uint8 += sign.to(torch.uint8) * 8
358-
if out_uint8.size(-1) % 2:
359-
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
360-
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])
362+
elif quant_type == "int8":
363+
for i in range(len(INT8_QUANT_TABLE)):
364+
out_uint8[scaled_A > INT8_QUANT_TABLE[i]] = i
361365

362-
code = get_4bit_type(quant_type, device=A.device)
366+
if quant_type == "int8":
367+
out = out_uint8
368+
code = torch.Tensor(INT8_QUANT_TABLE).to(A.device)
369+
else:
370+
if out_uint8.size(-1) % 2:
371+
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
372+
out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2])
373+
code = get_4bit_type(quant_type, device=A.device)
363374

364375
if compress_statistics:
365-
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
376+
offset = absmax.mean()
377+
absmax -= offset
378+
qabsmax, state2 = quantize_4bit_impl(absmax, blocksize=256, quant_type="int8")
379+
del absmax
380+
state = QuantState(
381+
absmax=qabsmax,
382+
shape=input_shape,
383+
dtype=A.dtype,
384+
blocksize=blocksize,
385+
code=code,
386+
quant_type=quant_type,
387+
offset=offset,
388+
state2=state2,
389+
)
366390
else:
367391
state = QuantState(
368392
absmax=absmax,
@@ -373,7 +397,21 @@ def quantize_4bit_impl(
373397
quant_type=quant_type,
374398
)
375399

376-
return out.unsqueeze(0), state
400+
return out.reshape(-1, 1), state
401+
402+
403+
def dequant_8bit(A, offset, quant_state):
404+
assert A.dtype == torch.uint8
405+
absmax = quant_state.code[A.reshape(-1).int()]
406+
blocks = absmax.shape[-1] // 256
407+
res = absmax.shape[-1] % 256
408+
if res != 0:
409+
absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0)
410+
absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1)
411+
absmax = absmax[: blocks * 256 + res]
412+
absmax = absmax.reshape(A.shape)
413+
absmax += offset
414+
return absmax
377415

378416

379417
@_maybe_torch_compile
@@ -411,12 +449,8 @@ def dequantize_4bit_impl(
411449
torch.Tensor:
412450
Dequantized tensor.
413451
"""
414-
if A.shape[0] == 1:
415-
transpose = False
416-
A = A.squeeze(0)
417-
elif A.shape[1] == 1:
418-
transpose = True
419-
A = A.squeeze(1)
452+
transpose = True if A.shape[0] == 1 else False
453+
A = A.reshape(-1)
420454

421455
if quant_state is None:
422456
assert absmax is not None and out is not None
@@ -438,17 +472,18 @@ def dequantize_4bit_impl(
438472
)
439473

440474
if quant_state.nested:
441-
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
475+
absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2)
442476

443477
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
444-
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
478+
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
479+
A = reverse_4bit_compress_format(ipex_weight)
445480
quant_state.ipex = False
446481

447482
# Map nf4 to [-1, 1]
448483
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
449484
n = out_dq.numel()
450-
out_dq[::2] = A & 0xF
451-
out_dq[1::2] = A >> 4
485+
out_dq[1::2] = A & 0xF
486+
out_dq[::2] = A >> 4
452487
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
453488
quant_state.code = quant_state.code.to(quant_state.dtype)
454489
out_dq = quant_state.code[out_dq]

bitsandbytes/backends/xpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def dequantize_4bit(
155155
if blocksize is None:
156156
blocksize = 64
157157
assert_on_xpu([A, absmax, out])
158-
if quant_type == "nf4":
158+
if quant_type == "nf4" and getattr(quant_state, "ipex", False):
159159
output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t()
160160
else:
161161
output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)

bitsandbytes/nn/modules.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
2121
OutlierTracer,
2222
enable_ipex_fusion,
23+
reverse_4bit_compress_format,
2324
)
2425

2526
T = TypeVar("T", bound="torch.nn.Module")
@@ -460,9 +461,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
460461
original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
461462
self.weight, "nf4", self.weight.quant_state.shape, 2
462463
)
463-
self.weight.data = original_weight.data
464+
self.weight.data = reverse_4bit_compress_format(original_weight.data)
464465
elif self.weight.device.type == "xpu":
465-
self.weight.data = self.weight.data.reshape(1, -1)
466+
self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1))
466467

467468
self.weight.quant_state.ipex = False
468469

bitsandbytes/utils.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,35 @@ def unpack_tensor_to_dict(tensor_data):
200200
return unpacked_dict
201201

202202

203+
def reverse_4bit_compress_format(weight):
204+
out_1 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device)
205+
out_2 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device)
206+
out_1 = (weight & 0xF0) >> 4
207+
out_2 = (weight & 0xF) << 4
208+
out = out_1 | out_2
209+
return out
210+
211+
203212
def enable_ipex_fusion(linear, x):
204213
from bitsandbytes.backends.cpu_xpu_common import (
205214
_ipex_cpu_version_prereq,
206215
_ipex_xpu_version_prereq,
216+
dequant_8bit,
207217
ipex_cpu,
208218
ipex_xpu,
209219
)
210220

221+
quant_state = linear.weight.quant_state
222+
223+
if quant_state.nested:
224+
quant_state.absmax = dequant_8bit(quant_state.absmax, quant_state.offset, quant_state.state2)
225+
quant_state.nested = False
226+
delattr(quant_state, "state2")
227+
211228
if x.device.type == "cpu" and ipex_cpu and _ipex_cpu_version_prereq(2, 5):
212-
quant_state = linear.weight.quant_state
229+
converted_weight = reverse_4bit_compress_format(linear.weight.data)
213230
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
214-
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
231+
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
215232
"nf4",
216233
quant_state.shape, # weight shape
217234
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
@@ -222,12 +239,16 @@ def enable_ipex_fusion(linear, x):
222239
2,
223240
)
224241
elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5):
225-
quant_state = linear.weight.quant_state
226-
new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
227-
242+
converted_weight = reverse_4bit_compress_format(linear.weight.data)
243+
new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
228244
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
229245
new_zeros = None
230246
compensation = None
247+
else:
248+
raise ValueError(
249+
"Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5"
250+
)
251+
231252
linear.weight.data = new_weight.data
232253
linear.weight.quant_state.ipex = True
233254
linear.weight.quant_state.new_scales = new_scales

0 commit comments

Comments
 (0)