Skip to content

Commit 9e589a2

Browse files
Cpu C++ kernel (#1789)
* add template to support more dtypes Signed-off-by: jiqing-feng <[email protected]> * update cmake list Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * fix compile cpu Signed-off-by: jiqing-feng <[email protected]> * make different dtype works Signed-off-by: jiqing-feng <[email protected]> * use bf16 on CPU Signed-off-by: jiqing-feng <[email protected]> * fix state2 dtype Signed-off-by: jiqing-feng <[email protected]> * remove torch Signed-off-by: jiqing-feng <[email protected]> * rm torch Signed-off-by: jiqing-feng <[email protected]> * enable float to bf16 Signed-off-by: jiqing-feng <[email protected]> * rm dequantizeBlockwise4bitCpu Signed-off-by: jiqing-feng <[email protected]> * fix check Signed-off-by: jiqing-feng <[email protected]> * enable dequant 4bit kernel Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * fix dequantize Signed-off-by: jiqing-feng <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * test Signed-off-by: jiqing-feng <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * change input param Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * fix input param Signed-off-by: jiqing-feng <[email protected]> * spliut 8bit and 4bit Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * fix input params Signed-off-by: jiqing-feng <[email protected]> * fix input params Signed-off-by: jiqing-feng <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * enable dequant4bit Signed-off-by: jiqing-feng <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * fix Signed-off-by: jiqing-feng <[email protected]> * fix reverse Signed-off-by: jiqing-feng <[email protected]> * fix dequant 4bit fallback path Signed-off-by: jiqing-feng <[email protected]> * fix fp4 dequant Signed-off-by: jiqing-feng <[email protected]> * rm _Float16 Signed-off-by: jiqing-feng <[email protected]> * fix cmake check Signed-off-by: jiqing-feng <[email protected]> * fix lint Signed-off-by: jiqing-feng <[email protected]> * fix datatypr Signed-off-by: jiqing-feng <[email protected]> * fix include Signed-off-by: jiqing-feng <[email protected]> * fix typo Signed-off-by: jiqing-feng <[email protected]> * fix include Signed-off-by: jiqing-feng <[email protected]> * add runtime check for avx512 Signed-off-by: jiqing-feng <[email protected]> * enable windows cpu build Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * Fix some tests * Use larger shape for test --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: Matthew Douglas <[email protected]>
1 parent 63f538a commit 9e589a2

File tree

14 files changed

+649
-44
lines changed

14 files changed

+649
-44
lines changed

CMakeLists.txt

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,17 @@ else()
7878
set(BUILD_HIP OFF)
7979
set(BUILD_MPS OFF)
8080
set(BUILD_XPU OFF)
81+
set(BUILD_CPU ON)
8182
endif()
8283

8384

85+
if (BUILD_CPU)
86+
set(CMAKE_CXX_STANDARD 17)
87+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
88+
string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" HOST_ARCH)
89+
find_package(OpenMP)
90+
endif()
91+
8492
if(BUILD_CUDA)
8593
# NVCC normally will only work with MSVC up to 1939. VS2022 17.10+ starts using versions 1940+.
8694
# Workaround: use --allow-unsupported-compiler
@@ -262,6 +270,34 @@ add_library(bitsandbytes SHARED ${SRC_FILES})
262270
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
263271
target_include_directories(bitsandbytes PUBLIC csrc include)
264272

273+
if (BUILD_CPU)
274+
if (OpenMP_CXX_FOUND)
275+
target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX)
276+
add_definitions(-DHAS_OPENMP)
277+
endif()
278+
279+
if ((HOST_ARCH MATCHES "x86_64|amd64") AND (NOT MSVC))
280+
include(CheckCXXCompilerFlag)
281+
check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG)
282+
check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG)
283+
if (HAS_AVX512F_FLAG)
284+
target_compile_options(bitsandbytes PRIVATE -mavx512f)
285+
endif()
286+
if (HAS_AVX512BF16_FLAG)
287+
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)
288+
endif()
289+
target_compile_options(
290+
bitsandbytes PRIVATE
291+
-mprefer-vector-width=256
292+
-mfma
293+
-mavx2
294+
-mlzcnt
295+
-mbmi
296+
-mbmi2
297+
)
298+
endif()
299+
endif()
300+
265301

