Skip to content

Commit aa0cf92

Browse files
authored
Merge branch 'main' into main
2 parents d66f93d + ed9c8fc commit aa0cf92

File tree

6 files changed

+84
-44
lines changed

6 files changed

+84
-44
lines changed

.github/workflows/tests.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ jobs:
101101
fail-fast: false
102102
matrix:
103103
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025, macos-15]
104-
# Test with the oldest supported torch version and the two newest.
105-
torch_version: ["2.2.2", "2.6.0", "2.7.1"]
104+
# Test with the oldest supported torch version, the newest two stable/RC.
105+
torch_version: ["2.2.2", "2.7.1", "2.8.0"]
106106
include:
107107
- os: ubuntu-22.04
108108
arch: x86_64
@@ -144,7 +144,7 @@ jobs:
144144

145145
- name: Install dependencies
146146
run: |
147-
pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu
147+
pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/${{ (matrix.torch_version == '2.8.0' && 'test/cpu') || 'cpu' }}
148148
pip install -e ".[test]"
149149
pip install pytest-cov
150150
@@ -372,7 +372,7 @@ jobs:
372372
pypi_index: "https://download.pytorch.org/whl/cu128"
373373
- cuda_version: "12.9.1"
374374
torch_version: "2.8.0"
375-
pypi_index: "https://download.pytorch.org/whl/nightly/cu129"
375+
pypi_index: "https://download.pytorch.org/whl/test/cu129"
376376

377377

378378
# Linux L40S runners

MANIFEST.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
include CMakeLists.txt
2+
graft csrc
3+
graft include

README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ bitsandbytes has the following minimum requirements for all platforms:
7171
<td>🟥 AMD GPU <br><code>cuda</code></td>
7272
<td>
7373
CDNA: gfx90a, gfx942<br>
74-
RDNA: gfx1100, gfx1200
74+
RDNA: gfx1100
7575
</td>
76-
<td>🚧</td>
77-
<td>🚧</td>
78-
<td>🚧</td>
76+
<td></td>
77+
<td>〰️</td>
78+
<td></td>
7979
</tr>
8080
<tr>
8181
<td></td>
@@ -85,8 +85,8 @@ bitsandbytes has the following minimum requirements for all platforms:
8585
Arc A-Series (Alchemist)<br>
8686
Arc B-Series (Battlemage)
8787
</td>
88-
<td>🚧</td>
89-
<td>🚧</td>
88+
<td></td>
89+
<td></td>
9090
<td>🚧</td>
9191
</tr>
9292
<tr>
@@ -108,7 +108,7 @@ bitsandbytes has the following minimum requirements for all platforms:
108108
<tr>
109109
<td></td>
110110
<td>🟩 NVIDIA GPU <br><code>cuda</code></td>
111-
<td>SM75, SM80, SM90, SM100</td>
111+
<td>SM75+</td>
112112
<td>✅</td>
113113
<td>✅</td>
114114
<td>✅</td>
@@ -139,8 +139,8 @@ bitsandbytes has the following minimum requirements for all platforms:
139139
Arc A-Series (Alchemist) <br>
140140
Arc B-Series (Battlemage)
141141
</td>
142-
<td>🚧</td>
143-
<td>🚧</td>
142+
<td></td>
143+
<td></td>
144144
<td>🚧</td>
145145
</tr>
146146
<tr>

bitsandbytes/backends/triton/ops.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# from bitsandbytes.functional import get_4bit_type
1010
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
1111
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
12+
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
13+
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
1214

1315

1416
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
@@ -21,7 +23,9 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
2123
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
2224
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
2325

24-
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
26+
with torch_accelerator_module.device(A.device):
27+
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
28+
2529
out = out.reshape(A.shape)
2630

