@@ -1101,6 +1101,9 @@ class TestQuantize4BitFunctional:
11011101 @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
11021102 @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ])
11031103 def test_4bit_quant (self , device , dtype , quant_type , blocksize ):
1104+ if device == "hpu" and quant_type != "nf4" :
1105+ pytest .skip ("fp4 dequantization is not supported on HPU" )
1106+
11041107 A1 = torch .randn (1024 , 1024 , device = device , dtype = dtype )
11051108 qa , SA = F .quantize_4bit (A1 , blocksize = blocksize , quant_type = quant_type )
11061109 A2 = F .dequantize_4bit (qa , SA , blocksize = blocksize , quant_type = quant_type )
@@ -1133,6 +1136,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11331136 @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
11341137 @pytest .mark .parametrize ("blocksize" , [64 , 128 ], ids = id_formatter ("blocksize" ))
11351138 def test_4bit_compressed_stats (self , device , quant_type , blocksize ):
1139+ if device == "hpu" and quant_type != "nf4" :
1140+ pytest .skip ("fp4 dequantization is not supported on HPU" )
1141+
11361142 errs1 = []
11371143 errs2 = []
11381144 for i in range (10 ):
@@ -1205,6 +1211,9 @@ def test_bench_4bit_dequant(self, quant_type):
12051211 )
12061212 @pytest .mark .parametrize ("dim" , [128 , 256 , 512 , 1024 ], ids = id_formatter ("dim" ))
12071213 def test_gemv_4bit (self , device , dim , dtype , storage_type , quant_storage , double_quant , kind ):
1214+ if device == "hpu" and storage_type != "nf4" :
1215+ pytest .skip ("fp4 dequantization is not supported on HPU" )
1216+
12081217 errs1 = []
12091218 errs2 = []
12101219 errs3 = []
@@ -1354,6 +1363,9 @@ def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
13541363 if device == "cpu" and dtype == torch .bfloat16 and torch .__version__ < (2 , 3 ):
13551364 pytest .skip ("eye doe not support bfloat16 on CPU in torch < 2.3" )
13561365
1366+ if device == "hpu" and storage_type != "nf4" :
1367+ pytest .skip ("fp4 dequantization is not supported on HPU" )
1368+
13571369 dims = 10
13581370 torch .random .manual_seed (np .random .randint (0 , 412424242 ))
13591371 dims = get_test_dims (0 , 8192 , n = dims )
0 commit comments