Skip to content

Commit 603f060

Browse files
committed
fix: Commented SAWBplusZeroperCh_new forward
Signed-off-by: Brandon Groth <[email protected]>
1 parent e0d1a0a commit 603f060

File tree

1 file changed

+67
-69
lines changed

1 file changed

+67
-69
lines changed

fms_mo/quant_refactor/sawb_new.py

Lines changed: 67 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
qscheme_per_tensor = Qscheme(
4040
unit="perT",
4141
symmetric=False,
42-
Nch=None,
43-
Ngrp=None,
4442
single_sided=False,
4543
qlevel_lowering=False,
4644
)
@@ -560,75 +558,75 @@ def calc_qparams(
560558
# Placeholder classes for PerCh - need to rework #
561559
class SAWBPlusZeroPerChSTE_new(PerChannelSTESAWB):
562560
"""
563-
per-channel SAWB with zero alignment, can use 15 or 16 bins, i.e. [-7,7] or [-7,8]
561+
per-channel SAWB with zero alignment, ca,8n use 15 or 16 bins, i.e. [-7,7] or [-7]
564562
"""
565563

566-
@staticmethod
567-
def forward(
568-
ctx,
569-
input_tensor: torch.FloatTensor,
570-
num_bits: torch.IntTensor,
571-
_clip_valn: torch.FloatTensor = clip_valn_default,
572-
clip_val: torch.FloatTensor = clip_val_default,
573-
dequantize: bool = True,
574-
_symmetric: bool = False,
575-
_qlevel_lowering: bool = False,
576-
_use_code: bool = False,
577-
):
578-
"""
579-
Forward function for SAWBPlusZeroPerChSTE
580-
581-
Args:
582-
ctx (torch.autograd.Function): Forward/Backward context object.
583-
input_tensor (torch.FloatTensor): Tensor to be quantized.
584-
num_bits (torch.IntTensor): Number of bit for quantization.
585-
clip_valn (torch.FloatTensor): Lower clip value bound.
586-
clip_val (torch.FloatTensor): Upper clip value bound.
587-
dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
588-
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
589-
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
590-
Defaults to True.
591-
use_code (bool, optional): Specify using SAWB code. Defaults to False.
592-
593-
Returns:
594-
torch.Tensor: Dequantized or Quantized output tensor.
595-
"""
596-
# assert num_bits in [4, 8], "only implemented for 4bit and 8bit"
597-
598-
SAWBcode_mapping = {8: 803, 4: 403, 2: 103}
599-
num_bits_int = (
600-
num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits
601-
)
602-
clip_val, _ = sawb_params_code(
603-
num_bits_int, SAWBcode_mapping[num_bits_int], input_tensor, perCh=True
604-
)
605-
606-
_nspace = 2**num_bits - 2 # + objSAWB.use16bins # Ignore 16bins for now
607-
int_l = -(2 ** (num_bits - 1)) + 1
608-
int_u = -int_l # + objSAWB.use16bins # Ignore 16bins for now
609-
610-
scale = clip_val * 2 / (2**num_bits - 2)
611-
# original SAWB assumes odd number of bins when calc clip_val
612-
zero_point = torch.zeros_like(scale) # SAWB always centers around 0 and align 0
613-
614-
if dequantize:
615-
output = torch.fake_quantize_per_channel_affine(
616-
input_tensor.float(),
617-
scale.float(),
618-
zero_point.float(),
619-
axis=0,
620-
quant_min=int_l,
621-
quant_max=int_u,
622-
).to(
623-
clip_val.dtype
624-
) # NOTE return will be a fp32 tensor; function only support float()
625-
else:
626-
output = torch.quantize_per_channel(
627-
input_tensor, scale, zero_point, 0, torch.qint8
628-
).int_repr()
629-
# NOTE return will be a torch.int8 tensor
630-
631-
return output
564+
# @staticmethod
565+
# def forward(
566+
# ctx,
567+
# input_tensor: torch.FloatTensor,
568+
# num_bits: torch.IntTensor,
569+
# _clip_valn: torch.FloatTensor = clip_valn_default,
570+
# clip_val: torch.FloatTensor = clip_val_default,
571+
# dequantize: bool = True,
572+
# _symmetric: bool = False,
573+
# _qlevel_lowering: bool = False,
574+
# _use_code: bool = False,
575+
# ):
576+
# """
577+
# Forward function for SAWBPlusZeroPerChSTE
578+
579+
# Args:
580+
# ctx (torch.autograd.Function): Forward/Backward context object.
581+
# input_tensor (torch.FloatTensor): Tensor to be quantized.
582+
# num_bits (torch.IntTensor): Number of bit for quantization.
583+
# clip_valn (torch.FloatTensor): Lower clip value bound.
584+
# clip_val (torch.FloatTensor): Upper clip value bound.
585+
# dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
586+
# symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
587+
# qlevel_lowering (bool, optional): Specify lowering of quantized levels.
588+
# Defaults to True.
589+
# use_code (bool, optional): Specify using SAWB code. Defaults to False.
590+
591+
# Returns:
592+
# torch.Tensor: Dequantized or Quantized output tensor.
593+
# """
594+
# # assert num_bits in [4, 8], "only implemented for 4bit and 8bit"
595+
596+
# SAWBcode_mapping = {8: 803, 4: 403, 2: 103}
597+
# num_bits_int = (
598+
# num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits
599+
# )
600+
# clip_val, _ = sawb_params_code(
601+
# num_bits_int, SAWBcode_mapping[num_bits_int], input_tensor, perCh=True
602+
# )
603+
604+
# _nspace = 2**num_bits - 2 # + objSAWB.use16bins # Ignore 16bins for now
605+
# int_l = -(2 ** (num_bits - 1)) + 1
606+
# int_u = -int_l # + objSAWB.use16bins # Ignore 16bins for now
607+
608+
# scale = clip_val * 2 / (2**num_bits - 2)
609+
# # original SAWB assumes odd number of bins when calc clip_val
610+
# zero_point = torch.zeros_like(scale) # SAWB always centers around 0 and align 0
611+
612+
# if dequantize:
613+
# output = torch.fake_quantize_per_channel_affine(
614+
# input_tensor.float(),
615+
# scale.float(),
616+
# zero_point.float(),
617+
# axis=0,
618+
# quant_min=int_l,
619+
# quant_max=int_u,
620+
# ).to(
621+
# clip_val.dtype
622+
# ) # NOTE return will be a fp32 tensor; function only support float()
623+
# else:
624+
# output = torch.quantize_per_channel(
625+
# input_tensor, scale, zero_point, 0, torch.qint8
626+
# ).int_repr()
627+
# # NOTE return will be a torch.int8 tensor
628+
629+
# return output
632630

633631
@staticmethod
634632
def backward(ctx, grad_output):

0 commit comments

Comments
 (0)