33import time
44
55import einops
6- import numpy as np
76import pytest
87import torch
98
@@ -101,16 +100,16 @@ class Test8BitBlockwiseQuantizeFunctional:
101100 def test_dynamic_blockwise_quantization (self , device , dtype , nested , blocksize , signed ):
102101 iters = 100
103102
104- if device == "cpu " :
103+ if device != "cuda " :
105104 iters = 10
106105
107- # This test is slow on CPU , so avoid atypical use cases.
106+ # This test is slow in our non-CUDA implementations , so avoid atypical use cases.
108107 if nested :
109108 pytest .skip ("Not a typical use case." )
110109 if blocksize != 256 :
111- pytest .skip ("Only blocksize 256 is used in CPU/XPU" )
110+ pytest .skip ("Only blocksize 256 is used in CPU/MPS/ XPU" )
112111 if dtype != torch .float32 :
113- pytest .skip ("Only float32 is used in CPU/XPU" )
112+ pytest .skip ("Only float32 is used in CPU/MPS/ XPU" )
114113
115114 diffs = []
116115 reldiffs = []
@@ -239,7 +238,7 @@ def test_fp8_quant(self, device):
239238
240239 abserr = []
241240 relerr = []
242- for i in range (100 ):
241+ for i in range (10 ):
243242 A1 = torch .randn (1024 , 1024 , device = device )
244243 C , SC = F .quantize_blockwise (A1 , code = code )
245244 A2 = F .dequantize_blockwise (C , SC )
@@ -253,7 +252,7 @@ def test_fp8_quant(self, device):
253252
254253 abserr = []
255254 relerr = []
256- for i in range (100 ):
255+ for i in range (10 ):
257256 A1 = torch .rand (1024 , 1024 , device = device )
258257 C , SC = F .quantize_blockwise (A1 , code = code )
259258 A2 = F .dequantize_blockwise (C , SC )
@@ -267,7 +266,7 @@ def test_fp8_quant(self, device):
267266
268267 abserr = []
269268 relerr = []
270- for i in range (100 ):
269+ for i in range (10 ):
271270 A1 = torch .randn (1024 , 1024 , device = device )
272271 C , SC = F .quantize_blockwise (A1 )
273272 A2 = F .dequantize_blockwise (C , SC )
@@ -1406,28 +1405,26 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
14061405 @pytest .mark .parametrize ("device" , get_available_devices ())
14071406 @pytest .mark .parametrize ("storage_type" , ["nf4" , "fp4" ], ids = ["nf4" , "fp4" ])
14081407 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = describe_dtype )
1409- @pytest .mark .parametrize ("double_quant" , [False ], ids = ["DQ_True" ])
14101408 @pytest .mark .skipif (
14111409 HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a" ,
1412- reason = "this test is not supported on ROCm with gfx90a architecture yet" ,
1410+ reason = "this test is not supported on ROCm with gfx90a architectßure yet" ,
14131411 )
1414- def test_gemv_eye_4bit (self , device , storage_type , dtype , double_quant ):
1412+ def test_gemv_eye_4bit (self , device , storage_type , dtype ):
14151413 if device == "cpu" and dtype == torch .bfloat16 and torch .__version__ < (2 , 3 ):
14161414 pytest .skip ("eye doe not support bfloat16 on CPU in torch < 2.3" )
14171415
14181416 if device == "hpu" and not is_supported_on_hpu (storage_type , dtype ):
14191417 pytest .skip ("This configuration is not supported on HPU." )
14201418
1421- dims = 10
1422- torch .random .manual_seed (np .random .randint (0 , 412424242 ))
1419+ dims = 4
14231420 dims = get_test_dims (0 , 8192 , n = dims )
14241421 dims = [dim + (64 - (dim % 64 )) for dim in dims ]
14251422 # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
14261423 for dim in dims :
14271424 A = torch .normal (0 , 0.1 , size = (1 , 1 , dim ), dtype = dtype , device = device )
14281425 B = torch .eye (dim , dtype = dtype , device = device )
14291426
1430- qB , state = F .quantize_4bit (B , quant_type = storage_type , compress_statistics = double_quant )
1427+ qB , state = F .quantize_4bit (B , quant_type = storage_type , compress_statistics = False )
14311428 C3 = torch .matmul (A , B .t ())
14321429 C2 = bnb .matmul_4bit (A , qB .t (), state )
14331430 A .requires_grad = True
0 commit comments