33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
55import ctypes as ct
6+ import itertools
67import operator
78import random
89from functools import reduce # Required in Python 3
@@ -130,13 +131,59 @@ def get_instance(cls):
130131 return cls ._instance
131132
132133
133- def create_linear_map (signed = True ):
134- if signed :
135- return torch .linspace (- 1.0 , 1.0 , 256 )
136- return torch .linspace (0.0 , 1.0 , 256 )
134+ def create_linear_map (signed = True , total_bits = 8 ):
135+ sign = (- 1.0 if signed else 0.0 )
137136
138-
139- def create_dynamic_map (signed = True , n = 7 ):
137+ values = torch .linspace (sign , 1.0 , 2 ** total_bits )
138+ gap = 256 - values .numel ()
139+ if gap == 0 :
140+ return values
141+ else :
142+ l = values .numel ()// 2
143+ #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
144+ return torch .Tensor (values [:l ].tolist () + [0 ]* gap + values [l :].tolist ())
145+
146+
147+ def create_fp8_map (signed = True , exponent_bits = 5 , precision_bits = 2 , total_bits = 8 ):
148+ e = exponent_bits
149+ p = precision_bits
150+ has_sign = 1 if signed else 0
151+ assert e + p == total_bits - has_sign
152+ # the exponent is biased to 2^(e-1) -1 == 0
153+ evalues = []
154+ pvalues = []
155+ for i , val in enumerate (range (- ((2 ** (exponent_bits - has_sign ))), 2 ** (exponent_bits - has_sign ), 1 )):
156+ evalues .append (2 ** val )
157+
158+
159+ lst = list (itertools .product ([0 , 1 ], repeat = precision_bits ))
160+ for bit_pattern in lst :
161+ value = 1
162+ for i , pval in enumerate (list (bit_pattern )):
163+ value += pval * (2 ** - (i + 1 ))
164+ pvalues .append (value )
165+
166+ assert len (evalues )* len (pvalues ) == 2 ** (total_bits - has_sign )
167+ values = []
168+ for ev in evalues :
169+ for pv in pvalues :
170+ if signed :
171+ values .append (- ev * pv )
172+ values .append (ev * pv )
173+ if total_bits < 8 :
174+ gap = 256 - len (values )
175+ for i in range (gap ):
176+ values .append (0 )
177+ values .sort ()
178+ code = torch .Tensor (values )
179+ code /= code .max ()
180+ code [127 ] = 0
181+
182+ return code
183+
184+
185+
186+ def create_dynamic_map (signed = True , max_exponent_bits = 7 , total_bits = 8 ):
140187 """
141188 Creates the dynamic quantiztion map.
142189
@@ -157,28 +204,32 @@ def create_dynamic_map(signed=True, n=7):
157204 # these are additional items that come from the case
158205 # where all the exponent bits are zero and no
159206 # indicator bit is present
160- additional_items = 2 ** (7 - n ) - 1
207+ non_sign_bits = total_bits - (1 if signed else 0 )
208+ additional_items = 2 ** (non_sign_bits - max_exponent_bits ) - 1
161209 if not signed :
162210 additional_items = 2 * additional_items
163- for i in range (n ):
164- fraction_items = (
165- 2 ** (i + 7 - n ) + 1 if signed else 2 ** (i + 7 - n + 1 ) + 1
166- )
211+ for i in range (max_exponent_bits ):
212+ fraction_items = int ((2 ** (i + non_sign_bits - max_exponent_bits ) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1 ) + 1 ))
167213 boundaries = torch .linspace (0.1 , 1 , fraction_items )
168214 means = (boundaries [:- 1 ] + boundaries [1 :]) / 2.0
169- data += ((10 ** (- (n - 1 ) + i )) * means ).tolist ()
215+ data += ((10 ** (- (max_exponent_bits - 1 ) + i )) * means ).tolist ()
170216 if signed :
171- data += (- (10 ** (- (n - 1 ) + i )) * means ).tolist ()
217+ data += (- (10 ** (- (max_exponent_bits - 1 ) + i )) * means ).tolist ()
172218
173- if additional_items > 0 :
174- boundaries = torch .linspace (0.1 , 1 , additional_items + 1 )
175- means = (boundaries [:- 1 ] + boundaries [1 :]) / 2.0
176- data += ((10 ** (- (n - 1 ) + i )) * means ).tolist ()
177- if signed :
178- data += (- (10 ** (- (n - 1 ) + i )) * means ).tolist ()
219+ if additional_items > 0 :
220+ boundaries = torch .linspace (0.1 , 1 , additional_items + 1 )
221+ means = (boundaries [:- 1 ] + boundaries [1 :]) / 2.0
222+ data += ((10 ** (- (max_exponent_bits - 1 ) + i )) * means ).tolist ()
223+ if signed :
224+ data += (- (10 ** (- (max_exponent_bits - 1 ) + i )) * means ).tolist ()
179225
180226 data .append (0 )
181227 data .append (1.0 )
228+
229+ gap = 256 - len (data )
230+ for i in range (gap ):
231+ data .append (0 )
232+
182233 data .sort ()
183234 return Tensor (data )
184235
@@ -322,9 +373,7 @@ def nvidia_transform(
322373 return out , new_state
323374
324375
325- def estimate_quantiles (
326- A : Tensor , out : Tensor = None , offset : float = 1 / 512
327- ) -> Tensor :
376+ def estimate_quantiles (A : Tensor , out : Tensor = None , offset : float = 1 / 512 , num_quantiles = 256 ) -> Tensor :
328377 '''
329378 Estimates 256 equidistant quantiles on the input tensor eCDF.
330379
@@ -344,25 +393,36 @@ def estimate_quantiles(
344393 out : torch.Tensor
345394 Tensor with the 256 estimated quantiles.
346395 offset : float
347- The offset for the first and last quantile from 0 and 1. Default: 1/512
396+ The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
397+ num_quantiles : int
398+ The number of equally spaced quantiles.
348399
349400 Returns
350401 -------
351402 torch.Tensor:
352403 The 256 quantiles in float32 datatype.
353404 '''
405+ if A .numel () < 256 : raise NotImplementedError (f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only { A .numel ()} values.' )
406+ if num_quantiles > 256 : raise NotImplementedError (f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={ num_quantiles } " )
407+ if num_quantiles < 256 and offset == 1 / (512 ):
408+ # override default arguments
409+ offset = 1 / (2 * num_quantiles )
410+
354411 if out is None : out = torch .zeros ((256 ,), dtype = torch .float32 , device = A .device )
355412 is_on_gpu ([A , out ])
413+ device = pre_call (A .device )
356414 if A .dtype == torch .float32 :
357- lib .cestimate_quantiles_fp32 (
358- get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ())
359- )
415+ lib .cestimate_quantiles_fp32 (get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ()))
360416 elif A .dtype == torch .float16 :
361- lib .cestimate_quantiles_fp16 (
362- get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ())
363- )
417+ lib .cestimate_quantiles_fp16 (get_ptr (A ), get_ptr (out ), ct .c_float (offset ), ct .c_int (A .numel ()))
364418 else :
365419 raise NotImplementedError (f"Not supported data type { A .dtype } " )
420+ post_call (device )
421+
422+ if num_quantiles < 256 :
423+ idx = torch .linspace (0 , 255 , num_quantiles ).long ().to (A .device )
424+ out = out [idx ]
425+
366426 return out
367427
368428
@@ -395,15 +455,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
395455 The quantization state to undo the quantization.
396456 """
397457
458+
398459 if code is None :
399460 if "dynamic" not in name2qmap :
400461 name2qmap ["dynamic" ] = create_dynamic_map ().to (A .device )
401462 code = name2qmap ["dynamic" ]
402- code = code .to (A .device )
403463
404464 if absmax is None :
405465 n = A .numel ()
406- blocksize = (blocksize if A .device .type == 'cpu' else 4096 )
407466 blocks = n // blocksize
408467 blocks += 1 if n % blocksize > 0 else 0
409468 absmax = torch .zeros ((blocks ,), device = A .device )
@@ -412,29 +471,33 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
412471 out = torch .zeros_like (A , dtype = torch .uint8 )
413472
414473 if A .device .type != 'cpu' :
415- is_on_gpu ([code , A , absmax , out , rand ])
474+ assert blocksize in [4096 , 2048 , 1024 , 512 ]
475+ cblocksize = ct .c_int32 (blocksize )
476+ prev_device = pre_call (A .device )
477+ code = code .to (A .device )
416478 if rand is not None :
479+ is_on_gpu ([code , A , out , absmax , rand ])
480+ assert blocksize == 4096
417481 assert rand .numel () >= 1024
418482 rand_offset = random .randint (0 , 1023 )
419483 if A .dtype == torch .float32 :
420484 lib .cquantize_blockwise_stochastic_fp32 (get_ptr (code ), get_ptr (A ),get_ptr (absmax ), get_ptr (out ), get_ptr (rand ), ct .c_int32 (rand_offset ), ct .c_int (A .numel ()))
421485 elif A .dtype == torch .float16 :
422486 lib .cquantize_blockwise_stochastic_fp16 (get_ptr (code ), get_ptr (A ),get_ptr (absmax ), get_ptr (out ), get_ptr (rand ), ct .c_int32 (rand_offset ), ct .c_int (A .numel ()))
423487 else :
424- raise ValueError (
425- f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } "
426- )
488+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
427489 else :
490+ is_on_gpu ([code , A , out , absmax ])
428491 if A .dtype == torch .float32 :
429- lib .cquantize_blockwise_fp32 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ),ct .c_int (A .numel ()))
492+ lib .cquantize_blockwise_fp32 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), cblocksize , ct .c_int (A .numel ()))
430493 elif A .dtype == torch .float16 :
431- lib .cquantize_blockwise_fp16 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ),ct .c_int (A .numel ()))
494+ lib .cquantize_blockwise_fp16 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), cblocksize , ct .c_int (A .numel ()))
432495 else :
433- raise ValueError (
434- f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } "
435- )
496+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
497+ post_call (A .device )
436498 else :
437499 # cpu
500+ code = code .cpu ()
438501 assert rand is None
439502 lib .cquantize_blockwise_cpu_fp32 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_longlong (blocksize ), ct .c_longlong (A .numel ()))
440503
@@ -479,27 +542,30 @@ def dequantize_blockwise(
479542 if "dynamic" not in name2qmap :
480543 name2qmap ["dynamic" ] = create_dynamic_map ().to (A .device )
481544 code = name2qmap ["dynamic" ]
482- code = code .to (A .device )
483545
484546 if out is None :
485547 out = torch .zeros_like (A , dtype = torch .float32 )
486548 if quant_state is None :
487549 quant_state = (absmax , code )
550+ else :
551+ absmax , code = quant_state
488552
489553
490554 if A .device .type != 'cpu' :
491- if blocksize not in [2048 , 4096 ]:
492- raise ValueError (f"The blockwise of { blocksize } is not supported. Supported values: [2048 4096]" )
555+ device = pre_call (A .device )
556+ code = code .to (A .device )
557+ if blocksize not in [2048 , 4096 , 1024 , 512 ]:
558+ raise ValueError (f"The blockwise of { blocksize } is not supported. Supported values: [2048, 4096, 1024, 512]" )
493559 is_on_gpu ([A , out ])
494560 if out .dtype == torch .float32 :
495- lib .cdequantize_blockwise_fp32 (get_ptr (quant_state [ 1 ] ), get_ptr (A ), get_ptr (quant_state [ 0 ] ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (A .numel ()))
561+ lib .cdequantize_blockwise_fp32 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (A .numel ()))
496562 elif out .dtype == torch .float16 :
497- lib .cdequantize_blockwise_fp16 (get_ptr (quant_state [ 1 ] ), get_ptr (A ), get_ptr (quant_state [ 0 ] ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (A .numel ()))
563+ lib .cdequantize_blockwise_fp16 (get_ptr (code ), get_ptr (A ), get_ptr (absmax ), get_ptr (out ), ct .c_int (blocksize ), ct .c_int (A .numel ()))
498564 else :
499- raise ValueError (
500- f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } "
501- )
565+ raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
566+ post_call (A .device )
502567 else :
568+ code = code .cpu ()
503569 lib .cdequantize_blockwise_cpu_fp32 (get_ptr (quant_state [1 ]), get_ptr (A ), get_ptr (quant_state [0 ]), get_ptr (out ), ct .c_longlong (blocksize ), ct .c_longlong (A .numel ()))
504570
505571 return out
0 commit comments