2828 per_channel_axis ,
2929)
3030
31+ default_axis = int (0 )
3132
3233class PerChannelSTE (torch .autograd .Function ):
3334 """Base class for customized forward/backward functions that is NOT using PT native func.
@@ -51,7 +52,7 @@ def forward(
5152 dequantize : bool = True ,
5253 symmetric : bool = False ,
5354 qlevel_lowering : bool = False ,
54- axis : int = 0 ,
55+ axis : int = default_axis ,
5556 ):
5657 """
5758 General forward method:
@@ -104,7 +105,7 @@ def calc_qparams(
104105 clip_val : torch .FloatTensor ,
105106 symmetric : bool = False ,
106107 qlevel_lowering : bool = True ,
107- axis : int = 0 ,
108+ axis : int = default_axis ,
108109 tensor_shape : torch .Size = None ,
109110 ):
110111 """
@@ -195,7 +196,7 @@ def forward(
195196 dequantize : bool = True ,
196197 symmetric : bool = False ,
197198 qlevel_lowering : bool = False ,
198- axis : int = 0 ,
199+ axis : int = default_axis ,
199200 ):
200201 """
201202 General forward method:
@@ -293,7 +294,7 @@ def qint_bounds(
293294 num_bits_int = (
294295 num_bits .item () if isinstance (num_bits , torch .Tensor ) else num_bits
295296 )
296- if symmetric and zero_point == 0 :
297+ if symmetric and torch . sum ( zero_point ) == 0 :
297298 qlevel_symmetric = 1 if qlevel_lowering else 0
298299 qint_l , qint_h = (
299300 - (2 ** (num_bits_int - 1 )) + qlevel_symmetric ,
@@ -315,7 +316,7 @@ def linear_quantization(
315316 qint_h : int ,
316317 qint_dtype : torch .dtype ,
317318 dequantize : bool = True ,
318- axis : int = 0 ,
319+ axis : int = default_axis ,
319320 ) -> torch .Tensor :
320321 """
321322 Linear quantization for PTnative STE
@@ -392,7 +393,7 @@ def forward(
392393 symmetric : bool = False ,
393394 qlevel_lowering : bool = False ,
394395 use_code : bool = False ,
395- axis : int = 0 ,
396+ axis : int = default_axis ,
396397 ):
397398 """
398399 General forward method:
@@ -450,7 +451,7 @@ def calc_qparams(
450451 clip_val : torch .FloatTensor ,
451452 symmetric : bool = False ,
452453 qlevel_lowering : bool = True ,
453- axis : int = 0 ,
454+ axis : int = default_axis ,
454455 tensor_shape : torch .Size = None ,
455456 use_code : bool = False ,
456457 ):
@@ -486,4 +487,154 @@ def calc_qparams(
486487 output = n_levels , scale , zero_point
487488 else :
488489 raise ValueError ("SAWB has non-symmetric Qscheme" )
489- return output
490+ return output
491+
492+ class PerChannelSTESAWB_PTnative (PerChannelSTE_PTnative ):
493+ """Base class for customized forward/backward functions.
494+ There's a family of non-learnable quantizers, such as SAWB, MinMax,
495+ whose forward can leverage PT native functions and backward is simply STE.
496+ We just need to calculate scale in the upper level quantizer class then those quantizers
497+ could all be using the same base "STE function"
498+
499+ math should be consistent with pytorch: https://pytorch.org/docs/stable/quantization.html
500+ x_int = round(x/scale + zp)
501+ x_dq = (x_int - zp) * scale
502+
503+ This type of class will be used by Quantizer.forward(), e.g.
504+ """
505+
506+ @staticmethod
507+ def forward (
508+ ctx ,
509+ input_tensor : torch .FloatTensor ,
510+ num_bits : torch .IntTensor ,
511+ clip_valn : torch .FloatTensor ,
512+ clip_val : torch .FloatTensor ,
513+ dequantize : bool = True ,
514+ symmetric : bool = False ,
515+ qlevel_lowering : bool = False ,
516+ use_code : bool = False ,
517+ axis : int = default_axis ,
518+ ):
519+ """
520+ General forward method:
521+ Set clip values to dtype of input_tensor tensor
522+ Compute # of quantized levels, scale, and zero point
523+ Perform PTnative linear quantization on input_tensor tensor
524+ return output
525+
526+ Args:
527+ ctx (torch.autograd.Function): Forward/Backward context object.
528+ input_tensor_tensor (torch.FloatTensor): Tensor to be quantized.
529+ num_bits (torch.IntTensor): Number of bit for quantization.
530+ clip_valn (torch.FloatTensor): Lower clip value bound.
531+ clip_val (torch.FloatTensor): Upper clip value bound.
532+ dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
533+ symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
534+ qlevel_lowering (bool, optional): Specify lowering of quantized levels.
535+ Defaults to True.
536+ axis (int, optional): Specify which tensor dimension to quantize indiviually.
537+ Defaults to 0.
538+
539+ Returns:
540+ torch.Tensor: Dequantized or Quantized output tensor.
541+ """
542+ clip_valn , clip_val = transform_clips (
543+ input_tensor .dtype ,
544+ clip_valn ,
545+ clip_val ,
546+ )
547+ (
548+ _ ,
549+ scale ,
550+ zero_point ,
551+ qint_l ,
552+ qint_h ,
553+ qint_dtype ,
554+ ) = PerChannelSTESAWB_PTnative .calc_qparams (
555+ num_bits , clip_valn , clip_val , symmetric , qlevel_lowering ,
556+ )
557+ output = PerChannelSTE_PTnative .linear_quantization (
558+ input_tensor , scale , zero_point , qint_l , qint_h , qint_dtype , dequantize , axis
559+ )
560+ return output
561+
562+ @classmethod
563+ def calc_qparams (
564+ cls ,
565+ num_bits : torch .IntTensor ,
566+ clip_valn : torch .FloatTensor ,
567+ clip_val : torch .FloatTensor ,
568+ symmetric : bool = False ,
569+ qlevel_lowering : bool = False ,
570+ ) -> Tuple [torch .IntTensor , torch .FloatTensor , torch .IntTensor , int , int ]:
571+ """
572+ Compute the scale and zero_point from num_bits and clip values.
573+ Also, compute qint bounds for PT clamping.
574+
575+ Args:
576+ num_bits (torch.IntTensor): Number of bit for quantization.
577+ clip_valn (torch.FloatTensor): Lower clip value.
578+ clip_val (torch.FloatTensor): Upper clip value.
579+ symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
580+ qlevel_lowering (bool, optional): Specify lowering of quantized levels.
581+ Defaults to True.
582+
583+ Returns:
584+ Tuple[torch.IntTensor, torch.FloatTensor, torch.IntTensor]: Quantized parameters
585+ """
586+ n_levels = 2 ** num_bits - 2 if qlevel_lowering else 2 ** num_bits - 1
587+ scale = (clip_val - clip_valn ) / n_levels
588+ zero_point = (
589+ torch .zeros_like (scale )
590+ if symmetric
591+ else torch .round (- clip_valn / scale ).to (torch .int )
592+ )
593+ qint_l , qint_h , qint_dtype = PerChannelSTESAWB_PTnative .qint_bounds (
594+ num_bits , zero_point , symmetric , qlevel_lowering
595+ )
596+ # Note: fake_quantize_per_channel_affine does not need matching dimensions for scale/zp to tensor
597+ return n_levels , scale , zero_point , qint_l , qint_h , qint_dtype
598+
599+ @classmethod
600+ def qint_bounds (
601+ cls ,
602+ num_bits : torch .IntTensor ,
603+ zero_point : torch .IntTensor ,
604+ symmetric : bool = False ,
605+ qlevel_lowering : bool = True ,
606+ ) -> Tuple [int , int , torch .dtype ]:
607+ """
608+ qlevel_symmetric: shift qlevel from [-2**(b-1), 2**(b-1)-1] to [-2**(b-1)+1, 2**(b-1)-1]
609+ For int8: [-127,127] ; For int4 [-7,7]
610+ qint bounds must be ints, not tensors
611+ """
612+ num_bits_int = (
613+ num_bits .item () if isinstance (num_bits , torch .Tensor ) else num_bits
614+ )
615+ if symmetric and torch .sum (zero_point ) == 0 :
616+ qlevel_symmetric = 1 if qlevel_lowering else 0
617+ qint_l , qint_h = (
618+ - (2 ** (num_bits_int - 1 )) + qlevel_symmetric ,
619+ 2 ** (num_bits_int - 1 ) - 1 ,
620+ )
621+ qint_dtype = torch .qint8
622+ else : # single_sided or zero_point != 0
623+ qint_l , qint_h = 0 , 2 ** num_bits_int - 1
624+ qint_dtype = torch .quint8
625+ return qint_l , qint_h , qint_dtype
626+
627+ @staticmethod
628+ def backward (ctx , grad_output ):
629+ """
630+ General STE backward method:
631+ Return grad_output + None args to match forward input_tensor
632+
633+ Args:
634+ ctx (torch.autograd.Function): Forward/Backward context object.
635+ grad_output (torch.FloatTensor): Gradient tensor
636+
637+ Returns:
638+ torch.FloatTensor, None,...,None: STE Gradient
639+ """
640+ return grad_output , None , None , None , None , None , None
0 commit comments