Skip to content

Commit 63396cf

Browse files
committed
feat: Implemented Base PerChannelSTEQmax classes
Signed-off-by: Brandon Groth <[email protected]>
1 parent d2bb34c commit 63396cf

File tree

1 file changed

+49
-66
lines changed

1 file changed

+49
-66
lines changed

fms_mo/quant_refactor/qmax_new.py

Lines changed: 49 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424

2525
# Local
2626
from fms_mo.quant_refactor.base_quant import Qscheme, Quantizer
27+
from fms_mo.quant_refactor.per_channel_ste import (
28+
PerChannelSTE_PTnative,
29+
PerChannelSTEQmax,
30+
PerChannelSTEQmax_PTnative,
31+
)
2732
from fms_mo.quant_refactor.per_tensor_ste import (
2833
PerTensorSTE_PTnative,
2934
PerTensorSTEQmax,
@@ -123,13 +128,29 @@ def set_quantizer(self):
123128
self.copy_legacy_vars() # copy align_zero
124129

125130
if self.use_PT_native_Qfunc:
126-
if self.extend_act_range:
127-
self.quantizer_name = "QmaxExtend"
128-
self.set_clip_ratio()
129-
self.quantizer = QmaxExtendRangeSTE_PTnative
131+
if self.minmax:
132+
self.quantizer_name = "Qminmax"
133+
if self.perGrp:
134+
# self.quantizer = QmaxPerGrpSTE_PTnative
135+
pass
136+
else:
137+
self.quantizer = (
138+
QmaxPerChSTE_PTnative
139+
if self.perCh
140+
else PerTensorSTEQmax_PTnative
141+
)
130142
else:
131-
self.quantizer_name = "Qminmax" if self.minmax else "Qmax"
132-
self.quantizer = PerTensorSTEQmax_PTnative
143+
if self.extend_act_range:
144+
self.quantizer_name = "QmaxExtend"
145+
self.set_clip_ratio()
146+
self.quantizer = QmaxExtendRangeSTE_PTnative
147+
else:
148+
self.quantizer_name = "Qminmax" if self.minmax else "Qmax"
149+
self.quantizer = (
150+
QmaxPerChSTE_PTnative
151+
if self.perCh
152+
else PerTensorSTEQmax_PTnative
153+
)
133154
else:
134155
if self.minmax:
135156
self.quantizer_name = "Qminmax"
@@ -171,14 +192,14 @@ def forward(self, input_tensor: torch.FloatTensor) -> torch.Tensor:
171192
if self.perCh:
172193
if self.minmax:
173194
clip_val_new = torch.max(
174-
input_tensor.reshape([self.perCh, -1]), dim=1
195+
input_tensor.reshape([self.qscheme.Nch, -1]), dim=1
175196
).values
176197
clip_valn_new = torch.min(
177-
input_tensor.reshape([self.perCh, -1]), dim=1
198+
input_tensor.reshape([self.qscheme.Nch, -1]), dim=1
178199
).values
179200
else:
180201
clip_val_new = torch.max(
181-
input_tensor.abs().reshape([self.perCh, -1]), dim=1
202+
input_tensor.abs().reshape([self.qscheme.Nch, -1]), dim=1
182203
).values
183204
clip_valn_new = -clip_val_new
184205
assert (
@@ -487,50 +508,32 @@ def calc_qparams(
487508
return n_levels, clip_val, scale, zero_point, qint_min, qint_max, qint_dtype
488509

489510

490-
# Placeholder classes for PerCh/PerGp - need to rework #
491-
class QmaxPerChSTE_new(torch.autograd.Function):
511+
class QmaxPerChSTE_new(PerChannelSTEQmax):
492512
"""
493513
Max with zero alignment (symmetric)
494514
"dequantize=False" option is functional
495515
"""
496516

497517
@staticmethod
498-
def forward(
499-
ctx,
500-
input_tensor,
501-
num_bits,
502-
_dequantize,
503-
inplace,
504-
_cvn,
505-
cv,
506-
align_zero,
507-
) -> torch.FloatTensor:
518+
def backward(ctx, grad_output):
508519
"""
509-
TODO (bmgroth): docstring
520+
Backward function for Qmax Per Channel STE
521+
522+
Args:
523+
ctx (torch.autograd.Function): Context object.
524+
grad_output (torch.FloatTensor): Gradient to clip
525+
526+
Returns:
527+
[torch.FloatTensor, None,...,None]: Gradients
510528
"""
511-
if inplace:
512-
ctx.mark_dirty(input)
513-
scale = (2**num_bits - 2) if align_zero else (2**num_bits - 1)
514-
zero_point = 0.0
515-
_clip_val = cv
516-
# here use symmetric similar to sawbperCh code
517-
_nspace = 2**num_bits - 2 # lose one level
518-
int_l = -(2 ** (num_bits - 1)) + 1
519-
int_u = -int_l # symmetric
520-
scale = (
521-
cv * 2 / (2**num_bits - 2)
522-
) # original SAWB assumes odd number of bins when calc clip_val
523-
zero_point = torch.zeros_like(scale) # centers around 0 and align 0
524-
# FIXME, fake quantize function only support float.
525-
output = torch.fake_quantize_per_channel_affine(
526-
input_tensor.float(),
527-
scale.float(),
528-
zero_point.float(),
529-
axis=0,
530-
quant_min=int_l,
531-
quant_max=int_u,
532-
).to(input_tensor.dtype)
533-
return output
529+
return grad_output, None, None, None, None, None, None
530+
531+
532+
class QmaxPerChSTE_PTnative(PerChannelSTEQmax_PTnative):
533+
"""
534+
Max with zero alignment (symmetric)
535+
"dequantize=False" option is functional
536+
"""
534537

535538
@staticmethod
536539
def backward(ctx, grad_output):
@@ -606,31 +609,11 @@ def backward(ctx, grad_output):
606609
return grad_output, None, None, None, None, None, None
607610

608611

609-
class QminmaxPerChSTE_new(torch.autograd.Function):
612+
class QminmaxPerChSTE_new(PerChannelSTEQmax):
610613
"""
611614
per channel minmax with zero alignment (asymmetric)
612615
"""
613616

614-
@staticmethod
615-
def forward(
616-
ctx, input_tensor, num_bits, _dequantize, inplace, cv, cvn, align_zero
617-
) -> torch.FloatTensor:
618-
"""
619-
TODO (bmgroth): docstring
620-
"""
621-
if inplace:
622-
ctx.mark_dirty(input_tensor)
623-
cv, cvn = cv.to(input_tensor.dtype), cvn.to(input_tensor.dtype)
624-
scale = (2**num_bits - 1) / (cv - cvn)
625-
zero_point = cvn * scale
626-
if align_zero:
627-
zero_point = torch.round(zero_point)
628-
output = (input_tensor.clamp(cvn[:, None], cv[:, None]) - cvn[:, None]) * scale[
629-
:, None
630-
]
631-
output = (torch.round(output) + zero_point[:, None]) / scale[:, None]
632-
return output
633-
634617
@staticmethod
635618
def backward(ctx, grad_output):
636619
"""

0 commit comments

Comments
 (0)