Skip to content

Commit 42e2d05

Browse files
Fix some tests
1 parent 44e92a1 commit 42e2d05

File tree

3 files changed

+57
-23
lines changed

3 files changed

+57
-23
lines changed

CMakeLists.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,15 @@ if (BUILD_CPU)
286286
if (HAS_AVX512BF16_FLAG)
287287
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)
288288
endif()
289-
target_compile_options(bitsandbytes PRIVATE -mprefer-vector-width=256)
289+
target_compile_options(
290+
bitsandbytes PRIVATE
291+
-mprefer-vector-width=256
292+
-mfma
293+
-mavx2
294+
-mlzcnt
295+
-mbmi
296+
-mbmi2
297+
)
290298
endif()
291299
endif()
292300

bitsandbytes/backends/cpu/ops.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Sequence
22
import ctypes as ct
33
import logging
4+
from math import prod
45

56
import torch
67

@@ -132,6 +133,13 @@ def _(
132133
dtype in [torch.bfloat16, torch.float16, torch.float32],
133134
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
134135
)
136+
137+
# Odd shape is not supported by this kernel; fallback to generic implementation
138+
if shape[-1] % 2 != 0:
139+
from ..default.ops import _dequantize_4bit_impl
140+
141+
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
142+
135143
# Enable non uint8 dtype
136144
if A.dtype != torch.uint8:
137145
A = A.view(torch.uint8)
@@ -140,35 +148,42 @@ def _(
140148
if absmax.dtype != torch.float32:
141149
absmax = absmax.float()
142150

143-
A = A.reshape(shape[0], shape[1] // 2)
151+
if len(shape) == 1:
152+
shape = (1, shape[0])
153+
154+
m = prod(shape[:-1])
155+
n = shape[-1]
156+
157+
A = A.reshape(m, n // 2)
144158
out = torch.empty(shape, dtype=dtype, device=A.device)
159+
145160
if quant_type == "fp4":
146161
if dtype == torch.float32:
147162
lib.cdequantize_blockwise_cpu_fp4_fp32(
148163
get_ptr(A),
149164
get_ptr(absmax),
150165
get_ptr(out),
151166
ct.c_longlong(blocksize),
152-
ct.c_longlong(shape[0]),
153-
ct.c_longlong(shape[1]),
167+
ct.c_longlong(m),
168+
ct.c_longlong(n),
154169
)
155170
elif dtype == torch.bfloat16:
156171
lib.cdequantize_blockwise_cpu_fp4_bf16(
157172
get_ptr(A),
158173
get_ptr(absmax),
159174
get_ptr(out),
160175
ct.c_longlong(blocksize),
161-
ct.c_longlong(shape[0]),
162-
ct.c_longlong(shape[1]),
176+
ct.c_longlong(m),
177+
ct.c_longlong(n),
163178
)
164179
elif dtype == torch.float16:
165180
lib.cdequantize_blockwise_cpu_fp4_fp16(
166181
get_ptr(A),
167182
get_ptr(absmax),
168183
get_ptr(out),
169184
ct.c_longlong(blocksize),
170-
ct.c_longlong(shape[0]),
171-
ct.c_longlong(shape[1]),
185+
ct.c_longlong(m),
186+
ct.c_longlong(n),
172187
)
173188
elif quant_type == "nf4":
174189
if dtype == torch.float32:
@@ -177,26 +192,26 @@ def _(
177192
get_ptr(absmax),
178193
get_ptr(out),
179194
ct.c_longlong(blocksize),
180-
ct.c_longlong(shape[0]),
181-
ct.c_longlong(shape[1]),
195+
ct.c_longlong(m),
196+
ct.c_longlong(n),
182197
)
183198
elif dtype == torch.bfloat16:
184199
lib.cdequantize_blockwise_cpu_nf4_bf16(
185200
get_ptr(A),
186201
get_ptr(absmax),
187202
get_ptr(out),
188203
ct.c_longlong(blocksize),
189-
ct.c_longlong(shape[0]),
190-
ct.c_longlong(shape[1]),
204+
ct.c_longlong(m),
205+
ct.c_longlong(n),
191206
)
192207
elif dtype == torch.float16:
193208
lib.cdequantize_blockwise_cpu_nf4_fp16(
194209
get_ptr(A),
195210
get_ptr(absmax),
196211
get_ptr(out),
197212
ct.c_longlong(blocksize),
198-
ct.c_longlong(shape[0]),
199-
ct.c_longlong(shape[1]),
213+
ct.c_longlong(m),
214+
ct.c_longlong(n),
200215
)
201216
else:
202217
raise ValueError

bitsandbytes/backends/default/ops.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -232,22 +232,14 @@ def _(
232232
return packed, absmax.float()
233233

234234

235-
@register_kernel("bitsandbytes::dequantize_4bit", "default")
236-
def _(
235+
def _dequantize_4bit_impl(
237236
A: torch.Tensor,
238237
absmax: torch.Tensor,
239238
blocksize: int,
240239
quant_type: str,
241240
shape: Sequence[int],
242241
dtype: torch.dtype,
243242
) -> torch.Tensor:
244-
torch._check_is_size(blocksize)
245-
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
246-
torch._check(
247-
dtype in [torch.bfloat16, torch.float16, torch.float32],
248-
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
249-
)
250-
251243
# Enable non uint8 dtype
252244
if A.dtype != torch.uint8:
253245
A = A.view(torch.uint8)
@@ -283,6 +275,25 @@ def _(
283275
return out
284276

285277

278+
@register_kernel("bitsandbytes::dequantize_4bit", "default")
279+
def _(
280+
A: torch.Tensor,
281+
absmax: torch.Tensor,
282+
blocksize: int,
283+
quant_type: str,
284+
shape: Sequence[int],
285+
dtype: torch.dtype,
286+
) -> torch.Tensor:
287+
torch._check_is_size(blocksize)
288+
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
289+
torch._check(
290+
dtype in [torch.bfloat16, torch.float16, torch.float32],
291+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
292+
)
293+
294+
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
295+
296+
286297
@register_kernel("bitsandbytes::gemv_4bit", "default")
287298
def _(
288299
A: torch.Tensor,

0 commit comments

Comments
 (0)