@@ -129,6 +129,7 @@ def test_quantile_quantization():
129129 assert diff < 0.001
130130
131131
132+
132133def test_dynamic_quantization ():
133134 diffs = []
134135 reldiffs = []
@@ -141,8 +142,8 @@ def test_dynamic_quantization():
141142 diffs .append (diff .mean ().item ())
142143 reldiffs .append (reldiff .mean ().item ())
143144 assert diff .mean ().item () < 0.0135
144- # print(sum(diffs)/len(diffs))
145- # print(sum(reldiffs)/len(reldiffs))
145+ print (sum (diffs )/ len (diffs ))
146+ print (sum (reldiffs )/ len (reldiffs ))
146147
147148 for i in range (100 ):
148149 A1 = torch .rand (1024 , 1024 , device = "cuda" )
@@ -157,7 +158,8 @@ def test_dynamic_quantization():
157158@pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float16 , torch .bfloat16 ], ids = ["fp32" , "fp16" , "bf16" ])
158159@pytest .mark .parametrize ("nested" , [False , True ], ids = ["False" , "True" ])
159160@pytest .mark .parametrize ("blocksize" , [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ])
160- def test_dynamic_blockwise_quantization (dtype , nested , blocksize ):
161+ @pytest .mark .parametrize ("signed" , [True , False ], ids = ['signed_True' , 'signed_False' ])
162+ def test_dynamic_blockwise_quantization (dtype , nested , blocksize , signed ):
161163 #print('')
162164 diffs = []
163165 reldiffs = []
@@ -178,9 +180,10 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
178180 assert A2 .dtype == dtype
179181
180182 diffs = []
183+ code = F .create_dynamic_map (signed = signed )
181184 for i in range (100 ):
182185 A1 = torch .rand (1024 , 1024 , device = "cuda" , dtype = dtype )
183- C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested )
186+ C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested , code = code )
184187 A2 = F .dequantize_blockwise (C , S )
185188 diff = torch .abs (A1 - A2 ).float ()
186189 reldiff = diff / torch .abs (A1 .float () + 1e-8 )
@@ -189,11 +192,15 @@ def test_dynamic_blockwise_quantization(dtype, nested, blocksize):
189192 #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
190193 abserr = sum (diffs )/ len (diffs )
191194 relerr = sum (reldiffs )/ len (reldiffs )
192- assert abserr < 0.0035
193- assert relerr < 0.015
195+ if signed :
196+ assert abserr < 0.0035
197+ assert relerr < 0.015
198+ else :
199+ assert abserr < 0.00175
200+ assert relerr < 0.012
194201 assert A2 .dtype == dtype
195- #print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
196- #print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
202+ #print('signed=', signed, ' nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
203+ #print('signed=', signed, ' nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
197204
198205
199206
0 commit comments