@@ -186,7 +186,7 @@ def test_few_bit_quant(self, device, bits, method):
186186 code = F .create_dynamic_map (True , bits - 0 , bits ).to (device )
187187 elif method == "quantile" :
188188 if device != "cuda" :
189- pytest .xfail ("Quantile map only works on CUDA" )
189+ pytest .skip ("Quantile map only works on CUDA" )
190190 values = torch .randn (2048 , 2048 , device = "cuda" )
191191 code = F .create_quantile_map (values , bits ).cuda ()
192192 # for some data types we have no zero
@@ -593,7 +593,7 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims):
593593
594594 A = A .view (- 1 , A .shape [- 1 ])
595595
596- CA , _ , statsA , _ , _ = F .int8_double_quant (A )
596+ CA , statsA , _ = F .int8_vectorwise_quant (A )
597597 CB , statsB , _ = F .int8_vectorwise_quant (B )
598598 output = F .int8_mm_dequant (F .int8_linear_matmul (CA , CB ), statsA , statsB )
599599
@@ -1102,6 +1102,9 @@ class TestQuantize4BitFunctional:
11021102 @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
11031103 @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ])
11041104 def test_4bit_quant (self , device , dtype , quant_type , blocksize ):
1105+ if device == "cpu" and quant_type != "nf4" :
1106+ pytest .xfail ("fp4 quantization is not supported on CPU" )
1107+
11051108 A1 = torch .randn (1024 , 1024 , device = device , dtype = dtype )
11061109 qa , SA = F .quantize_4bit (A1 , blocksize = blocksize , quant_type = quant_type )
11071110 A2 = F .dequantize_4bit (qa , SA , blocksize = blocksize , quant_type = quant_type )
@@ -1134,6 +1137,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11341137 @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
11351138 @pytest .mark .parametrize ("blocksize" , [64 , 128 ], ids = id_formatter ("blocksize" ))
11361139 def test_4bit_compressed_stats (self , device , quant_type , blocksize ):
1140+ if device == "cpu" and quant_type != "nf4" :
1141+ pytest .xfail ("fp4 quantization is not supported on CPU" )
1142+
11371143 errs1 = []
11381144 errs2 = []
11391145 for i in range (10 ):
@@ -1206,6 +1212,12 @@ def test_bench_4bit_dequant(self, quant_type):
12061212 )
12071213 @pytest .mark .parametrize ("dim" , [128 , 256 , 512 , 1024 ], ids = id_formatter ("dim" ))
12081214 def test_gemv_4bit (self , device , dim , dtype , storage_type , quant_storage , double_quant , kind ):
1215+ if device == "cpu" :
1216+ if storage_type != "nf4" :
1217+ pytest .xfail ("fp4 quantization is not supported on CPU" )
1218+ if quant_storage != torch .uint8 :
1219+ pytest .xfail ("Only uint8 storage is supported on CPU" )
1220+
12091221 errs1 = []
12101222 errs2 = []
12111223 errs3 = []
@@ -1216,7 +1228,11 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
12161228 max_errs2 = []
12171229 max_errs3 = []
12181230
1219- for i in range (100 ):
1231+ # Large number of iterations is excessive and slow on CPU.
1232+ # Keep for CUDA for now.
1233+ iters = 100 if device == "cuda" else 10
1234+
1235+ for i in range (iters ):
12201236 if kind == "fc1" :
12211237 A = torch .randn (1 , dim , dtype = dtype , device = device )
12221238 B = torch .randn (dim * 4 , dim , dtype = dtype , device = device ) / math .sqrt (dim )
@@ -1337,6 +1353,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
13371353 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = describe_dtype )
13381354 @pytest .mark .parametrize ("double_quant" , [False ], ids = ["DQ_True" ])
13391355 def test_gemv_eye_4bit (self , device , storage_type , dtype , double_quant ):
1356+ if device == "cpu" and storage_type != "nf4" :
1357+ pytest .xfail ("fp4 quantization is not supported on CPU" )
1358+
13401359 dims = 10
13411360 torch .random .manual_seed (np .random .randint (0 , 412424242 ))
13421361 dims = get_test_dims (0 , 8192 , n = dims )
0 commit comments