Skip to content

Commit f2029c6

Browse files
committed
rebase
Signed-off-by: jiqing-feng <[email protected]>
2 parents c5e1894 + 3f9f6f3 commit f2029c6

33 files changed

+354
-681
lines changed

.github/workflows/python-package.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ jobs:
171171
retention-days: 7
172172

173173
build-wheels:
174+
env:
175+
# Skip rebuilding the CPU library when building the wheels.
176+
BNB_SKIP_CMAKE: 1
174177
needs:
175178
- build-cpu
176179
- build-cuda

.github/workflows/tests.yml

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ concurrency:
1010
group: ${{ github.workflow }}-${{ github.ref }}
1111
cancel-in-progress: true
1212

13+
env:
14+
# Skip rebuilding the CPU library when installing the wheels.
15+
# We build the libraries in separate jobs and upload as artifacts.
16+
BNB_SKIP_CMAKE: 1
17+
1318
jobs:
1419

1520
build-cpu:
@@ -103,7 +108,7 @@ jobs:
103108
matrix:
104109
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025, macos-15]
105110
# Test with the oldest supported torch version, the newest two stable/RC.
106-
torch_version: ["2.3.1", "2.7.1", "2.8.0"]
111+
torch_version: ["2.3.1", "2.8.0", "2.9.0"]
107112
include:
108113
- os: ubuntu-22.04
109114
arch: x86_64
@@ -146,7 +151,7 @@ jobs:
146151
- name: Install dependencies
147152
run: |
148153
pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu
149-
pip install -e ".[test]"
154+
pip install -e ".[test]" -v
150155
pip install pytest-cov
151156
152157
# We need to downgrade to numpy<2 for torch<2.4.1 compatibility on Windows
@@ -188,7 +193,7 @@ jobs:
188193
- name: Install dependencies
189194
run: |
190195
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu
191-
pip install -e ".[test]"
196+
pip install -e ".[test]" -v
192197
pip install pytest-cov
193198
194199
- name: Show installed packages
@@ -263,7 +268,7 @@ jobs:
263268

264269
- name: Install dependencies
265270
run: |
266-
pip install -e ".[test]"
271+
pip install -e ".[test]" -v
267272
pip install pytest-cov
268273
269274
- name: Show installed packages
@@ -321,7 +326,7 @@ jobs:
321326

322327
- name: Install dependencies
323328
run: |
324-
pip install -e ".[test]"
329+
pip install -e ".[test]" -v
325330
pip install pytest-cov
326331
327332
- name: Show installed packages
@@ -438,7 +443,7 @@ jobs:
438443
- name: Install dependencies
439444
run: |
440445
pip install --pre torch~=${{ matrix.torch_version }}.dev0 --index-url ${{ matrix.pypi_index }}
441-
pip install -e ".[test]"
446+
pip install -e ".[test]" -v
442447
pip install pytest-cov
443448
- name: Show installed packages
444449
run: pip list

CMakeLists.txt

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ endif()
8585
if (BUILD_CPU)
8686
set(CMAKE_CXX_STANDARD 17)
8787
set(CMAKE_CXX_STANDARD_REQUIRED ON)
88+
string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" HOST_ARCH)
8889
find_package(OpenMP)
8990
endif()
9091

@@ -270,30 +271,46 @@ target_compile_features(bitsandbytes PUBLIC cxx_std_17)
270271
target_include_directories(bitsandbytes PUBLIC csrc include)
271272

272273
if (BUILD_CPU)
273-
target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX)
274-
include(CheckCXXCompilerFlag)
275-
check_cxx_compiler_flag(-mavx512f HAS_AVX512F)
276-
check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16)
277-
check_cxx_compiler_flag(-mavx512dq HAS_AVX512DQ)
278-
check_cxx_compiler_flag(-mavx512bw HAS_AVX512BW)
279-
check_cxx_compiler_flag(-mavx512vl HAS_AVX512VL)
280-
if(HAS_AVX512F)
281-
target_compile_options(bitsandbytes PRIVATE -mavx512f)
274+
if (OpenMP_CXX_FOUND)
275+
target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX)
276+
add_definitions(-DHAS_OPENMP)
282277
endif()
283-
if(HAS_AVX512BF16)
284-
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)
285-
endif()
286-
if(HAS_AVX512DQ)
287-
target_compile_options(bitsandbytes PRIVATE -mavx512dq)
288-
endif()
289-
if(HAS_AVX512BW)
290-
target_compile_options(bitsandbytes PRIVATE -mavx512bw)
291-
endif()
292-
if(HAS_AVX512VL)
293-
target_compile_options(bitsandbytes PRIVATE -mavx512vl)
278+
279+
if ((HOST_ARCH MATCHES "x86_64|amd64") AND (NOT MSVC))
280+
include(CheckCXXCompilerFlag)
281+
check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG)
282+
check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG)
283+
check_cxx_compiler_flag(-mavx512dq HAS_AVX512DQ)
284+
check_cxx_compiler_flag(-mavx512bw HAS_AVX512BW)
285+
check_cxx_compiler_flag(-mavx512vl HAS_AVX512VL)
286+
if (HAS_AVX512F_FLAG)
287+
target_compile_options(bitsandbytes PRIVATE -mavx512f)
288+
endif()
289+
if (HAS_AVX512BF16_FLAG)
290+
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)
291+
endif()
292+
if(HAS_AVX512DQ)
293+
target_compile_options(bitsandbytes PRIVATE -mavx512dq)
294+
endif()
295+
if(HAS_AVX512BW)
296+
target_compile_options(bitsandbytes PRIVATE -mavx512bw)
297+
endif()
298+
if(HAS_AVX512VL)
299+
target_compile_options(bitsandbytes PRIVATE -mavx512vl)
300+
endif()
301+
target_compile_options(
302+
bitsandbytes PRIVATE
303+
-mprefer-vector-width=256
304+
-mfma
305+
-mavx2
306+
-mlzcnt
307+
-mbmi
308+
-mbmi2
309+
)
294310
endif()
295311
endif()
296312