2731
return out, absmax.float()
@@ -35,13 +39,14 @@ def dequantize_blockwise(
3539
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
3640

3741
out = torch.empty_like(A, dtype=dtype, device=A.device)
38-
triton_kernels.dequant_int8_blockwise(
39-
A,
40-
code,
41-
absmax,
42-
out,
43-
blocksize,
44-
)
42+
with torch_accelerator_module.device(A.device):
43+
triton_kernels.dequant_int8_blockwise(
44+
A,
45+
code,
46+
absmax,
47+
out,
48+
blocksize,
49+
)
4550

4651
return out
4752

@@ -55,13 +60,14 @@ def dequantize_blockwise_inplace(
5560
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
5661
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
5762

58-
triton_kernels.dequant_int8_blockwise(
59-
A,
60-
code,
61-
absmax,
62-
out,
63-
blocksize,
64-
)
63+
with torch_accelerator_module.device(A.device):
64+
triton_kernels.dequant_int8_blockwise(
65+
A,
66+
code,
67+
absmax,
68+
out,
69+
blocksize,
70+
)
6571

6672

6773
def quantize_4bit(
@@ -84,9 +90,10 @@ def quantize_4bit(
8490
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
8591
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
8692

87-
triton_kernels.quantize_4bit_blockwise_triton(
88-
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
89-
)
93+
with torch_accelerator_module.device(A.device):
94+
triton_kernels.quantize_4bit_blockwise_triton(
95+
A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
96+
)
9097
packed = out
9198

9299
if quant_storage != torch.uint8:
@@ -119,7 +126,9 @@ def dequantize_4bit(
119126

120127
out = torch.empty(shape, dtype=dtype, device=A.device)
121128

122-
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
129+
with torch_accelerator_module.device(A.device):
130+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
131+
123132
return out
124133

125134

@@ -134,7 +143,8 @@ def dequantize_4bit_inplace(
134143
) -> None:
135144
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
136145
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
137-
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
146+
with torch_accelerator_module.device(A.device):
147+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
138148

139149

140150
def gemv_4bit(
@@ -150,14 +160,15 @@ def gemv_4bit(
150160

151161
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
152162

153-
triton_kernels._dequantize_4bit_impl_passing_code(
154-
B,
155-
absmax,
156-
blocksize,
157-
code,
158-
dtype=A.dtype,
159-
out=B_dq_triton,
160-
)
163+
with torch_accelerator_module.device(A.device):
164+
triton_kernels._dequantize_4bit_impl_passing_code(
165+
B,
166+
absmax,
167+
blocksize,
168+
code,
169+
dtype=A.dtype,
170+
out=B_dq_triton,
171+
)
161172

162173
return torch.nn.functional.linear(
163174
A,

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[build-system]
2-
requires = ["setuptools >= 63.0.0"]
3-
build-backend = "setuptools.build_meta"
2+
requires = ["scikit-build-core", "setuptools >= 63.0.0"]
3+
build-backend = "scikit_build_core.setuptools.build_meta"
44

55
[project]
66
name = "bitsandbytes"

setup.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from distutils.errors import DistutilsModuleError
6+
from warnings import warn
7+
58
from setuptools import find_packages, setup
9+
from setuptools.command.build_py import build_py
610
from setuptools.dist import Distribution
711

812

@@ -12,4 +16,26 @@ def has_ext_modules(self):
1216
return True
1317

1418

15-
setup(version="0.47.0.dev0", packages=find_packages(), distclass=BinaryDistribution)
19+
class ExtBuildPy(build_py):
20+
def run(self):
21+
# build_cmake needs to be called prior to build_py, as the latter
22+
# collects the files output into the package directory.
23+
try:
24+
self.run_command("build_cmake")
25+
except DistutilsModuleError:
26+
warn(
27+
"scikit-build-core not installed, CMake will not be invoked automatically. "
28+
"Please install scikit-build-core or run CMake manually to build extensions."
29+
)
30+
super().run()
31+
32+
33+
setup(
34+
version="0.47.0.dev0",
35+
packages=find_packages(),
36+
distclass=BinaryDistribution,
37+
cmake_source_dir=".",
38+
cmdclass={
39+
"build_py": ExtBuildPy,
40+
},
41+
)

0 commit comments

Comments
 (0)