Skip to content

Commit 86a82ee

Browse files
committed
fix: Fixed quantizer_new perCh and updated negative clip_vals from sawb_params
Signed-off-by: Brandon Groth <[email protected]>
1 parent d66340c commit 86a82ee

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

fms_mo/quant_refactor/quantizers_new.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def forward(ctx, input, num_bits, dequantize, inplace, objSAWB_clip_val):
559559
else:
560560
output = torch.quantize_per_channel(
561561
input, scale, zero_point, 0, torch.qint8
562-
).int_repr()
562+
).int_repr().clamp(int_l, int_u)
563563
# NOTE return will be a torch.int8 tensor
564564

565565
return output
@@ -639,6 +639,13 @@ def sawb_params_code(num_bits, code, out, perCh=False):
639639
mu = torch.mean(out.abs(), dim=reduce_dim)
640640
std = torch.mean(out**2, dim=reduce_dim).sqrt()
641641
clip_val_vec = coeff[1] * mu + coeff[0] * std
642+
643+
# Overwrite negative clip_vals with abs.max
644+
neg_clip_idx = clip_val_vec < 0.0
645+
if torch.any(neg_clip_idx):
646+
clip_val_max = torch.max(out.abs(), dim=1).values
647+
clip_val_vec[neg_clip_idx] = clip_val_max[neg_clip_idx]
648+
642649
return clip_val_vec, None
643650
else:
644651
# per-tensor
@@ -648,6 +655,10 @@ def sawb_params_code(num_bits, code, out, perCh=False):
648655

649656
clip_val = coeff[1] * mu + coeff[0] * std
650657

658+
# Overwrite negative clip_vals with abs.max
659+
if clip_val < 0.0:
660+
clip_val = x.abs().max()
661+
651662
if code in [102]:
652663
nspace = 2**num_bits - 1
653664
elif code in [403, 103, 703, 803]:
@@ -762,6 +773,10 @@ def sawb_params(num_bits, out):
762773
coeff = dic_coeff[num_bits]
763774
clip_val = coeff[1] * mu + coeff[0] * std
764775

776+
# Overwrite negative clip_vals with abs.max
777+
if clip_val < 0.0:
778+
clip_val = x.abs().max()
779+
765780
return clip_val
766781

767782

0 commit comments

Comments
 (0)