313+
297314
if(BUILD_CUDA)
298315
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
299316
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse)

bitsandbytes/autograd/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
from ._functions import get_inverse_transform_indices, undo_layout

bitsandbytes/autograd/_functions.py

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
from collections.abc import Callable
21
from dataclasses import dataclass
32
from math import prod
43
from typing import Optional
54
import warnings
65
from warnings import warn
76

87
import torch
9-
from typing_extensions import deprecated
108

119
import bitsandbytes.functional as F
1210

@@ -50,66 +48,9 @@ def get_current_outlier_idx(self):
5048
return torch.Tensor(list(self.outliers)).to(torch.int64)
5149

5250

53-
@deprecated(
54-
"This function is deprecated and will be removed in a future release.",
55-
category=FutureWarning,
56-
)
57-
def get_inverse_transform_indices(
58-
transform_tile: Callable[[torch.Tensor], torch.Tensor],
59-
tile_size: tuple[int, int],
60-
):
61-
"""
62-
Compute a permutation of indices that invert the specified (tiled) matrix transformation
63-
64-
:param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
65-
:param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
66-
:note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
67-
:example: transform_tile function for the turing layout (bitsandbytes.functional as F)
68-
:returns: indices
69-
"""
70-
d1, d2 = tile_size
71-
assert 0 < d1 * d2 < 2**64
72-
tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
73-
# encode each position in tile as a tuple of <= 8 unique bytes
74-
permuted_tile_indices = torch.zeros_like(tile_indices)
75-
for i in range(8):
76-
# select i-th byte, apply transformation and trace where each index ended up
77-
ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
78-
sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
79-
assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
80-
permuted_tile_i = transform_tile(sample_tile_i)
81-
ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
82-
permuted_tile_indices += ith_permuted_indices * (256**i)
83-
if d1 * d2 < 256**i:
84-
break # if all indices fit in i bytes, stop early
85-
return permuted_tile_indices
86-
87-
8851
_is_compiling = torch.compiler.is_compiling
8952

9053

91-
@deprecated(
92-
"This function is deprecated and will be removed in a future release.",
93-
category=FutureWarning,
94-
)
95-
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
96-
"""
97-
Undo a tiled permutation such as turing or ampere layout
98-
99-
:param permuted_tensor: torch tensor in a permuted layout
100-
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
101-
:return: contiguous row-major tensor
102-
"""
103-
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
104-
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
105-
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
106-
outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
107-
outputs[tile_indices.flatten()] = tensor
108-
outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
109-
outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
110-
return outputs.reshape(rows, cols).contiguous()
111-
112-
11354
@dataclass
11455
class MatmulLtState:
11556
_tile_indices: Optional[torch.Tensor] = None # TODO: remove
@@ -433,7 +374,7 @@ def matmul_4bit(
433374
bias: Optional[torch.Tensor] = None,
434375
):
435376
assert quant_state is not None
436-
# Change dtype to bfloat16 on CPU
377+
# Change dtype to input dtype on CPU
437378
if A.device.type == "cpu":
438379
quant_state.dtype = A.dtype
439380

bitsandbytes/backends/cpu/ops.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Sequence
22
import ctypes as ct
33
import logging
4+
from math import prod
45

56
import torch
67

