1313 describe_dtype ,
1414 get_available_devices ,
1515 id_formatter ,
16+ is_supported_on_hpu ,
1617 torch_load_from_buffer ,
1718 torch_save_to_buffer ,
1819)
2728
2829@pytest .mark .parametrize ("device" , get_available_devices ())
2930@pytest .mark .parametrize ("quant_storage" , ["uint8" , "float16" , "bfloat16" , "float32" ])
31+ @pytest .mark .parametrize ("original_dtype" , [torch .float16 , torch .bfloat16 ])
3032@pytest .mark .parametrize ("bias" , TRUE_FALSE , ids = id_formatter ("bias" ))
3133@pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE , ids = id_formatter ("compress_statistics" ))
3234@pytest .mark .parametrize ("quant_type" , ["nf4" , "fp4" ])
3335@pytest .mark .parametrize ("save_before_forward" , TRUE_FALSE , ids = id_formatter ("save_before_forward" ))
34- def test_linear_serialization (device , quant_type , compress_statistics , bias , quant_storage , save_before_forward ):
35- original_dtype = torch .float16
36+ def test_linear_serialization (
37+ device , quant_type , original_dtype , compress_statistics , bias , quant_storage , save_before_forward
38+ ):
39+ if device == "hpu" and not is_supported_on_hpu (quant_type , original_dtype , storage [quant_storage ]):
40+ pytest .skip ("This configuration is not supported on HPU." )
41+
3642 compute_dtype = None
3743 layer_shape = (300 , 400 )
3844
@@ -188,6 +194,9 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua
188194@pytest .mark .parametrize ("blocksize" , [64 , 128 ])
189195@pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE , ids = id_formatter ("compress_statistics" ))
190196def test_copy_param (device , quant_type , blocksize , compress_statistics ):
197+ if device == "hpu" and not is_supported_on_hpu (quant_type ):
198+ pytest .skip ("This configuration is not supported on HPU." )
199+
191200 tensor = torch .randn (300 , 400 )
192201 param = bnb .nn .Params4bit (
193202 data = tensor ,
@@ -207,6 +216,9 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
207216@pytest .mark .parametrize ("blocksize" , [64 , 128 ])
208217@pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE , ids = id_formatter ("compress_statistics" ))
209218def test_deepcopy_param (device , quant_type , blocksize , compress_statistics ):
219+ if device == "hpu" and not is_supported_on_hpu (quant_type ):
220+ pytest .skip ("This configuration is not supported on HPU." )
221+
210222 tensor = torch .randn (300 , 400 )
211223 param = bnb .nn .Params4bit (
212224 data = tensor ,
@@ -233,6 +245,9 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
233245@pytest .mark .parametrize ("blocksize" , [64 , 128 ])
234246@pytest .mark .parametrize ("compress_statistics" , TRUE_FALSE , ids = id_formatter ("compress_statistics" ))
235247def test_params4bit_real_serialization (device , quant_type , blocksize , compress_statistics ):
248+ if device == "hpu" and not is_supported_on_hpu (quant_type ):
249+ pytest .skip ("This configuration is not supported on HPU." )
250+
236251 original_tensor = torch .randn (300 , 400 )
237252 original_param = bnb .nn .Params4bit (
238253 data = original_tensor ,
@@ -270,6 +285,9 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
270285@pytest .mark .parametrize ("mode" , ["default" , "reduce-overhead" ], ids = id_formatter ("mode" ))
271286@pytest .mark .skipif (torch .__version__ < (2 , 4 ), reason = "Not supported in torch < 2.4" )
272287def test_linear4bit_torch_compile (device , quant_type , compute_dtype , compress_statistics , bias , fullgraph , mode ):
288+ if device == "hpu" and not is_supported_on_hpu (quant_type ):
289+ pytest .skip ("This configuration is not supported on HPU." )
290+
273291 if fullgraph and torch .__version__ < (2 , 8 , 0 , "dev" ):
274292 pytest .skip ("fullgraph mode requires torch 2.8 or higher" )
275293
@@ -314,7 +332,8 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st
314332 ref_output = net (x )
315333
316334 # Compile the model
317- compiled_net = torch .compile (net , fullgraph = fullgraph , mode = mode )
335+ compile_backend = "hpu_backend" if device == "hpu" else "inductor"
336+ compiled_net = torch .compile (net , fullgraph = fullgraph , mode = mode , backend = compile_backend )
318337
319338 # Get output from compiled model
320339 with torch .no_grad ():
0 commit comments