Skip to content

Commit 47589cd

Browse files
committed
support double quant on intel cpu and xpu
Signed-off-by: jiqing-feng <[email protected]>
1 parent 7e6f865 commit 47589cd

File tree

3 files changed

+68
-10
lines changed

3 files changed

+68
-10
lines changed

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from bitsandbytes.functional import (
88
QuantState,
99
get_4bit_type,
10+
create_dynamic_map,
1011
)
1112

1213
try:
@@ -279,8 +280,9 @@ def mm_dequant_impl(
279280
0.8333333: 3, # 0b0011
280281
}
281282

283+
INT8_QUANT_TABLE = create_dynamic_map().tolist()
284+
282285

283-
@_maybe_torch_compile
284286
def quantize_4bit_impl(
285287
A: Tensor,
286288
absmax: Tensor = None,
@@ -314,7 +316,7 @@ def quantize_4bit_impl(
314316
tuple(torch.Tensor, torch.Size, torch.dtype, int):
315317
The quantization state to undo the quantization.
316318
"""
317-
if quant_type not in ["nf4", "fp4"]:
319+
if quant_type not in ["nf4", "fp4", "int8"]:
318320
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.")
319321
if quant_type == "fp4":
320322
warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.")
@@ -355,14 +357,35 @@ def quantize_4bit_impl(
355357
for key, val in FP4_QUANT_TABLE.items():
356358
out_uint8[abs_scaled_A > key] = val
357359
out_uint8 += sign.to(torch.uint8) * 8
358-
if out_uint8.size(-1) % 2:
359-
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
360-
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])
360+
elif quant_type == "int8":
361+
for i in range(len(INT8_QUANT_TABLE)):
362+
out_uint8[scaled_A > INT8_QUANT_TABLE[i]] = i
361363

362-
code = get_4bit_type(quant_type, device=A.device)
364+
if quant_type != "int8":
365+
if out_uint8.size(-1) % 2:
366+
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
367+
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])
368+
369+
code = get_4bit_type(quant_type, device=A.device)
370+
else:
371+
out = out_uint8
372+
code = torch.Tensor(INT8_QUANT_TABLE, device=A.device)
363373

364374
if compress_statistics:
365-
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
375+
offset = absmax.mean()
376+
absmax -= offset
377+
qabsmax, state2 = quantize_4bit_impl(absmax, blocksize=256, quant_type="int8")
378+
del absmax
379+
state = QuantState(
380+
absmax=qabsmax,
381+
shape=input_shape,
382+
dtype=A.dtype,
383+
blocksize=blocksize,
384+
code=code,
385+
quant_type=quant_type,
386+
offset=offset,
387+
state2=state2,
388+
)
366389
else:
367390
state = QuantState(
368391
absmax=absmax,
@@ -376,6 +399,14 @@ def quantize_4bit_impl(
376399
return out.unsqueeze(0), state
377400

378401

402+
def dequant_8bit(A, offset, quant_state):
403+
assert A.dtype == torch.uint8
404+
absmax = quant_state.code[A.reshape(-1).int()]
405+
absmax += offset
406+
absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).reshape(quant_state.shape).to(quant_state.dtype)
407+
return absmax
408+
409+
379410
@_maybe_torch_compile
380411
def dequantize_4bit_impl(
381412
A: Tensor,
@@ -438,7 +469,7 @@ def dequantize_4bit_impl(
438469
)
439470

440471
if quant_state.nested:
441-
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
472+
absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2)
442473

443474
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
444475
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)

bitsandbytes/functional.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,27 @@ def quantize_blockwise(
728728
else:
729729
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype)
730730

731+
732+
n = A.numel()
733+
blocks = n // blocksize
734+
blocks += 1 if n % blocksize > 0 else 0
735+
rem = n % blocksize
736+
has_rem = rem > 0
737+
# Scale tensor to [-1, 1]
738+
A_reshaped = A.reshape(n)
739+
A_com = A_reshaped[: n - rem]
740+
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
741+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
742+
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
743+
scaled_A = scaled_A.reshape(-1)
744+
if has_rem:
745+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
746+
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
747+
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
748+
B = torch.empty(A.shape, dtype=torch.uint8, device=A.device)
749+
for i in range(len(code)):
750+
B[scaled_A > code[i]] = i
751+
731752
return out, quant_state
732753

733754

bitsandbytes/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,17 @@ def enable_ipex_fusion(linear, x):
206206
_ipex_xpu_version_prereq,
207207
ipex_cpu,
208208
ipex_xpu,
209+
dequant_8bit,
209210
)
210211

212+
quant_state = linear.weight.quant_state
213+
214+
if quant_state.nested:
215+
quant_state.absmax = dequant_8bit(quant_state.absmax, quant_state.offset, quant_state.state2)
216+
quant_state.nested = False
217+
delattr(quant_state, "state2")
218+
211219
if x.device.type == "cpu" and ipex_cpu and _ipex_cpu_version_prereq(2, 5):
212-
quant_state = linear.weight.quant_state
213220
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
214221
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
215222
"nf4",
@@ -222,7 +229,6 @@ def enable_ipex_fusion(linear, x):
222229
2,
223230
)
224231
elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5):
225-
quant_state = linear.weight.quant_state
226232
new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
227233

228234
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)

0 commit comments

Comments
 (0)