266302
if(BUILD_CUDA)
267303
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})

bitsandbytes/autograd/_functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,9 @@ def matmul_4bit(
374374
bias: Optional[torch.Tensor] = None,
375375
):
376376
assert quant_state is not None
377+
# Change dtype to bfloat16 on CPU
378+
if A.device.type == "cpu":
379+
quant_state.dtype = A.dtype
377380

378381
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
379382
if A.shape[-1] % quant_state.blocksize != 0:

bitsandbytes/backends/cpu/ops.py

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from collections.abc import Sequence
12
import ctypes as ct
23
import logging
4+
from math import prod
35

46
import torch
57

@@ -76,10 +78,8 @@ def _(
7678
torch._check_is_size(blocksize)
7779
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
7880

79-
# Only FP32 has c++ kernrl
81+
out = torch.empty_like(A, dtype=dtype)
8082
if dtype == torch.float32:
81-
out = torch.empty_like(A, dtype=dtype)
82-
8383
lib.cdequantize_blockwise_cpu_fp32(
8484
get_ptr(code),
8585
get_ptr(A),
@@ -88,6 +88,24 @@ def _(
8888
ct.c_longlong(blocksize),
8989
ct.c_longlong(A.numel()),
9090
)
91+
elif dtype == torch.bfloat16:
92+
lib.cdequantize_blockwise_cpu_bf16(
93+
get_ptr(code),
94+
get_ptr(A),
95+
get_ptr(absmax),
96+
get_ptr(out),
97+
ct.c_longlong(blocksize),
98+
ct.c_longlong(A.numel()),
99+
)
100+
elif dtype == torch.float16:
101+
lib.cdequantize_blockwise_cpu_fp16(
102+
get_ptr(code),
103+
get_ptr(A),
104+
get_ptr(absmax),
105+
get_ptr(out),
106+
ct.c_longlong(blocksize),
107+
ct.c_longlong(A.numel()),
108+
)
91109
else:
92110
out = code[A.reshape(-1).int()]
93111
blocks = out.shape[-1] // blocksize
@@ -99,3 +117,103 @@ def _(
99117
out = out.reshape(A.shape)
100118

101119
return out
120+
121+
@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
122+
def _(
123+
A: torch.Tensor,
124+
absmax: torch.Tensor,
125+
blocksize: int,
126+
quant_type: str,
127+
shape: Sequence[int],
128+
dtype: torch.dtype,
129+
) -> torch.Tensor:
130+
torch._check_is_size(blocksize)
131+
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
132+
torch._check(
133+
dtype in [torch.bfloat16, torch.float16, torch.float32],
134+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
135+
)
136+
137+
# Odd shape is not supported by this kernel; fallback to generic implementation
138+
if shape[-1] % 2 != 0:
139+
from ..default.ops import _dequantize_4bit_impl
140+
141+
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
142+
143+
# Enable non uint8 dtype
144+
if A.dtype != torch.uint8:
145+
A = A.view(torch.uint8)
146+
147+
# TODO: support half precision absmax
148+
if absmax.dtype != torch.float32:
149+
absmax = absmax.float()
150+
151+
if len(shape) == 1:
152+
shape = (1, shape[0])
153+
154+
m = prod(shape[:-1])
155+
n = shape[-1]
156+
157+
A = A.reshape(m, n // 2)
158+
out = torch.empty(shape, dtype=dtype, device=A.device)
159+
160+
if quant_type == "fp4":
161+
if dtype == torch.float32:
162+
lib.cdequantize_blockwise_cpu_fp4_fp32(
163+
get_ptr(A),
164+
get_ptr(absmax),
165+
get_ptr(out),
166+
ct.c_longlong(blocksize),
167+
ct.c_longlong(m),
168+
ct.c_longlong(n),
169+
)
170+
elif dtype == torch.bfloat16:
171+
lib.cdequantize_blockwise_cpu_fp4_bf16(
172+
get_ptr(A),
173+
get_ptr(absmax),
174+
get_ptr(out),
175+
ct.c_longlong(blocksize),
176+
ct.c_longlong(m),
177+
ct.c_longlong(n),
178+
)
179+
elif dtype == torch.float16:
180+
lib.cdequantize_blockwise_cpu_fp4_fp16(
181+
get_ptr(A),
182+
get_ptr(absmax),
183+
get_ptr(out),
184+
ct.c_longlong(blocksize),
185+
ct.c_longlong(m),
186+
ct.c_longlong(n),
187+
)
188+
elif quant_type == "nf4":
189+
if dtype == torch.float32:
190+
lib.cdequantize_blockwise_cpu_nf4_fp32(
191+
get_ptr(A),
192+
get_ptr(absmax),
193+
get_ptr(out),
194+
ct.c_longlong(blocksize),
195+
ct.c_longlong(m),
196+
ct.c_longlong(n),
197+
)
198+
elif dtype == torch.bfloat16:
199+
lib.cdequantize_blockwise_cpu_nf4_bf16(
200+
get_ptr(A),
201+
get_ptr(absmax),
202+
get_ptr(out),
203+
ct.c_longlong(blocksize),
204+
ct.c_longlong(m),
205+
ct.c_longlong(n),
206+
)
207+
elif dtype == torch.float16:
208+
lib.cdequantize_blockwise_cpu_nf4_fp16(
209+
get_ptr(A),
210+
get_ptr(absmax),
211+
get_ptr(out),
212+
ct.c_longlong(blocksize),
213+
ct.c_longlong(m),
214+
ct.c_longlong(n),
215+
)
216+
else:
217+
raise ValueError
218+
219+
return out

bitsandbytes/backends/default/ops.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -232,22 +232,14 @@ def _(
232232
return packed, absmax.float()
233233

234234

235-
@register_kernel("bitsandbytes::dequantize_4bit", "default")
236-
def _(
235+
def _dequantize_4bit_impl(
237236
A: torch.Tensor,
238237
absmax: torch.Tensor,
239238
blocksize: int,
240239
quant_type: str,
241240
shape: Sequence[int],
242241
dtype: torch.dtype,
243242
) -> torch.Tensor:
244-
torch._check_is_size(blocksize)
245-
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
246-
torch._check(
247-
dtype in [torch.bfloat16, torch.float16, torch.float32],
248-
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
249-
)
250-
251243
# Enable non uint8 dtype
252244
if A.dtype != torch.uint8:
253245
A = A.view(torch.uint8)
@@ -283,6 +275,25 @@ def _(
283275
return out
284276

285277

278+
@register_kernel("bitsandbytes::dequantize_4bit", "default")
279+
def _(
280+
A: torch.Tensor,
281+
absmax: torch.Tensor,
282+
blocksize: int,
283+
quant_type: str,
284+
shape: Sequence[int],
285+
dtype: torch.dtype,
286+
) -> torch.Tensor:
287+
torch._check_is_size(blocksize)
288+
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
289+
torch._check(
290+
dtype in [torch.bfloat16, torch.float16, torch.float32],
291+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
292+
)
293+
294+
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
295+
296+
286297
@register_kernel("bitsandbytes::gemv_4bit", "default")
287298
def _(
288299
A: torch.Tensor,

csrc/common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55

66
using namespace BinSearch;
77

8+
typedef enum DataType_t {
9+
General8bit = 0,
10+
FP4 = 1,
11+
NF4 = 2,
12+
} DataType_t;
13+
814
struct quantize_block_args {
915
BinAlgo<Scalar, float, Direct2>* bin_searcher;
1016
float* code;

0 commit comments

Comments
 (0)