@@ -132,6 +133,13 @@ def _(
132133
dtype in [torch.bfloat16, torch.float16, torch.float32],
133134
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
134135
)
136+
137+
# Odd shape is not supported by this kernel; fallback to generic implementation
138+
if shape[-1] % 2 != 0:
139+
from ..default.ops import _dequantize_4bit_impl
140+
141+
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
142+
135143
# Enable non uint8 dtype
136144
if A.dtype != torch.uint8:
137145
A = A.view(torch.uint8)
@@ -140,35 +148,42 @@ def _(
140148
if absmax.dtype != torch.float32:
141149
absmax = absmax.float()
142150

143-
A = A.reshape(shape[0], shape[1] // 2)
151+
if len(shape) == 1:
152+
shape = (1, shape[0])
153+
154+
m = prod(shape[:-1])
155+
n = shape[-1]
156+
157+
A = A.reshape(m, n // 2)
144158
out = torch.empty(shape, dtype=dtype, device=A.device)
159+
145160
if quant_type == "fp4":
146161
if dtype == torch.float32:
147162
lib.cdequantize_blockwise_cpu_fp4_fp32(
148163
get_ptr(A),
149164
get_ptr(absmax),
150165
get_ptr(out),
151166
ct.c_longlong(blocksize),
152-
ct.c_longlong(shape[0]),
153-
ct.c_longlong(shape[1]),
167+
ct.c_longlong(m),
168+
ct.c_longlong(n),
154169
)
155170
elif dtype == torch.bfloat16:
156171
lib.cdequantize_blockwise_cpu_fp4_bf16(
157172
get_ptr(A),
158173
get_ptr(absmax),
159174
get_ptr(out),
160175
ct.c_longlong(blocksize),
161-
ct.c_longlong(shape[0]),
162-
ct.c_longlong(shape[1]),
176+
ct.c_longlong(m),
177+
ct.c_longlong(n),
163178
)
164179
elif dtype == torch.float16:
165180
lib.cdequantize_blockwise_cpu_fp4_fp16(
166181
get_ptr(A),
167182
get_ptr(absmax),
168183
get_ptr(out),
169184
ct.c_longlong(blocksize),
170-
ct.c_longlong(shape[0]),
171-
ct.c_longlong(shape[1]),
185+
ct.c_longlong(m),
186+
ct.c_longlong(n),
172187
)
173188
elif quant_type == "nf4":
174189
if dtype == torch.float32:
@@ -177,26 +192,26 @@ def _(
177192
get_ptr(absmax),
178193
get_ptr(out),
179194
ct.c_longlong(blocksize),
180-
ct.c_longlong(shape[0]),
181-
ct.c_longlong(shape[1]),
195+
ct.c_longlong(m),
196+
ct.c_longlong(n),
182197
)
183198
elif dtype == torch.bfloat16:
184199
lib.cdequantize_blockwise_cpu_nf4_bf16(
185200
get_ptr(A),
186201
get_ptr(absmax),
187202
get_ptr(out),
188203
ct.c_longlong(blocksize),
189-
ct.c_longlong(shape[0]),
190-
ct.c_longlong(shape[1]),
204+
ct.c_longlong(m),
205+
ct.c_longlong(n),
191206
)
192207
elif dtype == torch.float16:
193208
lib.cdequantize_blockwise_cpu_nf4_fp16(
194209
get_ptr(A),
195210
get_ptr(absmax),
196211
get_ptr(out),
197212
ct.c_longlong(blocksize),
198-
ct.c_longlong(shape[0]),
199-
ct.c_longlong(shape[1]),
213+
ct.c_longlong(m),
214+
ct.c_longlong(n),
200215
)
201216
else:
202217
raise ValueError

bitsandbytes/backends/cuda/ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
99

1010
from ..._ops import register_kernel
11-
from ...cextension import HIP_ENVIRONMENT, lib
11+
from ...cextension import ROCM_WARP_SIZE_64, lib
1212

1313

1414
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
@@ -211,7 +211,7 @@ def _get_col_absmax(
211211
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
212212
torch._check_is_size(blocksize)
213213

214-
if HIP_ENVIRONMENT:
214+
if ROCM_WARP_SIZE_64:
215215
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
216216
else:
217217
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -269,7 +269,7 @@ def _(
269269
def _dequantize_blockwise_impl(
270270
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
271271
) -> None:
272-
if HIP_ENVIRONMENT:
272+
if ROCM_WARP_SIZE_64:
273273
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
274274
else:
275275
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -303,7 +303,7 @@ def _dequantize_blockwise_impl(
303303
def _(
304304
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
305305
) -> tuple[torch.Tensor, torch.Tensor]:
306-
if HIP_ENVIRONMENT:
306+
if ROCM_WARP_SIZE_64:
307307
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
308308
else:
309309
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -385,7 +385,7 @@ def _dequantize_4bit_impl(
385385
dtype: torch.dtype,
386386
out: torch.Tensor,
387387
) -> None:
388-
if HIP_ENVIRONMENT:
388+
if ROCM_WARP_SIZE_64:
389389
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
390390
else:
391391
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])

0 commit comments

Comments
 (0)