@@ -211,7 +211,7 @@ def forward(
211211
212212 Args:
213213 ctx (torch.autograd.Function): Forward/Backward context object.
214- input_tensor_tensor (torch.FloatTensor): Tensor to be quantized.
214+ input_tensor (torch.FloatTensor): Tensor to be quantized.
215215 num_bits (torch.IntTensor): Number of bit for quantization.
216216 clip_valn (torch.FloatTensor): Lower clip value bound.
217217 clip_val (torch.FloatTensor): Upper clip value bound.
@@ -283,7 +283,7 @@ def calc_qparams(
283283 n_levels = 2 ** num_bits - 2 if qlevel_lowering else 2 ** num_bits - 1
284284 scale = (clip_val - clip_valn ) / n_levels
285285 zero_point = (
286- torch .zeros_like (scale )
286+ torch .zeros_like (scale ). to ( torch . int )
287287 if symmetric
288288 else torch .round (- clip_valn / scale ).to (torch .int )
289289 )
@@ -422,14 +422,17 @@ def forward(
422422
423423 Args:
424424 ctx (torch.autograd.Function): Forward/Backward context object.
425- input_tensor_tensor (torch.FloatTensor): Tensor to be quantized.
425+ input_tensor (torch.FloatTensor): Tensor to be quantized.
426426 num_bits (torch.IntTensor): Number of bit for quantization.
427427 clip_valn (torch.FloatTensor): Lower clip value bound.
428428 clip_val (torch.FloatTensor): Upper clip value bound.
429429 dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
430430 symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
431431 qlevel_lowering (bool, optional): Specify lowering of quantized levels.
432432 Defaults to True.
433+ use_code (bool, optional): Specify a specific SAWB code for quantization.
434+ Defaults to False.
435+ axis (int, optional): Specify an axis to quantize. Defaults to default_axis.
433436
434437 Returns:
435438 torch.Tensor: Dequantized or Quantized output tensor.
@@ -543,14 +546,15 @@ def forward(
543546
544547 Args:
545548 ctx (torch.autograd.Function): Forward/Backward context object.
546- input_tensor_tensor (torch.FloatTensor): Tensor to be quantized.
549+ input_tensor (torch.FloatTensor): Tensor to be quantized.
547550 num_bits (torch.IntTensor): Number of bit for quantization.
548551 clip_valn (torch.FloatTensor): Lower clip value bound.
549552 clip_val (torch.FloatTensor): Upper clip value bound.
550553 dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
551554 symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
552555 qlevel_lowering (bool, optional): Specify lowering of quantized levels.
553556 Defaults to True.
557+ use_code (bool, optional): Specify a specific SAWB code for quantization.
554558 axis (int, optional): Specify which tensor dimension to quantize indiviually.
555559 Defaults to 0.
556560
@@ -615,7 +619,7 @@ def calc_qparams(
615619 n_levels = 2 ** num_bits - 2 if qlevel_lowering else 2 ** num_bits - 1
616620 scale = (clip_val - clip_valn ) / n_levels
617621 zero_point = (
618- torch .zeros_like (scale )
622+ torch .zeros_like (scale ). to ( torch . int )
619623 if symmetric
620624 else torch .round (- clip_valn / scale ).to (torch .int )
621625 )
@@ -667,3 +671,262 @@ def backward(ctx, grad_output):
667671 torch.FloatTensor, None,...,None: STE Gradient
668672 """
669673 return grad_output , None , None , None , None , None , None
674+
675+
676+ class PerChannelSTEQmax (PerChannelSTE ):
677+ """
678+ PerChannelSTE Base for Qmax
679+
680+ Extends:
681+ PerChannelSTE
682+ """
683+
684+ @staticmethod
685+ def forward (
686+ ctx ,
687+ input_tensor : torch .FloatTensor ,
688+ num_bits : torch .IntTensor ,
689+ clip_valn : torch .FloatTensor ,
690+ clip_val : torch .FloatTensor ,
691+ dequantize : bool = True ,
692+ symmetric : bool = False ,
693+ qlevel_lowering : bool = False ,
694+ use_minmax : bool = False ,
695+ axis : int = default_axis ,
696+ ):
697+ """
698+ General forward method:
699+ Set clip values to dtype of input_tensor tensor
700+ Compute # of quantized levels, scale, and zero point
701+ Save data for backward()
702+ Perform linear quantization on input_tensor tensor
703+ return output
704+
705+ Args:
706+ ctx (torch.autograd.Function): Forward/Backward context object.
707+ input_tensor (torch.FloatTensor): Tensor to be quantized.
708+ num_bits (torch.IntTensor): Number of bit for quantization.
709+ clip_valn (torch.FloatTensor): Lower clip value bound.
710+ clip_val (torch.FloatTensor): Upper clip value bound.
711+ dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
712+ symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
713+ qlevel_lowering (bool, optional): Specify lowering of quantized levels.
714+ Defaults to True.
715+ axis (int, optional): Specify which tensor dimension to quantize indiviually.
716+ Defaults to 0.
717+
718+ Returns:
719+ torch.Tensor: Dequantized or Quantized output tensor.
720+ """
721+ clip_valn , clip_val = transform_clips (input_tensor .dtype , clip_valn , clip_val )
722+ n_levels , scale , zero_point = PerChannelSTEQmax .calc_qparams (
723+ num_bits ,
724+ clip_valn ,
725+ clip_val ,
726+ symmetric ,
727+ qlevel_lowering ,
728+ axis ,
729+ input_tensor .shape ,
730+ use_minmax ,
731+ )
732+ PerChannelSTE .save_tensors (
733+ ctx ,
734+ tensors = (input_tensor , n_levels , clip_valn , clip_val , scale , zero_point ),
735+ )
736+ output = linear_quantization (
737+ input_tensor ,
738+ num_bits ,
739+ scale ,
740+ zero_point ,
741+ dequantize ,
742+ symmetric ,
743+ qlevel_lowering ,
744+ )
745+ return output
746+
747+ @classmethod
748+ def calc_qparams (
749+ cls ,
750+ num_bits : torch .IntTensor ,
751+ clip_valn : torch .FloatTensor ,
752+ clip_val : torch .FloatTensor ,
753+ symmetric : bool = False ,
754+ qlevel_lowering : bool = True ,
755+ axis : int = default_axis ,
756+ tensor_shape : torch .Size = None ,
757+ use_minmax : bool = False ,
758+ ):
759+ """
760+ Compute the scale and zero_point from num_bits and clip values
761+
762+ Args:
763+ num_bits (torch.IntTensor): Number of bit for quantization.
764+ clip_valn (torch.FloatTensor): Lower clip value.
765+ clip_val (torch.FloatTensor): Upper clip value.
766+ symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
767+ qlevel_lowering (bool, optional): Specify lowering of quantized levels.
768+ Defaults to True.
769+ axis (int, optional): Specify which tensor dimension to quantize indiviually.
770+ Defaults to 0.
771+ use_minmax (bool, optional): Specify to use Qminmax. Defaults to False.
772+
773+ Returns:
774+ torch.IntTensor, torch.FloatTensor, torch.IntTensor: Quantized parameters
775+ """
776+ if use_minmax : # asymmetric case
777+ n_levels = 2 ** num_bits - 1
778+ _ , scale , zero_point = asymmetric_linear_quantization_params (
779+ num_bits , clip_valn , clip_val , qlevel_lowering = False
780+ )
781+ else :
782+ n_levels = 2 ** num_bits - 2 if qlevel_lowering else 2 ** num_bits - 1
783+ _ , scale , zero_point = symmetric_linear_quantization_params (
784+ num_bits , clip_val , qlevel_lowering
785+ )
786+
787+ # Broadcast scale, zero_point to tensor shape
788+ scale , zero_point = per_channel_axis (scale , zero_point , tensor_shape , axis )
789+
790+ return n_levels , scale , zero_point
791+
792+
793+ class PerChannelSTEQmax_PTnative (PerChannelSTE_PTnative ):
794+ """
795+ PerChannelSTEQmax_PTnative Base for Qmax
796+
797+ Extends:
798+ PerChannelSTE_PTnative
799+ """
800+
801+ @staticmethod
802+ def forward (
803+ ctx ,
804+ input_tensor : torch .FloatTensor ,
805+ num_bits : torch .IntTensor ,
806+ clip_valn : torch .FloatTensor ,
807+ clip_val : torch .FloatTensor ,
808+ dequantize : bool = True ,
809+ symmetric : bool = False ,
810+ qlevel_lowering : bool = False ,
811+ use_minmax : bool = False ,
812+ axis : int = default_axis ,
813+ ):
814+ """
815+ General forward method:
816+ Set clip values to dtype of input_tensor tensor
817+ Compute # of quantized levels, scale, and zero point
818+ Save data for backward()
819+ Perform linear quantization on input_tensor tensor
820+ return output
821+
822+ Args:
823+ ctx (torch.autograd.Function): Forward/Backward context object.
824+ input_tensor (torch.FloatTensor): Tensor to be quantized.
825+ num_bits (torch.IntTensor): Number of bit for quantization.
826+ clip_valn (torch.FloatTensor): Lower clip value bound.
827+ clip_val (torch.FloatTensor): Upper clip value bound.
828+ dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
829+ symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
830+ qlevel_lowering (bool, optional): Specify lowering of quantized levels.
831+ Defaults to True.
832+ axis (int, optional): Specify which tensor dimension to quantize indiviually.
833+ Defaults to 0.
834+
835+ Returns:
836+ torch.Tensor: Dequantized or Quantized output tensor.
837+ """
838+ clip_valn , clip_val = transform_clips (
839+ input_tensor .dtype ,
840+ clip_valn ,
841+ clip_val ,
842+ )
843+ (
844+ _ ,
845+ scale ,
846+ zero_point ,
847+ qint_l ,
848+ qint_h ,
849+ qint_dtype ,
850+ ) = PerChannelSTEQmax_PTnative .calc_qparams (
851+ num_bits ,
852+ clip_valn ,
853+ clip_val ,
854+ symmetric ,
855+ qlevel_lowering ,
856+ axis = axis ,
857+ tensor_shape = input_tensor .shape ,
858+ use_minmax = use_minmax ,
859+ )
860+ output = PerChannelSTE_PTnative .linear_quantization (
861+ input_tensor ,
862+ scale ,
863+ zero_point ,
864+ qint_l ,
865+ qint_h ,
866+ qint_dtype ,
867+ dequantize ,
868+ axis ,
869+ )
870+ return output
871+
872+ @classmethod
873+ def calc_qparams (
874+ cls ,
875+ num_bits : torch .IntTensor ,
876+ clip_valn : torch .FloatTensor ,
877+ clip_val : torch .FloatTensor ,
878+ symmetric : bool = False ,
879+ qlevel_lowering : bool = True ,
880+ axis : int = default_axis ,
881+ tensor_shape : torch .Size = None ,
882+ use_minmax : bool = False ,
883+ ):
884+ """
885+ Compute the scale and zero_point from num_bits and clip values
886+
887+ Args:
888+ num_bits (torch.IntTensor): Number of bit for quantization.
889+ clip_valn (torch.FloatTensor): Lower clip value.
890+ clip_val (torch.FloatTensor): Upper clip value.
891+ symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
892+ qlevel_lowering (bool, optional): Specify lowering of quantized levels.
893+ Defaults to True.
894+ axis (int, optional): Specify which tensor dimension to quantize indiviually.
895+ Defaults to 0.
896+ use_minmax (bool, optional): Specify to use Qminmax. Defaults to False.
897+
898+ Returns:
899+ torch.IntTensor, torch.FloatTensor, torch.IntTensor: Quantized parameters
900+ """
901+ if use_minmax : # asymmetric case
902+ n_levels = 2 ** num_bits - 1
903+ _ , scale , zero_point = asymmetric_linear_quantization_params (
904+ num_bits , clip_valn , clip_val , qlevel_lowering = False
905+ )
906+ else :
907+ n_levels = 2 ** num_bits - 2 if qlevel_lowering else 2 ** num_bits - 1
908+ _ , scale , zero_point = symmetric_linear_quantization_params (
909+ num_bits , clip_val , qlevel_lowering
910+ )
911+
912+ qint_min , qint_max , qint_dtype = PerChannelSTE_PTnative .qint_bounds (
913+ num_bits , zero_point , symmetric , qlevel_lowering
914+ )
915+
916+ # Note: PTnative doesn't require broadcast for scale/zero_point
917+ return n_levels , scale , zero_point , qint_min , qint_max , qint_dtype
918+
919+ @staticmethod
920+ def backward (ctx , grad_output ):
921+ """
922+ General STE backward method:
923+ Return grad_output + None args to match forward input_tensor
924+
925+ Args:
926+ ctx (torch.autograd.Function): Forward/Backward context object.
927+ grad_output (torch.FloatTensor): Gradient tensor
928+
929+ Returns:
930+ torch.FloatTensor, None,...,None: STE Gradient
931+ """
932+ return grad_output , None , None , None , None , None , None
0 commit comments