Skip to content

Commit cafe58f

Browse files
committed
fix format
Signed-off-by: jiqing-feng <[email protected]>
1 parent 47589cd commit cafe58f

File tree

2 files changed

+5
-27
lines changed

2 files changed

+5
-27
lines changed

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from bitsandbytes.functional import (
88
QuantState,
9-
get_4bit_type,
109
create_dynamic_map,
10+
get_4bit_type,
1111
)
1212

1313
try:
@@ -361,15 +361,14 @@ def quantize_4bit_impl(
361361
for i in range(len(INT8_QUANT_TABLE)):
362362
out_uint8[scaled_A > INT8_QUANT_TABLE[i]] = i
363363

364-
if quant_type != "int8":
364+
if quant_type == "int8":
365+
out = out_uint8
366+
code = torch.Tensor(INT8_QUANT_TABLE, device=A.device)
367+
else:
365368
if out_uint8.size(-1) % 2:
366369
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
367370
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])
368-
369371
code = get_4bit_type(quant_type, device=A.device)
370-
else:
371-
out = out_uint8
372-
code = torch.Tensor(INT8_QUANT_TABLE, device=A.device)
373372

374373
if compress_statistics:
375374
offset = absmax.mean()

bitsandbytes/functional.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -728,27 +728,6 @@ def quantize_blockwise(
728728
else:
729729
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype)
730730

731-
732-
n = A.numel()
733-
blocks = n // blocksize
734-
blocks += 1 if n % blocksize > 0 else 0
735-
rem = n % blocksize
736-
has_rem = rem > 0
737-
# Scale tensor to [-1, 1]
738-
A_reshaped = A.reshape(n)
739-
A_com = A_reshaped[: n - rem]
740-
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
741-
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
742-
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
743-
scaled_A = scaled_A.reshape(-1)
744-
if has_rem:
745-
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
746-
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
747-
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
748-
B = torch.empty(A.shape, dtype=torch.uint8, device=A.device)
749-
for i in range(len(code)):
750-
B[scaled_A > code[i]] = i
751-
752731
return out, quant_state
753732

754733

0 commit comments

Comments
 (0)