Skip to content

Commit 0f63009

Browse files
Merge branch 'main' into upstream_main_npu_enabled
2 parents 19ce67d + 3f9f6f3 commit 0f63009

30 files changed

+811
-136
lines changed

.github/workflows/python-package.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ jobs:
171171
retention-days: 7
172172

173173
build-wheels:
174+
env:
175+
# Skip rebuilding the CPU library when building the wheels.
176+
BNB_SKIP_CMAKE: 1
174177
needs:
175178
- build-cpu
176179
- build-cuda

.github/workflows/tests.yml

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ concurrency:
1010
group: ${{ github.workflow }}-${{ github.ref }}
1111
cancel-in-progress: true
1212

13+
env:
14+
# Skip rebuilding the CPU library when installing the wheels.
15+
# We build the libraries in separate jobs and upload as artifacts.
16+
BNB_SKIP_CMAKE: 1
17+
1318
jobs:
1419

1520
build-cpu:
@@ -146,7 +151,7 @@ jobs:
146151
- name: Install dependencies
147152
run: |
148153
pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu
149-
pip install -e ".[test]"
154+
pip install -e ".[test]" -v
150155
pip install pytest-cov
151156
152157
# We need to downgrade to numpy<2 for torch<2.4.1 compatibility on Windows
@@ -188,7 +193,7 @@ jobs:
188193
- name: Install dependencies
189194
run: |
190195
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu
191-
pip install -e ".[test]"
196+
pip install -e ".[test]" -v
192197
pip install pytest-cov
193198
194199
- name: Show installed packages
@@ -263,7 +268,7 @@ jobs:
263268

264269
- name: Install dependencies
265270
run: |
266-
pip install -e ".[test]"
271+
pip install -e ".[test]" -v
267272
pip install pytest-cov
268273
269274
- name: Show installed packages
@@ -321,7 +326,7 @@ jobs:
321326

322327
- name: Install dependencies
323328
run: |
324-
pip install -e ".[test]"
329+
pip install -e ".[test]" -v
325330
pip install pytest-cov
326331
327332
- name: Show installed packages
@@ -438,7 +443,7 @@ jobs:
438443
- name: Install dependencies
439444
run: |
440445
pip install --pre torch~=${{ matrix.torch_version }}.dev0 --index-url ${{ matrix.pypi_index }}
441-
pip install -e ".[test]"
446+
pip install -e ".[test]" -v
442447
pip install pytest-cov
443448
- name: Show installed packages
444449
run: pip list

CMakeLists.txt

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,17 @@ else()
9393
set(BUILD_MPS OFF)
9494
set(BUILD_XPU OFF)
9595
set(BUILD_NPU OFF)
96+
set(BUILD_CPU ON)
9697
endif()
9798

9899

100+
if (BUILD_CPU)
101+
set(CMAKE_CXX_STANDARD 17)
102+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
103+
string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" HOST_ARCH)
104+
find_package(OpenMP)
105+
endif()
106+
99107
if(BUILD_CUDA)
100108
# NVCC normally will only work with MSVC up to 1939. VS2022 17.10+ starts using versions 1940+.
101109
# Workaround: use --allow-unsupported-compiler
@@ -311,6 +319,34 @@ add_library(bitsandbytes SHARED ${SRC_FILES})
311319
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
312320
target_include_directories(bitsandbytes PUBLIC csrc include)
313321

322+
if (BUILD_CPU)
323+
if (OpenMP_CXX_FOUND)
324+
target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX)
325+
add_definitions(-DHAS_OPENMP)
326+
endif()
327+
328+
if ((HOST_ARCH MATCHES "x86_64|amd64") AND (NOT MSVC))
329+
include(CheckCXXCompilerFlag)
330+
check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG)
331+
check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG)
332+
if (HAS_AVX512F_FLAG)
333+
target_compile_options(bitsandbytes PRIVATE -mavx512f)
334+
endif()
335+
if (HAS_AVX512BF16_FLAG)
336+
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)
337+
endif()
338+
target_compile_options(
339+
bitsandbytes PRIVATE
340+
-mprefer-vector-width=256
341+
-mfma
342+
-mavx2
343+
-mlzcnt
344+
-mbmi
345+
-mbmi2
346+
)
347+
endif()
348+
endif()
349+
314350

315351
if(BUILD_CUDA)
316352
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})

bitsandbytes/autograd/_functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,9 @@ def matmul_4bit(
374374
bias: Optional[torch.Tensor] = None,
375375
):
376376
assert quant_state is not None
377+
# Change dtype to bfloat16 on CPU
378+
if A.device.type == "cpu":
379+
quant_state.dtype = A.dtype
377380

