Skip to content

Commit d2bb34c

Browse files
committed
feat: Added Base PerChannelSTEQmax and PTnative classes
Signed-off-by: Brandon Groth <[email protected]>
1 parent 9dbd861 commit d2bb34c

File tree

1 file changed

+268
-5
lines changed

1 file changed

+268
-5
lines changed

fms_mo/quant_refactor/per_channel_ste.py

Lines changed: 268 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def forward(
211211
212212
Args:
213213
ctx (torch.autograd.Function): Forward/Backward context object.
214-
input_tensor_tensor (torch.FloatTensor): Tensor to be quantized.
214+
input_tensor (torch.FloatTensor): Tensor to be quantized.
215215
num_bits (torch.IntTensor): Number of bit for quantization.
216216
clip_valn (torch.FloatTensor): Lower clip value bound.
217217
clip_val (torch.FloatTensor): Upper clip value bound.
@@ -283,7 +283,7 @@ def calc_qparams(
283283
n_levels = 2**num_bits - 2 if qlevel_lowering else 2**num_bits - 1
284284
scale = (clip_val - clip_valn) / n_levels
285285
zero_point = (
286-
torch.zeros_like(scale)
286+
torch.zeros_like(scale).to(torch.int)
287287
if symmetric
288288
else torch.round(-clip_valn / scale).to(torch.int)
289289
)
@@ -422,14 +422,17 @@ def forward(
422422
423423
Args:
424424
ctx (torch.autograd.Function): Forward/Backward context object.
425-
input_tensor_tensor (torch.FloatTensor): Tensor to be quantized.
425+
input_tensor (torch.FloatTensor): Tensor to be quantized.
426426
num_bits (torch.IntTensor): Number of bit for quantization.
427427
clip_valn (torch.FloatTensor): Lower clip value bound.
428428
clip_val (torch.FloatTensor): Upper clip value bound.
429429
dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
430430
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
431431
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
432432
Defaults to True.
433+
use_code (bool, optional): Specify a specific SAWB code for quantization.
434+
Defaults to False.
435+
axis (int, optional): Specify an axis to quantize. Defaults to default_axis.
433436
434437
Returns:
435438
torch.Tensor: Dequantized or Quantized output tensor.
@@ -543,14 +546,15 @@ def forward(
543546
544547
Args:
545548
ctx (torch.autograd.Function): Forward/Backward context object.
546-
input_tensor_tensor (torch.FloatTensor): Tensor to be quantized.
549+
input_tensor (torch.FloatTensor): Tensor to be quantized.
547550
num_bits (torch.IntTensor): Number of bit for quantization.
548551
clip_valn (torch.FloatTensor): Lower clip value bound.
549552
clip_val (torch.FloatTensor): Upper clip value bound.
550553
dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
551554
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
552555
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
553556
Defaults to True.
557+
use_code (bool, optional): Specify a specific SAWB code for quantization.
554558
axis (int, optional): Specify which tensor dimension to quantize indiviually.
555559
Defaults to 0.
556560
@@ -615,7 +619,7 @@ def calc_qparams(
615619
n_levels = 2**num_bits - 2 if qlevel_lowering else 2**num_bits - 1
616620
scale = (clip_val - clip_valn) / n_levels
617621
zero_point = (
618-
torch.zeros_like(scale)
622+
torch.zeros_like(scale).to(torch.int)
619623
if symmetric
620624
else torch.round(-clip_valn / scale).to(torch.int)
621625
)
@@ -667,3 +671,262 @@ def backward(ctx, grad_output):
667671
torch.FloatTensor, None,...,None: STE Gradient
668672
"""
669673
return grad_output, None, None, None, None, None, None
674+
675+
676+
class PerChannelSTEQmax(PerChannelSTE):
677+
"""
678+
PerChannelSTE Base for Qmax
679+
680+
Extends:
681+
PerChannelSTE
682+
"""
683+
684+
@staticmethod
685+
def forward(
686+
ctx,
687+
input_tensor: torch.FloatTensor,
688+
num_bits: torch.IntTensor,
689+
clip_valn: torch.FloatTensor,
690+
clip_val: torch.FloatTensor,
691+
dequantize: bool = True,
692+
symmetric: bool = False,
693+
qlevel_lowering: bool = False,
694+
use_minmax: bool = False,
695+
axis: int = default_axis,
696+
):
697+
"""
698+
General forward method:
699+
Set clip values to dtype of input_tensor tensor
700+
Compute # of quantized levels, scale, and zero point
701+
Save data for backward()
702+
Perform linear quantization on input_tensor tensor
703+
return output
704+
705+
Args:
706+
ctx (torch.autograd.Function): Forward/Backward context object.
707+
input_tensor (torch.FloatTensor): Tensor to be quantized.
708+
num_bits (torch.IntTensor): Number of bit for quantization.
709+
clip_valn (torch.FloatTensor): Lower clip value bound.
710+
clip_val (torch.FloatTensor): Upper clip value bound.
711+
dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
712+
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
713+
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
714+
Defaults to True.
715+
axis (int, optional): Specify which tensor dimension to quantize indiviually.
716+
Defaults to 0.
717+
718+
Returns:
719+
torch.Tensor: Dequantized or Quantized output tensor.
720+
"""
721+
clip_valn, clip_val = transform_clips(input_tensor.dtype, clip_valn, clip_val)
722+
n_levels, scale, zero_point = PerChannelSTEQmax.calc_qparams(
723+
num_bits,
724+
clip_valn,
725+
clip_val,
726+
symmetric,
727+
qlevel_lowering,
728+
axis,
729+
input_tensor.shape,
730+
use_minmax,
731+
)
732+
PerChannelSTE.save_tensors(
733+
ctx,
734+
tensors=(input_tensor, n_levels, clip_valn, clip_val, scale, zero_point),
735+
)
736+
output = linear_quantization(
737+
input_tensor,
738+
num_bits,
739+
scale,
740+
zero_point,
741+
dequantize,
742+
symmetric,
743+
qlevel_lowering,
744+
)
745+
return output
746+
747+
@classmethod
748+
def calc_qparams(
749+
cls,
750+
num_bits: torch.IntTensor,
751+
clip_valn: torch.FloatTensor,
752+
clip_val: torch.FloatTensor,
753+
symmetric: bool = False,
754+
qlevel_lowering: bool = True,
755+
axis: int = default_axis,
756+
tensor_shape: torch.Size = None,
757+
use_minmax: bool = False,
758+
):
759+
"""
760+
Compute the scale and zero_point from num_bits and clip values
761+
762+
Args:
763+
num_bits (torch.IntTensor): Number of bit for quantization.
764+
clip_valn (torch.FloatTensor): Lower clip value.
765+
clip_val (torch.FloatTensor): Upper clip value.
766+
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
767+
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
768+
Defaults to True.
769+
axis (int, optional): Specify which tensor dimension to quantize indiviually.
770+
Defaults to 0.
771+
use_minmax (bool, optional): Specify to use Qminmax. Defaults to False.
772+
773+
Returns:
774+
torch.IntTensor, torch.FloatTensor, torch.IntTensor: Quantized parameters
775+
"""
776+
if use_minmax: # asymmetric case
777+
n_levels = 2**num_bits - 1
778+
_, scale, zero_point = asymmetric_linear_quantization_params(
779+
num_bits, clip_valn, clip_val, qlevel_lowering=False
780+
)
781+
else:
782+
n_levels = 2**num_bits - 2 if qlevel_lowering else 2**num_bits - 1
783+
_, scale, zero_point = symmetric_linear_quantization_params(
784+
num_bits, clip_val, qlevel_lowering
785+
)
786+
787+
# Broadcast scale, zero_point to tensor shape
788+
scale, zero_point = per_channel_axis(scale, zero_point, tensor_shape, axis)
789+
790+
return n_levels, scale, zero_point
791+
792+
793+
class PerChannelSTEQmax_PTnative(PerChannelSTE_PTnative):
794+
"""
795+
PerChannelSTEQmax_PTnative Base for Qmax
796+
797+
Extends:
798+
PerChannelSTE_PTnative
799+
"""
800+
801+
@staticmethod
802+
def forward(
803+
ctx,
804+
input_tensor: torch.FloatTensor,
805+
num_bits: torch.IntTensor,
806+
clip_valn: torch.FloatTensor,
807+
clip_val: torch.FloatTensor,
808+
dequantize: bool = True,
809+
symmetric: bool = False,
810+
qlevel_lowering: bool = False,
811+
use_minmax: bool = False,
812+
axis: int = default_axis,
813+
):
814+
"""
815+
General forward method:
816+
Set clip values to dtype of input_tensor tensor
817+
Compute # of quantized levels, scale, and zero point
818+
Save data for backward()
819+
Perform linear quantization on input_tensor tensor
820+
return output
821+
822+
Args:
823+
ctx (torch.autograd.Function): Forward/Backward context object.
824+
input_tensor (torch.FloatTensor): Tensor to be quantized.
825+
num_bits (torch.IntTensor): Number of bit for quantization.
826+
clip_valn (torch.FloatTensor): Lower clip value bound.
827+
clip_val (torch.FloatTensor): Upper clip value bound.
828+
dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
829+
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
830+
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
831+
Defaults to True.
832+
axis (int, optional): Specify which tensor dimension to quantize indiviually.
833+
Defaults to 0.
834+
835+
Returns:
836+
torch.Tensor: Dequantized or Quantized output tensor.
837+
"""
838+
clip_valn, clip_val = transform_clips(
839+
input_tensor.dtype,
840+
clip_valn,
841+
clip_val,
842+
)
843+
(
844+
_,
845+
scale,
846+
zero_point,
847+
qint_l,
848+
qint_h,
849+
qint_dtype,
850+
) = PerChannelSTEQmax_PTnative.calc_qparams(
851+
num_bits,
852+
clip_valn,
853+
clip_val,
854+
symmetric,
855+
qlevel_lowering,
856+
axis=axis,
857+
tensor_shape=input_tensor.shape,
858+
use_minmax=use_minmax,
859+
)
860+
output = PerChannelSTE_PTnative.linear_quantization(
861+
input_tensor,
862+
scale,
863+
zero_point,
864+
qint_l,
865+
qint_h,
866+
qint_dtype,
867+
dequantize,
868+
axis,
869+
)
870+
return output
871+
872+
@classmethod
873+
def calc_qparams(
874+
cls,
875+
num_bits: torch.IntTensor,
876+
clip_valn: torch.FloatTensor,
877+
clip_val: torch.FloatTensor,
878+
symmetric: bool = False,
879+
qlevel_lowering: bool = True,
880+
axis: int = default_axis,
881+
tensor_shape: torch.Size = None,
882+
use_minmax: bool = False,
883+
):
884+
"""
885+
Compute the scale and zero_point from num_bits and clip values
886+
887+
Args:
888+
num_bits (torch.IntTensor): Number of bit for quantization.
889+
clip_valn (torch.FloatTensor): Lower clip value.
890+
clip_val (torch.FloatTensor): Upper clip value.
891+
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
892+
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
893+
Defaults to True.
894+
axis (int, optional): Specify which tensor dimension to quantize indiviually.
895+
Defaults to 0.
896+
use_minmax (bool, optional): Specify to use Qminmax. Defaults to False.
897+
898+
Returns:
899+
torch.IntTensor, torch.FloatTensor, torch.IntTensor: Quantized parameters
900+
"""
901+
if use_minmax: # asymmetric case
902+
n_levels = 2**num_bits - 1
903+
_, scale, zero_point = asymmetric_linear_quantization_params(
904+
num_bits, clip_valn, clip_val, qlevel_lowering=False
905+
)
906+
else:
907+
n_levels = 2**num_bits - 2 if qlevel_lowering else 2**num_bits - 1
908+
_, scale, zero_point = symmetric_linear_quantization_params(
909+
num_bits, clip_val, qlevel_lowering
910+
)
911+
912+
qint_min, qint_max, qint_dtype = PerChannelSTE_PTnative.qint_bounds(
913+
num_bits, zero_point, symmetric, qlevel_lowering
914+
)
915+
916+
# Note: PTnative doesn't require broadcast for scale/zero_point
917+
return n_levels, scale, zero_point, qint_min, qint_max, qint_dtype
918+
919+
@staticmethod
920+
def backward(ctx, grad_output):
921+
"""
922+
General STE backward method:
923+
Return grad_output + None args to match forward input_tensor
924+
925+
Args:
926+
ctx (torch.autograd.Function): Forward/Backward context object.
927+
grad_output (torch.FloatTensor): Gradient tensor
928+
929+
Returns:
930+
torch.FloatTensor, None,...,None: STE Gradient
931+
"""
932+
return grad_output, None, None, None, None, None, None

0 commit comments

Comments
 (0)