Skip to content

Commit 4265392

Browse files
Merge pull request #1719 from ved1beta/fsdp_integration2
Fix Params4bit tensor subclass handling
2 parents e54dc12 + 0ecb8fb commit 4265392

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

bitsandbytes/nn/modules.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,46 @@ def to(self, *args, **kwargs):
356356

357357
return new_param
358358

359+
@classmethod
360+
def __torch_function__(cls, func, types, args=(), kwargs=None):
361+
if kwargs is None:
362+
kwargs = {}
363+
364+
if func in [torch.chunk, torch.split]:
365+
tensor = args[0]
366+
367+
result = super().__torch_function__(func, types, args, kwargs)
368+
369+
if isinstance(result, tuple):
370+
return tuple(
371+
cls(
372+
data=chunk,
373+
requires_grad=tensor.requires_grad,
374+
quant_state=tensor.quant_state,
375+
blocksize=tensor.blocksize,
376+
compress_statistics=tensor.compress_statistics,
377+
quant_type=tensor.quant_type,
378+
quant_storage=tensor.quant_storage,
379+
module=tensor.module,
380+
bnb_quantized=tensor.bnb_quantized,
381+
)
382+
for chunk in result
383+
)
384+
else:
385+
return cls(
386+
data=result,
387+
requires_grad=tensor.requires_grad,
388+
quant_state=tensor.quant_state,
389+
blocksize=tensor.blocksize,
390+
compress_statistics=tensor.compress_statistics,
391+
quant_type=tensor.quant_type,
392+
quant_storage=tensor.quant_storage,
393+
module=tensor.module,
394+
bnb_quantized=tensor.bnb_quantized,
395+
)
396+
397+
return super().__torch_function__(func, types, args, kwargs)
398+
359399

360400
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
361401
if getattr(module.weight, "quant_state", None) is not None:

tests/test_linear4bit.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
212212
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
213213

214214

215+
@pytest.mark.parametrize("device", get_available_devices())
216+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
217+
def test_params4bit_torch_chunk_split(device, quant_type):
218+
"""Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
219+
if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8):
220+
pytest.skip("This configuration is not supported on HPU.")
221+
222+
if device == "cpu":
223+
pytest.skip("CPU quantization causes segfault, skipping CPU test")
224+
225+
original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu")
226+
227+
params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False)
228+
229+
if device != "cpu":
230+
params4bit = params4bit.to(device)
231+
232+
chunks = torch.chunk(params4bit, 2, dim=0)
233+
234+
assert isinstance(chunks, tuple), "torch.chunk should return tuple"
235+
for chunk in chunks:
236+
assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass"
237+
assert hasattr(chunk, "quant_type"), "Should preserve metadata"
238+
assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value"
239+
240+
splits = torch.split(params4bit, 2, dim=0)
241+
242+
assert isinstance(splits, tuple), "torch.split should return tuple"
243+
assert len(splits) > 0, "Should have at least one split"
244+
for split in splits:
245+
assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass"
246+
assert hasattr(split, "quant_type"), "Should preserve metadata"
247+
assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value"
248+
249+
215250
@pytest.mark.parametrize("device", get_available_devices())
216251
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
217252
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])

0 commit comments

Comments
 (0)