378381
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu" and A.device.type != "npu":
379382
if A.shape[-1] % quant_state.blocksize != 0:

bitsandbytes/backends/cpu/ops.py

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from collections.abc import Sequence
12
import ctypes as ct
23
import logging
4+
from math import prod
35

46
import torch
57

@@ -76,10 +78,8 @@ def _(
7678
torch._check_is_size(blocksize)
7779
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
7880

79-
# Only FP32 has c++ kernrl
81+
out = torch.empty_like(A, dtype=dtype)
8082
if dtype == torch.float32:
81-
out = torch.empty_like(A, dtype=dtype)
82-
8383
lib.cdequantize_blockwise_cpu_fp32(
8484
get_ptr(code),
8585
get_ptr(A),
@@ -88,6 +88,24 @@ def _(
8888
ct.c_longlong(blocksize),
8989
ct.c_longlong(A.numel()),
9090
)
91+
elif dtype == torch.bfloat16:
92+
lib.cdequantize_blockwise_cpu_bf16(
93+
get_ptr(code),
94+
get_ptr(A),
95+
get_ptr(absmax),
96+
get_ptr(out),
97+
ct.c_longlong(blocksize),
98+
ct.c_longlong(A.numel()),
99+
)
100+
elif dtype == torch.float16:
101+
lib.cdequantize_blockwise_cpu_fp16(
102+
get_ptr(code),
103+
get_ptr(A),
104+
get_ptr(absmax),
105+
get_ptr(out),
106+
ct.c_longlong(blocksize),
107+
ct.c_longlong(A.numel()),
108+
)
91109
else:
92110
out = code[A.reshape(-1).int()]
93111
blocks = out.shape[-1] // blocksize
@@ -99,3 +117,103 @@ def _(
99117
out = out.reshape(A.shape)
100118

101119
return out
120+
121+
@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
122+
def _(
123+
A: torch.Tensor,
124+
absmax: torch.Tensor,
125+
blocksize: int,
126+
quant_type: str,
127+
shape: Sequence[int],
128+
dtype: torch.dtype,
129+
) -> torch.Tensor:
130+
torch._check_is_size(blocksize)
131+
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
132+
torch._check(
133+
dtype in [torch.bfloat16, torch.float16, torch.float32],
134+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
135+
)
136+
137+
# Odd shape is not supported by this kernel; fallback to generic implementation
138+
if shape[-1] % 2 != 0:
139+
from ..default.ops import _dequantize_4bit_impl
140+
141+
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
142+
143+
# Enable non uint8 dtype
144+
if A.dtype != torch.uint8:
145+
A = A.view(torch.uint8)
146+
147+
# TODO: support half precision absmax
148+
if absmax.dtype != torch.float32:
149+
absmax = absmax.float()
150+
151+
if len(shape) == 1:
152+
shape = (1, shape[0])
153+
154+
m = prod(shape[:-1])
155+
n = shape[-1]
156+
157+
A = A.reshape(m, n // 2)
158+
out = torch.empty(shape, dtype=dtype, device=A.device)
159+
160+
if quant_type == "fp4":
161+
if dtype == torch.float32:
162+
lib.cdequantize_blockwise_cpu_fp4_fp32(
163+
get_ptr(A),
164+
get_ptr(absmax),
165+
get_ptr(out),
166+
ct.c_longlong(blocksize),
167+
ct.c_longlong(m),
168+
ct.c_longlong(n),
169+
)
170+
elif dtype == torch.bfloat16:
171+
lib.cdequantize_blockwise_cpu_fp4_bf16(
172+
get_ptr(A),
173+
get_ptr(absmax),
174+
get_ptr(out),
175+
ct.c_longlong(blocksize),
176+
ct.c_longlong(m),
177+
ct.c_longlong(n),
178+
)
179+
elif dtype == torch.float16:
180+
lib.cdequantize_blockwise_cpu_fp4_fp16(
181+
get_ptr(A),
182+
get_ptr(absmax),
183+
get_ptr(out),
184+
ct.c_longlong(blocksize),
185+
ct.c_longlong(m),
186+
ct.c_longlong(n),
187+
)
188+
elif quant_type == "nf4":
189+
if dtype == torch.float32:
190+
lib.cdequantize_blockwise_cpu_nf4_fp32(
191+
get_ptr(A),
192+
get_ptr(absmax),
193+
get_ptr(out),
194+
ct.c_longlong(blocksize),
195+
ct.c_longlong(m),
196+
ct.c_longlong(n),
197+
)
198+
elif dtype == torch.bfloat16:
199+
lib.cdequantize_blockwise_cpu_nf4_bf16(
200+
get_ptr(A),
201+
get_ptr(absmax),
202+
get_ptr(out),
203+
ct.c_longlong(blocksize),
204+
ct.c_longlong(m),
205+
ct.c_longlong(n),
206+
)
207+
elif dtype == torch.float16:
208+
lib.cdequantize_blockwise_cpu_nf4_fp16(
209+
get_ptr(A),
210+
get_ptr(absmax),
211+
get_ptr(out),
212+
ct.c_longlong(blocksize),
213+
ct.c_longlong(m),
214+
ct.c_longlong(n),
215+
)
216+
else:
217+
raise ValueError
218+
219+
return out

bitsandbytes/backends/cuda/ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
99

1010
from ..._ops import register_kernel
11-
from ...cextension import HIP_ENVIRONMENT, lib
11+
from ...cextension import ROCM_WARP_SIZE_64, lib
1212

1313

1414
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
@@ -211,7 +211,7 @@ def _get_col_absmax(
211211
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
212212
torch._check_is_size(blocksize)
213213

214-
if HIP_ENVIRONMENT:
214+
if ROCM_WARP_SIZE_64:
215215
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
216216
else:
217217
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -269,7 +269,7 @@ def _(
269269
def _dequantize_blockwise_impl(
270270
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
271271
) -> None:
272-
if HIP_ENVIRONMENT:
272+
if ROCM_WARP_SIZE_64:
273273
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
274274
else:
275275
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -303,7 +303,7 @@ def _dequantize_blockwise_impl(
303303
def _(
304304
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
305305
) -> tuple[torch.Tensor, torch.Tensor]:
306-
if HIP_ENVIRONMENT:
306+
if ROCM_WARP_SIZE_64:
307307
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
308308
else:
309309
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -385,7 +385,7 @@ def _dequantize_4bit_impl(
385385
dtype: torch.dtype,
386386
out: torch.Tensor,
387387
) -> None:
388-
if HIP_ENVIRONMENT:
388+
if ROCM_WARP_SIZE_64:
389389
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
390390
else:
391391
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])

bitsandbytes/backends/default/ops.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -232,22 +232,14 @@ def _(
232232
return packed, absmax.float()
233233

234234

235-
@register_kernel("bitsandbytes::dequantize_4bit", "default")
236-
def _(
235+
def _dequantize_4bit_impl(
237236
A: torch.Tensor,
238237
absmax: torch.Tensor,
239238
blocksize: int,
240239
quant_type: str,
241240
shape: Sequence[int],
242241
dtype: torch.dtype,
243242
) -> torch.Tensor:
244-
torch._check_is_size(blocksize)
245-
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
246-
torch._check(
247-
dtype in [torch.bfloat16, torch.float16, torch.float32],
248-
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
249-
)
250-
251243
# Enable non uint8 dtype
252244
if A.dtype != torch.uint8:
253245
A = A.view(torch.uint8)
@@ -283,6 +275,25 @@ def _(
283275
return out
284276

285277

278+
@register_kernel("bitsandbytes::dequantize_4bit", "default")
279+
def _(
280+
A: torch.Tensor,
281+
absmax: torch.Tensor,
282+
blocksize: int,
283+
quant_type: str,
284+
shape: Sequence[int],
285+
dtype: torch.dtype,
286+
) -> torch.Tensor:
287+
torch._check_is_size(blocksize)
288+
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
289+
torch._check(
290+
dtype in [torch.bfloat16, torch.float16, torch.float32],
291+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
292+
)
293+
294+
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
295+
296+
286297
@register_kernel("bitsandbytes::gemv_4bit", "default")
287298
def _(
288299
A: torch.Tensor,

bitsandbytes/cextension.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
import torch
1111

1212
from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
13-
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch
13+
from bitsandbytes.cuda_specs import (
14+
CUDASpecs,
15+
get_cuda_specs,
16+
get_cuda_version_tuple,
17+
get_rocm_gpu_arch,
18+
get_rocm_warpsize,
19+
)
1420

1521
logger = logging.getLogger(__name__)
1622

@@ -307,6 +313,7 @@ def get_native_library() -> BNBNativeLibrary:
307313

308314

309315
ROCM_GPU_ARCH = get_rocm_gpu_arch()
316+
ROCM_WARP_SIZE_64 = True if get_rocm_warpsize() == 64 else False
310317

311318
HIP_ENVIRONMENT = False
312319
BNB_BACKEND = "CPU"

0 commit comments

Comments
 (0)