Skip to content

Commit d180d8e

Browse files
authored
fix double compress 8bit precision (#1582)
Signed-off-by: jiqing-feng <[email protected]>
1 parent d3658c5 commit d180d8e

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,9 @@ def quantize_4bit_impl(
369369
out_uint8[abs_scaled_A > key] = val
370370
out_uint8 += sign.to(torch.uint8) * 8
371371
elif quant_type == "int8":
372-
for i in range(len(INT8_QUANT_TABLE)):
373-
out_uint8[scaled_A > INT8_QUANT_TABLE[i]] = i
372+
map = torch.tensor(INT8_QUANT_TABLE, device=scaled_A.device)
373+
diff = torch.abs(scaled_A.unsqueeze(-1) - map)
374+
out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device)
374375

375376
if quant_type == "int8":
376377
out = out_uint8

0 commit comments

Comments
 (0)