@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
212212 assert param .data .data_ptr () == shallow_copy_param .data .data_ptr ()
213213
214214
215+ @pytest .mark .parametrize ("device" , get_available_devices ())
216+ @pytest .mark .parametrize ("quant_type" , ["nf4" , "fp4" ])
217+ def test_params4bit_torch_chunk_split (device , quant_type ):
218+ """Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
219+ if device == "hpu" and not is_supported_on_hpu (quant_type , torch .float16 , torch .uint8 ):
220+ pytest .skip ("This configuration is not supported on HPU." )
221+
222+ if device == "cpu" :
223+ pytest .skip ("CPU quantization causes segfault, skipping CPU test" )
224+
225+ original_tensor = torch .randn (8 , 4 , dtype = torch .float16 , device = "cpu" )
226+
227+ params4bit = bnb .nn .Params4bit (data = original_tensor , quant_type = quant_type , requires_grad = False )
228+
229+ if device != "cpu" :
230+ params4bit = params4bit .to (device )
231+
232+ chunks = torch .chunk (params4bit , 2 , dim = 0 )
233+
234+ assert isinstance (chunks , tuple ), "torch.chunk should return tuple"
235+ for chunk in chunks :
236+ assert isinstance (chunk , bnb .nn .Params4bit ), "Chunk should preserve Params4bit subclass"
237+ assert hasattr (chunk , "quant_type" ), "Should preserve metadata"
238+ assert chunk .quant_type == params4bit .quant_type , "Should preserve quant_type value"
239+
240+ splits = torch .split (params4bit , 2 , dim = 0 )
241+
242+ assert isinstance (splits , tuple ), "torch.split should return tuple"
243+ assert len (splits ) > 0 , "Should have at least one split"
244+ for split in splits :
245+ assert isinstance (split , bnb .nn .Params4bit ), "Split should preserve Params4bit subclass"
246+ assert hasattr (split , "quant_type" ), "Should preserve metadata"
247+ assert split .quant_type == params4bit .quant_type , "Should preserve quant_type value"
248+
249+
215250@pytest .mark .parametrize ("device" , get_available_devices ())
216251@pytest .mark .parametrize ("quant_type" , ["nf4" , "fp4" ])
217252@pytest .mark .parametrize ("blocksize" , [64 , 128 ] if not HIP_ENVIRONMENT else [128 ])
0 commit comments