Skip to content

Commit ed34d1a

Browse files
authored
Merge branch 'main' into egor/8bit_opt2
2 parents a13736a + 7bfe923 commit ed34d1a

File tree

9 files changed

+132
-29
lines changed

9 files changed

+132
-29
lines changed

.github/FUNDING.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
open_collective: bitsandbytes

.github/scripts/build-cuda.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ elif [ "${build_arch}" = "aarch64" ]; then
1515
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;90;100;120"
1616
else
1717
# By default, target Maxwell through Hopper.
18-
build_capability="50;52;60;61;70;75;80;86;89;90"
18+
build_capability="50;60;70;75;80;86;89;90"
1919

20-
# CUDA 12.8+: Add sm100 and sm120; remove < sm75 to align with PyTorch 2.7+cu128 minimum
21-
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;86;89;90;100;120"
20+
# CUDA 12.8+: Add sm100 and sm120; remove < sm70 to align with PyTorch 2.8+cu128 minimum
21+
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="70;75;80;86;89;90;100;120"
2222
fi
2323

2424
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ bitsandbytes has the following minimum requirements for all platforms:
2626
#### Accelerator support:
2727

2828
<small>Note: this table reflects the status of the current development branch. For the latest stable release, see the
29-
[document in the v0.46.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.46.0/README.md#accelerator-support).
29+
[document in the 0.47.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.47.0/README.md#accelerator-support).
3030
</small>
3131

3232
##### Legend:

bitsandbytes/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
if hasattr(torch, "xpu") and torch.xpu.is_available():
3939
from .backends.xpu import ops as xpu_ops
4040

41-
4241
if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
4342
# In case not automatically imported
4443
import habana_frameworks.torch
@@ -76,4 +75,4 @@ def _import_backends():
7675
"optim.optimizer.MockArgs": False,
7776
}
7877

79-
__version__ = "0.47.0.dev0"
78+
__version__ = "0.48.0.dev0"

bitsandbytes/nn/modules.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,46 @@ def to(self, *args, **kwargs):
356356

357357
return new_param
358358

359+
@classmethod
360+
def __torch_function__(cls, func, types, args=(), kwargs=None):
361+
if kwargs is None:
362+
kwargs = {}
363+
364+
if func in [torch.chunk, torch.split]:
365+
tensor = args[0]
366+
367+
result = super().__torch_function__(func, types, args, kwargs)
368+
369+
if isinstance(result, tuple):
370+
return tuple(
371+
cls(
372+
data=chunk,
373+
requires_grad=tensor.requires_grad,
374+
quant_state=tensor.quant_state,
375+
blocksize=tensor.blocksize,
376+
compress_statistics=tensor.compress_statistics,
377+
quant_type=tensor.quant_type,
378+
quant_storage=tensor.quant_storage,
379+
module=tensor.module,
380+
bnb_quantized=tensor.bnb_quantized,
381+
)
382+
for chunk in result
383+
)
384+
else:
385+
return cls(
386+
data=result,
387+
requires_grad=tensor.requires_grad,
388+
quant_state=tensor.quant_state,
389+
blocksize=tensor.blocksize,
390+
compress_statistics=tensor.compress_statistics,
391+
quant_type=tensor.quant_type,
392+
quant_storage=tensor.quant_storage,
393+
module=tensor.module,
394+
bnb_quantized=tensor.bnb_quantized,
395+
)
396+
397+
return super().__torch_function__(func, types, args, kwargs)
398+
359399

360400
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
361401
if getattr(module.weight, "quant_state", None) is not None:

csrc/kernels.cu

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,6 @@ __global__ void kQuantizeBlockwise(
431431
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
432432
}
433433

434-
unsigned char packed_4bit = 0;
435434
switch (DATA_TYPE) {
436435
case General8bit:
437436
#pragma unroll NUM_PER_TH
@@ -445,17 +444,15 @@ __global__ void kQuantizeBlockwise(
445444
case FP4:
446445
#pragma unroll NUM_PER_TH
447446
for (int j = 0; j < NUM_PER_TH / 2; j++) {
448-
packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
449-
packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
450-
qvals[j] = packed_4bit;
447+
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
448+
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
451449
}
452450
break;
453451
case NF4:
454452
#pragma unroll NUM_PER_TH
455453
for (int j = 0; j < NUM_PER_TH / 2; j++) {
456-
packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
457-
packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
458-
qvals[j] = packed_4bit;
454+
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
455+
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
459456
}
460457
break;
461458
}

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def run(self):
3131

3232

3333
setup(
34-
version="0.47.0.dev0",
34+
version="0.48.0.dev0",
3535
packages=find_packages(),
3636
distclass=BinaryDistribution,
3737
cmake_source_dir=".",

tests/test_functional.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,21 +1125,52 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11251125

11261126
# With larger block sizes, we can expect this to blow up.
11271127
# At blocksize>=1024, don't even bother looking at relerr.
1128-
if blocksize <= 64:
1129-
assert err.item() < 0.1
1130-
assert relerr.item() < 0.28
1131-
elif blocksize <= 256:
1132-
assert err.item() < 0.11
1133-
assert relerr.item() < 0.30
1134-
elif blocksize <= 512:
1135-
assert err.item() < 0.12
1136-
assert relerr.item() < 0.31
1137-
elif quant_type == "fp4":
1138-
# 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
1139-
assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
1140-
else:
1141-
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
1142-
assert err.item() < math.log2(blocksize) * 8e-2
1128+
#
1129+
# Actually, the above is not true anymore after fixing the integer packing bug.
1130+
# The following values were taken from averaging 1k samples per test configuration after fixing the bug.
1131+
error_dict = dict()
1132+
error_dict["fp4"] = dict()
1133+
error_dict["nf4"] = dict()
1134+
error_dict["fp4"]["err"] = {
1135+
64: 0.096545,
1136+
128: 0.102947,
1137+
256: 0.108685,
1138+
512: 0.114087,
1139+
1024: 0.119312,
1140+
2048: 0.124460,
1141+
4096: 0.129573,
1142+
}
1143+
error_dict["fp4"]["rel_err"] = {
1144+
64: 0.260130,
1145+
128: 0.275734,
1146+
256: 0.289842,
1147+
512: 0.302852,
1148+
1024: 0.314982,
1149+
2048: 0.326402,
1150+
4096: 0.337228,
1151+
}
1152+
1153+
error_dict["nf4"]["err"] = {
1154+
64: 0.072792,
1155+
128: 0.076835,
1156+
256: 0.080326,
1157+
512: 0.083535,
1158+
1024: 0.086603,
1159+
2048: 0.089592,
1160+
4096: 0.092537,
1161+
}
1162+
error_dict["nf4"]["rel_err"] = {
1163+
64: 0.203299,
1164+
128: 0.215252,
1165+
256: 0.226044,
1166+
512: 0.236021,
1167+
1024: 0.245365,
1168+
2048: 0.254146,
1169+
4096: 0.262457,
1170+
}
1171+
1172+
assert err < error_dict[quant_type]["err"][blocksize] + 1e-3
1173+
assert relerr < error_dict[quant_type]["rel_err"][blocksize] + 1e-3
11431174

11441175
@pytest.mark.parametrize("device", get_available_devices())
11451176
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])

tests/test_linear4bit.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
212212
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
213213

214214

215+
@pytest.mark.parametrize("device", get_available_devices())
216+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
217+
def test_params4bit_torch_chunk_split(device, quant_type):
218+
"""Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
219+
if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8):
220+
pytest.skip("This configuration is not supported on HPU.")
221+
222+
if device == "cpu":
223+
pytest.skip("CPU quantization causes segfault, skipping CPU test")
224+
225+
original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu")
226+
227+
params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False)
228+
229+
if device != "cpu":
230+
params4bit = params4bit.to(device)
231+
232+
chunks = torch.chunk(params4bit, 2, dim=0)
233+
234+
assert isinstance(chunks, tuple), "torch.chunk should return tuple"
235+
for chunk in chunks:
236+
assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass"
237+
assert hasattr(chunk, "quant_type"), "Should preserve metadata"
238+
assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value"
239+
240+
splits = torch.split(params4bit, 2, dim=0)
241+
242+
assert isinstance(splits, tuple), "torch.split should return tuple"
243+
assert len(splits) > 0, "Should have at least one split"
244+
for split in splits:
245+
assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass"
246+
assert hasattr(split, "quant_type"), "Should preserve metadata"
247+
assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value"
248+
249+
215250
@pytest.mark.parametrize("device", get_available_devices())
216251
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
217252
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])

0 commit comments

Comments
 (0)