Skip to content

Commit a852ca8

Browse files
committed
feat: Added SAWB perCh PTnative and removed 16 bins
Signed-off-by: Brandon Groth <[email protected]>
1 parent 87f6907 commit a852ca8

File tree

2 files changed

+164
-86
lines changed

2 files changed

+164
-86
lines changed

fms_mo/quant_refactor/per_channel_ste.py

Lines changed: 159 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
per_channel_axis,
2929
)
3030

31+
default_axis = int(0)
3132

3233
class PerChannelSTE(torch.autograd.Function):
3334
"""Base class for customized forward/backward functions that is NOT using PT native func.
@@ -51,7 +52,7 @@ def forward(
5152
dequantize: bool = True,
5253
symmetric: bool = False,
5354
qlevel_lowering: bool = False,
54-
axis: int = 0,
55+
axis: int = default_axis,
5556
):
5657
"""
5758
General forward method:
@@ -104,7 +105,7 @@ def calc_qparams(
104105
clip_val: torch.FloatTensor,
105106
symmetric: bool = False,
106107
qlevel_lowering: bool = True,
107-
axis: int = 0,
108+
axis: int = default_axis,
108109
tensor_shape: torch.Size = None,
109110
):
110111
"""
@@ -195,7 +196,7 @@ def forward(
195196
dequantize: bool = True,
196197
symmetric: bool = False,
197198
qlevel_lowering: bool = False,
198-
axis: int = 0,
199+
axis: int = default_axis,
199200
):
200201
"""
201202
General forward method:
@@ -293,7 +294,7 @@ def qint_bounds(
293294
num_bits_int = (
294295
num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits
295296
)
296-
if symmetric and zero_point == 0:
297+
if symmetric and torch.sum(zero_point) == 0:
297298
qlevel_symmetric = 1 if qlevel_lowering else 0
298299
qint_l, qint_h = (
299300
-(2 ** (num_bits_int - 1)) + qlevel_symmetric,
@@ -315,7 +316,7 @@ def linear_quantization(
315316
qint_h: int,
316317
qint_dtype: torch.dtype,
317318
dequantize: bool = True,
318-
axis: int = 0,
319+
axis: int = default_axis,
319320
) -> torch.Tensor:
320321
"""
321322
Linear quantization for PTnative STE
@@ -392,7 +393,7 @@ def forward(
392393
symmetric: bool = False,
393394
qlevel_lowering: bool = False,
394395
use_code: bool = False,
395-
axis: int = 0,
396+
axis: int = default_axis,
396397
):
397398
"""
398399
General forward method:
@@ -450,7 +451,7 @@ def calc_qparams(
450451
clip_val: torch.FloatTensor,
451452
symmetric: bool = False,
452453
qlevel_lowering: bool = True,
453-
axis: int = 0,
454+
axis: int = default_axis,
454455
tensor_shape: torch.Size = None,
455456
use_code: bool = False,
456457
):
@@ -486,4 +487,154 @@ def calc_qparams(
486487
output = n_levels, scale, zero_point
487488
else:
488489
raise ValueError("SAWB has non-symmetric Qscheme")
489-
return output
490+
return output
491+
492+
class PerChannelSTESAWB_PTnative(PerChannelSTE_PTnative):
493+
"""Base class for customized forward/backward functions.
494+
There's a family of non-learnable quantizers, such as SAWB, MinMax,
495+
whose forward can leverage PT native functions and backward is simply STE.
496+
We just need to calculate scale in the upper level quantizer class then those quantizers
497+
could all be using the same base "STE function"
498+
499+
math should be consistent with pytorch: https://pytorch.org/docs/stable/quantization.html
500+
x_int = round(x/scale + zp)
501+
x_dq = (x_int - zp) * scale
502+
503+
This type of class will be used by Quantizer.forward(), e.g.
504+
"""
505+
506+
@staticmethod
507+
def forward(
508+
ctx,
509+
input_tensor: torch.FloatTensor,
510+
num_bits: torch.IntTensor,
511+
clip_valn: torch.FloatTensor,
512+
clip_val: torch.FloatTensor,
513+
dequantize: bool = True,
514+
symmetric: bool = False,
515+
qlevel_lowering: bool = False,
516+
use_code: bool = False,
517+
axis: int = default_axis,
518+
):
519+
"""
520+
General forward method:
521+
Set clip values to dtype of input_tensor tensor
522+
Compute # of quantized levels, scale, and zero point
523+
Perform PTnative linear quantization on input_tensor tensor
524+
return output
525+
526+
Args:
527+
ctx (torch.autograd.Function): Forward/Backward context object.
528+
input_tensor_tensor (torch.FloatTensor): Tensor to be quantized.
529+
num_bits (torch.IntTensor): Number of bit for quantization.
530+
clip_valn (torch.FloatTensor): Lower clip value bound.
531+
clip_val (torch.FloatTensor): Upper clip value bound.
532+
dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
533+
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
534+
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
535+
Defaults to True.
536+
axis (int, optional): Specify which tensor dimension to quantize indiviually.
537+
Defaults to 0.
538+
539+
Returns:
540+
torch.Tensor: Dequantized or Quantized output tensor.
541+
"""
542+
clip_valn, clip_val = transform_clips(
543+
input_tensor.dtype,
544+
clip_valn,
545+
clip_val,
546+
)
547+
(
548+
_,
549+
scale,
550+
zero_point,
551+
qint_l,
552+
qint_h,
553+
qint_dtype,
554+
) = PerChannelSTESAWB_PTnative.calc_qparams(
555+
num_bits, clip_valn, clip_val, symmetric, qlevel_lowering,
556+
)
557+
output = PerChannelSTE_PTnative.linear_quantization(
558+
input_tensor, scale, zero_point, qint_l, qint_h, qint_dtype, dequantize, axis
559+
)
560+
return output
561+
562+
@classmethod
563+
def calc_qparams(
564+
cls,
565+
num_bits: torch.IntTensor,
566+
clip_valn: torch.FloatTensor,
567+
clip_val: torch.FloatTensor,
568+
symmetric: bool = False,
569+
qlevel_lowering: bool = False,
570+
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.IntTensor, int, int]:
571+
"""
572+
Compute the scale and zero_point from num_bits and clip values.
573+
Also, compute qint bounds for PT clamping.
574+
575+
Args:
576+
num_bits (torch.IntTensor): Number of bit for quantization.
577+
clip_valn (torch.FloatTensor): Lower clip value.
578+
clip_val (torch.FloatTensor): Upper clip value.
579+
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
580+
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
581+
Defaults to True.
582+
583+
Returns:
584+
Tuple[torch.IntTensor, torch.FloatTensor, torch.IntTensor]: Quantized parameters
585+
"""
586+
n_levels = 2**num_bits - 2 if qlevel_lowering else 2**num_bits - 1
587+
scale = (clip_val - clip_valn) / n_levels
588+
zero_point = (
589+
torch.zeros_like(scale)
590+
if symmetric
591+
else torch.round(-clip_valn / scale).to(torch.int)
592+
)
593+
qint_l, qint_h, qint_dtype = PerChannelSTESAWB_PTnative.qint_bounds(
594+
num_bits, zero_point, symmetric, qlevel_lowering
595+
)
596+
# Note: fake_quantize_per_channel_affine does not need matching dimensions for scale/zp to tensor
597+
return n_levels, scale, zero_point, qint_l, qint_h, qint_dtype
598+
599+
@classmethod
600+
def qint_bounds(
601+
cls,
602+
num_bits: torch.IntTensor,
603+
zero_point: torch.IntTensor,
604+
symmetric: bool = False,
605+
qlevel_lowering: bool = True,
606+
) -> Tuple[int, int, torch.dtype]:
607+
"""
608+
qlevel_symmetric: shift qlevel from [-2**(b-1), 2**(b-1)-1] to [-2**(b-1)+1, 2**(b-1)-1]
609+
For int8: [-127,127] ; For int4 [-7,7]
610+
qint bounds must be ints, not tensors
611+
"""
612+
num_bits_int = (
613+
num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits
614+
)
615+
if symmetric and torch.sum(zero_point) == 0:
616+
qlevel_symmetric = 1 if qlevel_lowering else 0
617+
qint_l, qint_h = (
618+
-(2 ** (num_bits_int - 1)) + qlevel_symmetric,
619+
2 ** (num_bits_int - 1) - 1,
620+
)
621+
qint_dtype = torch.qint8
622+
else: # single_sided or zero_point != 0
623+
qint_l, qint_h = 0, 2**num_bits_int - 1
624+
qint_dtype = torch.quint8
625+
return qint_l, qint_h, qint_dtype
626+
627+
@staticmethod
628+
def backward(ctx, grad_output):
629+
"""
630+
General STE backward method:
631+
Return grad_output + None args to match forward input_tensor
632+
633+
Args:
634+
ctx (torch.autograd.Function): Forward/Backward context object.
635+
grad_output (torch.FloatTensor): Gradient tensor
636+
637+
Returns:
638+
torch.FloatTensor, None,...,None: STE Gradient
639+
"""
640+
return grad_output, None, None, None, None, None, None

fms_mo/quant_refactor/sawb_new.py

Lines changed: 5 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030
from fms_mo.quant_refactor.per_channel_ste import (
3131
PerChannelSTESAWB,
32-
# PerChannelSTESAWB_PTnative,
32+
PerChannelSTESAWB_PTnative,
3333
)
3434
from fms_mo.quant_refactor.linear_utils import linear_dequantize, linear_quantize
3535
from fms_mo.quant_refactor.sawb_utils import sawb_params, sawb_params_code
@@ -132,14 +132,10 @@ def set_quantizer(self):
132132
if self.use_PT_native_Qfunc:
133133
if self.perCh:
134134
self.use_code = self.qscheme.qlevel_lowering
135-
# self.quantizer = PerChannelSTESAWB_PTnative
135+
self.quantizer = PerChannelSTESAWB_PTnative
136136
else:
137-
# if self.use_extended_range_4bits:
138-
# self.use_code = True
139-
# self.quantizer = SAWBPlus16ZeroSTE_PTnative
140-
# else:
141-
self.use_code = self.qscheme.qlevel_lowering
142-
self.quantizer = PerTensorSTESAWB_PTnative
137+
self.use_code = self.qscheme.qlevel_lowering
138+
self.quantizer = PerTensorSTESAWB_PTnative
143139

144140
else: # Non-PTnative quantizers
145141
self.use_code = self.qscheme.qlevel_lowering
@@ -148,8 +144,6 @@ def set_quantizer(self):
148144
self.quantizer = (
149145
SAWBPlusZeroPerChSTE_new
150146
if self.perCh and self.num_bits in [2, 4, 8]
151-
# else SAWBPlus16ZeroSTE_new
152-
# if self.extended_ranged and self.num_bits == 4
153147
else SAWBPlusZeroSTE_new
154148
)
155149
else:
@@ -558,79 +552,11 @@ def backward(ctx, grad_output):
558552
# return n_levels, clip_val, scale, zero_point, qint_l, qint_h, qint_dtype
559553

560554

561-
# Placeholder classes for PerCh - need to rework #
562555
class SAWBPlusZeroPerChSTE_new(PerChannelSTESAWB):
563556
"""
564557
per-channel SAWB with zero alignment, ca,8n use 15 or 16 bins, i.e. [-7,7] or [-7]
565558
"""
566559

567-
# @staticmethod
568-
# def forward(
569-
# ctx,
570-
# input_tensor: torch.FloatTensor,
571-
# num_bits: torch.IntTensor,
572-
# _clip_valn: torch.FloatTensor = clip_valn_default,
573-
# clip_val: torch.FloatTensor = clip_val_default,
574-
# dequantize: bool = True,
575-
# _symmetric: bool = False,
576-
# _qlevel_lowering: bool = False,
577-
# _use_code: bool = False,
578-
# ):
579-
# """
580-
# Forward function for SAWBPlusZeroPerChSTE
581-
582-
# Args:
583-
# ctx (torch.autograd.Function): Forward/Backward context object.
584-
# input_tensor (torch.FloatTensor): Tensor to be quantized.
585-
# num_bits (torch.IntTensor): Number of bit for quantization.
586-
# clip_valn (torch.FloatTensor): Lower clip value bound.
587-
# clip_val (torch.FloatTensor): Upper clip value bound.
588-
# dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
589-
# symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
590-
# qlevel_lowering (bool, optional): Specify lowering of quantized levels.
591-
# Defaults to True.
592-
# use_code (bool, optional): Specify using SAWB code. Defaults to False.
593-
594-
# Returns:
595-
# torch.Tensor: Dequantized or Quantized output tensor.
596-
# """
597-
# # assert num_bits in [4, 8], "only implemented for 4bit and 8bit"
598-
599-
# SAWBcode_mapping = {8: 803, 4: 403, 2: 103}
600-
# num_bits_int = (
601-
# num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits
602-
# )
603-
# clip_val, _ = sawb_params_code(
604-
# num_bits_int, SAWBcode_mapping[num_bits_int], input_tensor, perCh=True
605-
# )
606-
607-
# _nspace = 2**num_bits - 2 # + objSAWB.use16bins # Ignore 16bins for now
608-
# int_l = -(2 ** (num_bits - 1)) + 1
609-
# int_u = -int_l # + objSAWB.use16bins # Ignore 16bins for now
610-
611-
# scale = clip_val * 2 / (2**num_bits - 2)
612-
# # original SAWB assumes odd number of bins when calc clip_val
613-
# zero_point = torch.zeros_like(scale) # SAWB always centers around 0 and align 0
614-
615-
# if dequantize:
616-
# output = torch.fake_quantize_per_channel_affine(
617-
# input_tensor.float(),
618-
# scale.float(),
619-
# zero_point.float(),
620-
# axis=0,
621-
# quant_min=int_l,
622-
# quant_max=int_u,
623-
# ).to(
624-
# clip_val.dtype
625-
# ) # NOTE return will be a fp32 tensor; function only support float()
626-
# else:
627-
# output = torch.quantize_per_channel(
628-
# input_tensor, scale, zero_point, 0, torch.qint8
629-
# ).int_repr()
630-
# # NOTE return will be a torch.int8 tensor
631-
632-
# return output
633-
634560
@staticmethod
635561
def backward(ctx, grad_output):
636562
"""
@@ -645,3 +571,4 @@ def backward(ctx, grad_output):
645571
"""
646572
grad_input = grad_output.clone()
647573
return grad_input, None, None, None, None, None, None, None
574+

0 commit comments

Comments
 (0)