Skip to content

Commit 2938c73

Browse files
committed
test_params4bit_torch_chunk_split
1 parent 1dbe602 commit 2938c73

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/test_linear4bit.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
212212
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
213213

214214

215+
@pytest.mark.parametrize("device", get_available_devices())
216+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
217+
def test_params4bit_torch_chunk_split(device, quant_type):
218+
"""Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
219+
if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8):
220+
pytest.skip("This configuration is not supported on HPU.")
221+
222+
if device == "cpu":
223+
pytest.skip("CPU quantization causes segfault, skipping CPU test")
224+
225+
original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu")
226+
227+
params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False)
228+
229+
if device != "cpu":
230+
params4bit = params4bit.to(device)
231+
232+
chunks = torch.chunk(params4bit, 2, dim=0)
233+
234+
assert isinstance(chunks, tuple), "torch.chunk should return tuple"
235+
for chunk in chunks:
236+
assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass"
237+
assert hasattr(chunk, "quant_type"), "Should preserve metadata"
238+
assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value"
239+
240+
splits = torch.split(params4bit, 2, dim=0)
241+
242+
assert isinstance(splits, tuple), "torch.split should return tuple"
243+
assert len(splits) > 0, "Should have at least one split"
244+
for split in splits:
245+
assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass"
246+
assert hasattr(split, "quant_type"), "Should preserve metadata"
247+
assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value"
248+
249+
215250
@pytest.mark.parametrize("device", get_available_devices())
216251
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
217252
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])

0 commit comments

Comments
 (0)