@@ -559,7 +559,7 @@ def forward(ctx, input, num_bits, dequantize, inplace, objSAWB_clip_val):
559559 else :
560560 output = torch .quantize_per_channel (
561561 input , scale , zero_point , 0 , torch .qint8
562- ).int_repr ()
562+ ).int_repr (). clamp ( int_l , int_u )
563563 # NOTE return will be a torch.int8 tensor
564564
565565 return output
@@ -639,6 +639,13 @@ def sawb_params_code(num_bits, code, out, perCh=False):
639639 mu = torch .mean (out .abs (), dim = reduce_dim )
640640 std = torch .mean (out ** 2 , dim = reduce_dim ).sqrt ()
641641 clip_val_vec = coeff [1 ] * mu + coeff [0 ] * std
642+
643+ # Overwrite negative clip_vals with abs.max
644+ neg_clip_idx = clip_val_vec < 0.0
645+ if torch .any (neg_clip_idx ):
646+ clip_val_max = torch .max (out .abs (), dim = 1 ).values
647+ clip_val_vec [neg_clip_idx ] = clip_val_max [neg_clip_idx ]
648+
642649 return clip_val_vec , None
643650 else :
644651 # per-tensor
@@ -648,6 +655,10 @@ def sawb_params_code(num_bits, code, out, perCh=False):
648655
649656 clip_val = coeff [1 ] * mu + coeff [0 ] * std
650657
658+ # Overwrite negative clip_vals with abs.max
659+ if clip_val < 0.0 :
660+ clip_val = x .abs ().max ()
661+
651662 if code in [102 ]:
652663 nspace = 2 ** num_bits - 1
653664 elif code in [403 , 103 , 703 , 803 ]:
@@ -762,6 +773,10 @@ def sawb_params(num_bits, out):
762773 coeff = dic_coeff [num_bits ]
763774 clip_val = coeff [1 ] * mu + coeff [0 ] * std
764775
776+ # Overwrite negative clip_vals with abs.max
777+ if clip_val < 0.0 :
778+ clip_val = x .abs ().max ()
779+
765780 return clip_val
766781
767782
0 commit comments