Skip to content

Commit 7f2d8a8

Browse files
committed
add cpu fp4 and rem
Signed-off-by: jiqing-feng <[email protected]>
1 parent f5c0b01 commit 7f2d8a8

File tree

2 files changed

+86
-44
lines changed

2 files changed

+86
-44
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 84 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -112,31 +112,62 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
112112
dtype=torch.float32,
113113
device="cpu",
114114
)
115+
_FP4_QUANT_TABLE = torch.tensor(
116+
[
117+
0.0000,
118+
0.0052,
119+
0.6667,
120+
1.0000,
121+
0.3333,
122+
0.5000,
123+
0.1667,
124+
0.2500,
125+
0.0000,
126+
-0.0052,
127+
-0.6667,
128+
-1.0000,
129+
-0.3333,
130+
-0.5000,
131+
-0.1667,
132+
-0.2500,
133+
],
134+
dtype=torch.float32,
135+
device="cpu",
136+
)
137+
CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE}
115138

116139

117140
@register_kernel("bitsandbytes::quantize_4bit", "cpu")
118141
def _(
119142
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
120143
) -> tuple[torch.Tensor, torch.Tensor]:
121144
torch._check_is_size(blocksize)
122-
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
145+
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4 on CPU, got {quant_type}")
123146
torch._check(
124147
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
125148
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
126149
)
127150

128151
n = A.numel()
129-
130-
# TODO: Support when weight matrix is not divisible by blocksize
131-
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
132-
133-
# Divide into blocks and normalize
134-
blocks = A.reshape(-1, blocksize)
135-
absmax = blocks.abs().max(dim=1).values.float()
136-
scaled = blocks / absmax.unsqueeze(-1)
137-
152+
blocks = n // blocksize
153+
blocks += 1 if n % blocksize > 0 else 0
154+
rem = n % blocksize
155+
has_rem = rem > 0
156+
157+
# Scale tensor to [-1, 1]
158+
absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype)
159+
A_reshaped = A.reshape(n)
160+
A_com_reshaped = A_reshaped[: n - rem].reshape(n // blocksize, blocksize)
161+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
162+
scaled = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
163+
scaled = scaled.reshape(-1)
164+
if has_rem:
165+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
166+
scaled_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
167+
scaled = torch.cat([scaled, scaled_rem], dim=0)
138168
# Quantize with the lookup table
139-
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8)
169+
quant_table = CODE[quant_type]
170+
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - quant_table), dim=-1, keepdim=True).to(torch.uint8)
140171

141172
# Pack two quantized values per byte
142173
packed = quantized[::2] << 4 | quantized[1::2]
@@ -157,32 +188,47 @@ def _(
157188
dtype: torch.dtype,
158189
) -> torch.Tensor:
159190
torch._check_is_size(blocksize)
160-
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
191+
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4 on CPU, got {quant_type}")
161192
torch._check(
162193
dtype in [torch.bfloat16, torch.float16, torch.float32],
163194
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
164195
)
165-
torch._check(
166-
A.dtype == torch.uint8,
167-
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
168-
)
169-
170-
A = A.view(-1, 1)
171-
172-
# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
173-
upper = (A >> 4).to(torch.int64)
174-
lower = (A & 0x0F).to(torch.int64)
175-
176-
# Expand to blocks
177-
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
178196

179-
# Dequantize
180-
blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]
197+
# Enable non uint8 dtype
198+
device = A.device
199+
if A.dtype != torch.uint8:
200+
bytes_value = A.cpu().numpy().tobytes()
201+
A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device)
202+
203+
A = A.reshape(-1)
204+
# Map nf4 to [-1, 1]
205+
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
206+
n = out_dq.numel()
207+
out_dq[1::2] = A & 0xF
208+
out_dq[::2] = A >> 4
209+
# code is fp32, cast to dtype to avoid the mismatch issue
210+
code = CODE[quant_type].to(dtype)
211+
out_dq = code[out_dq]
212+
213+
# Apply scales
214+
if out_dq.numel() != n:
215+
assert out_dq.numel() == n + 1
216+
out_dq = torch.narrow(out_dq, 0, 0, n)
217+
blocks = n // blocksize
218+
blocks += 1 if n % blocksize > 0 else 0
219+
rem = n % blocksize
220+
has_rem = rem > 0
221+
222+
out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1)
223+
if has_rem:
224+
out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1)
225+
out[n - rem :] = out_dq[n - rem :] * absmax[-1]
226+
else:
227+
out = out_dq.view(-1, blocksize) * absmax.view(-1, 1)
228+
229+
out = out.reshape(-1, *shape[1:]).to(dtype)
181230

182-
# Reshape to original shape
183-
blocks = blocks.reshape(-1, *shape[1:])
184-
185-
return blocks.to(dtype)
231+
return out
186232

187233

188234
@register_kernel("bitsandbytes::gemv_4bit", "cpu")
@@ -194,17 +240,13 @@ def _(
194240
code: torch.Tensor,
195241
blocksize: int,
196242
) -> torch.Tensor:
197-
# TODO: We need to determine whether `code` is NF4, FP4, or other.
198-
# Right now we assume NF4, as this is the only one supported on CPU.
199-
200-
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(
201-
B,
202-
absmax,
203-
blocksize,
204-
"nf4",
205-
shape=shapeB,
206-
dtype=A.dtype,
207-
)
243+
# Applied from dequantize_4bit
244+
B = B.view(-1, 1)
245+
upper = (B >> 4).to(torch.int64)
246+
lower = (B & 0x0F).to(torch.int64)
247+
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
248+
B_dq = code[blocks] * absmax[:, None]
249+
B_dq = B_dq.reshape(-1, *shapeB[1:]).to(A.dtype)
208250

209251
# User called gemv with B.t(), so we need to transpose it back.
210252
# if B.shape[0] == 1:

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.nn.functional as F
1212

1313
import bitsandbytes as bnb
14-
from bitsandbytes.functional import QuantState, enable_ipex_fusion
14+
from bitsandbytes.functional import QuantState, enable_ipex_fusion, ipex_cpu, ipex_xpu
1515
from bitsandbytes.optim import GlobalOptimManager
1616
from bitsandbytes.utils import (
1717
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
@@ -502,7 +502,7 @@ def set_ipex_linear(self, x: torch.Tensor):
502502

503503
def forward(self, x: torch.Tensor):
504504
# Check if ipex fusion can be used
505-
if not self.ipex_linear_is_set:
505+
if not self.ipex_linear_is_set and (ipex_cpu or ipex_xpu):
506506
self.set_ipex_linear(x)
507507
self.ipex_linear_is_set = True
508508

0 commit comments

Comments
 (0)