@@ -94,15 +94,22 @@ class Test8BitBlockwiseQuantizeFunctional:
9494 @pytest .mark .parametrize ("blocksize" , [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ])
9595 @pytest .mark .parametrize ("signed" , TRUE_FALSE , ids = id_formatter ("signed" ))
9696 def test_dynamic_blockwise_quantization (self , device , dtype , nested , blocksize , signed ):
97- if device in ("cpu" , "xpu" ):
97+ iters = 100
98+
99+ if device == "cpu" :
100+ iters = 10
101+
102+ # This test is slow on CPU, so avoid atypical use cases.
103+ if nested :
104+ pytest .skip ("Not a typical use case." )
98105 if blocksize != 256 :
99106 pytest .skip ("Only blocksize 256 is used in CPU/XPU" )
100107 if dtype != torch .float32 :
101108 pytest .skip ("Only float32 is used in CPU/XPU" )
102109
103110 diffs = []
104111 reldiffs = []
105- for i in range (100 ):
112+ for i in range (iters ):
106113 A1 = torch .randn (1024 , 1024 , device = device , dtype = dtype )
107114 C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested )
108115 A2 = F .dequantize_blockwise (C , S )
@@ -112,15 +119,13 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
112119 reldiffs .append (reldiff .mean ().item ())
113120 abserr = sum (diffs ) / len (diffs )
114121 relerr = sum (reldiffs ) / len (reldiffs )
115- # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
116- # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
117122 assert abserr < 0.011
118123 assert relerr < 0.018
119124 assert A2 .dtype == dtype
120125
121126 diffs = []
122127 code = F .create_dynamic_map (signed = signed )
123- for i in range (100 ):
128+ for i in range (iters ):
124129 A1 = torch .rand (1024 , 1024 , device = device , dtype = dtype )
125130 C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested , code = code )
126131 A2 = F .dequantize_blockwise (C , S )
@@ -139,33 +144,29 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
139144 assert abserr < 0.00175 if device in ("cpu" , "xpu" ) else 0.0023
140145 assert relerr < 0.012
141146 assert A2 .dtype == dtype
142- # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
143- # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
144-
145- @pytest .mark .parametrize ("device" , get_available_devices ())
146- def test_blockwise_cpu_large (self , device ):
147- if device == "xpu" :
148- pytest .skip ("XPU will not build CPU C++ codes" )
149147
148+ @pytest .mark .skipif ("cpu" not in get_available_devices (), reason = "CPU is required" )
149+ @pytest .mark .parametrize ("hidden" , [128 ])
150+ @pytest .mark .parametrize ("blocksize" , [4096 , 16384 ])
151+ def test_blockwise_cpu_large (self , hidden , blocksize ):
150152 diffs = []
151153 reldiffs = []
152154 batch = 128
153155 seq = 128
154- for hidden in [128 ]: # , 14336]:
155- for blocksize in [4096 , 16384 ]:
156- for i in range (2 ):
157- A1 = torch .randn (batch , seq , hidden , device = "cpu" )
158- t0 = time .time ()
159- C , S = F .quantize_blockwise (A1 , blocksize = blocksize )
160- A2 = F .dequantize_blockwise (C , S , blocksize = blocksize )
161- print (time .time () - t0 )
162- diff = torch .abs (A1 - A2 )
163- reldiff = diff / torch .abs (A1 + 1e-8 )
164- diffs .append (diff .mean ().item ())
165- reldiffs .append (reldiff .mean ().item ())
166- assert diffs [- 1 ] < 0.011
167- # print(sum(diffs)/len(diffs))
168- # print(sum(reldiffs)/len(reldiffs))
156+
157+ for i in range (2 ):
158+ A1 = torch .randn (batch , seq , hidden , device = "cpu" )
159+ t0 = time .time ()
160+ C , S = F .quantize_blockwise (A1 , blocksize = blocksize )
161+ A2 = F .dequantize_blockwise (C , S , blocksize = blocksize )
162+ print (time .time () - t0 )
163+ diff = torch .abs (A1 - A2 )
164+ reldiff = diff / torch .abs (A1 + 1e-8 )
165+ diffs .append (diff .mean ().item ())
166+ reldiffs .append (reldiff .mean ().item ())
167+ assert diffs [- 1 ] < 0.011
168+ # print(sum(diffs)/len(diffs))
169+ # print(sum(reldiffs)/len(reldiffs))
169170
170171 @pytest .mark .parametrize ("device" , get_available_devices ())
171172 @pytest .mark .parametrize ("bits" , range (2 , 9 ), ids = id_formatter ("bits" ))
0 commit comments