@@ -261,7 +261,7 @@ def forward(self, x):
261261 scale = torch .tensor ([1.0 ]).to (x .dtype ).to (x .device )
262262
263263 # pylint: disable = access-member-before-definition
264- if self .calib_counter :
264+ if self .calib_counter > 0 :
265265 with torch .no_grad ():
266266 qinput = self .quantize_calib_feature (x / scale )
267267 qweight = self .quantize_calib_weight (self .weight * scale )
@@ -733,6 +733,8 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
733733 chunk_size: some HW may have specific chunk size (BLOCK SIZE, especially in k-dim) for
734734 the reason to avoid overflow/underflow problem. This can be simulated using
735735 PyTorch (break a matmul into serial smaller matmuls, slow) or Triton kernel
736+ useDynMaxQfunc: [-1, -2] indicates reduce_dim, 0< val <64 indicates artificial
737+ zero-shift, False -> use normal static quantization.
736738
737739 Returns:
738740 A QLinearINT8Deploy object initialized with the weights and biases from the
@@ -761,7 +763,11 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
761763 )
762764 qlin_int .usePTnativeQfunc = kwargs .get ("use_PT_native_Qfunc" , False )
763765 qlin_int .useDynMaxQfunc = kwargs .get ("use_dynamic_max_act_Qfunc" , False )
764- qlin_int .useSymAct = "sym" in fms_mo_qlinear .qa_mode
766+ qlin_int .useSymAct = (
767+ "sym" in fms_mo_qlinear .qa_mode
768+ or fms_mo_qlinear .qa_mode in ["pertokenmax" , "max" ]
769+ # these are the symmetric quantizers with no "sym" in their names
770+ )
765771 qlin_int .max_acc_bits = kwargs .get ("max_acc_bits" , 32 )
766772 qlin_int .accminmax = (
767773 - (1 << (qlin_int .max_acc_bits - 1 )),
@@ -778,26 +784,49 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
778784 with torch .no_grad ():
779785 Qa = fms_mo_qlinear .quantize_feature
780786 Qw = fms_mo_qlinear .quantize_weight
787+ # if no calibration has been run before swapping, clipvals stored in Qw will be the
788+ # original one, e.g. per-tensor. If want to experiment with new quantizers, need to run
789+ # at least one fwd, which will update the clipvals.
790+ Qw (fms_mo_qlinear .weight )
781791 w_cv = Qw .clip_val
782- if qlin_int .useDynMaxQfunc in [- 1 , - 2 ]: # [-1, -2] indicates reduce_dim
783- # dynamic Qmax has no clipvals, reg fake ones, won't be used in real calc
784- Qa .register_buffer ("clip_val" , torch .tensor (8.0 , device = tar_dev ))
785- Qa .register_buffer ("clip_valn" , torch .tensor (- 8.0 , device = tar_dev ))
786- a_cv = Qa .clip_val
787- a_cvn = Qa .clip_valn
792+ a_cv = getattr (Qa , "clip_val" , torch .tensor (8.0 , device = tar_dev ))
793+ a_cvn = getattr (Qa , "clip_valn" , torch .tensor (- 8.0 , device = tar_dev ))
788794 # Store original cv_a and cv_w in python floats (instead of tensors) will be more
789795 # accurate, but not compatible for per-ch and per-token.
790- qlin_int .cvs = [a_cv , a_cvn , w_cv ] # TODO remove the need of this.
796+ qlin_int .cvs = [a_cv , a_cvn , w_cv ] # TODO remove the need of this?
797+
798+ # prepare smoothQuant scale, = (smQ_a_scale ^ alpha)/(smQ_w_scale ^ (1-alpha) )
799+ smq_scale = torch .tensor ([1.0 ], device = tar_dev , dtype = fms_mo_w_dtype )
800+ if getattr (fms_mo_qlinear , "smoothq" , False ):
801+ smq_a_scale = fms_mo_qlinear .smoothq_act_scale
802+ smq_w_scale = (
803+ fms_mo_qlinear .weight .abs ()
804+ .max (dim = 0 , keepdim = True )[0 ]
805+ .clamp (min = 1e-5 )
806+ )
807+ smq_alpha = fms_mo_qlinear .smoothq_alpha
808+ if torch .all (smq_a_scale != 0 ).item ():
809+ smq_scale = (
810+ (smq_a_scale ** smq_alpha / smq_w_scale ** (1.0 - smq_alpha ))
811+ .clamp (min = 1e-5 )
812+ .to (smq_a_scale .dtype )
813+ )
791814
792- # may need to trigger Qw.clipval re-calc for SAWB here, (if needed?)
815+ # could trigger Qw.clipval re-calc for SAWB here, if needed
816+ input_scale = torch .tensor (1.0 , device = tar_dev )
817+ w_scale = w_cv * 2 / w_levels
818+ qlin_int .use_fake_zero_shift = False
793819 if qlin_int .useDynMaxQfunc in [- 1 , - 2 ]:
794- input_scale = torch .tensor (1.0 , device = tar_dev )
795- input_zero_point = torch .tensor (128 , dtype = torch .int , device = tar_dev )
796- w_scale = w_cv * 2 / w_levels
820+ input_zero_point = torch .tensor (
821+ 128 - qlin_int .useSymAct , device = tar_dev
822+ )
823+ elif 0 < qlin_int .useDynMaxQfunc < 65 :
824+ # introduce fake zero-shift, input_scale will be calc dynamically
825+ qlin_int .use_fake_zero_shift = True
826+ input_zero_point = torch .tensor (qlin_int .useDynMaxQfunc , device = tar_dev )
797827 elif qlin_int .usePTnativeQfunc :
798828 input_scale = torch .tensor ([(a_cv - a_cvn ) / a_levels ], device = tar_dev )
799- input_zero_point = torch .round (- a_cvn / input_scale ).to (torch .int )
800- w_scale = w_cv * 2 / w_levels
829+ input_zero_point = torch .round (- a_cvn / input_scale )
801830 else :
802831 # fms_mo formula is a bit different from conventional PT formula
803832 quant_scale = a_levels / torch .tensor ([a_cv - a_cvn ], device = tar_dev )
@@ -812,48 +841,70 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
812841 qlin_int .register_buffer ("quant_zero_point" , quant_zero_point )
813842 w_zp = torch .zeros_like (w_scale , dtype = torch .int )
814843
844+ input_zero_point = input_zero_point .to (torch .int ) # note 2 in pre-compute
815845 qlin_int .register_buffer ("input_scale" , input_scale )
816846 qlin_int .register_buffer ("input_zp" , input_zero_point )
817847 qlin_int .register_buffer ("w_scale" , w_scale )
818848 qlin_int .register_buffer ("w_zp" , w_zp )
849+ qlin_int .register_buffer ("smq_scale" , smq_scale )
819850
820851 # NOTE:
821852 # 1. Keep W transposed to prevent confusion, hence (W.t()/scale).t()
822- # 2. only a few quantizer have .dequantize working correctly
853+ # 2. only a few quantizer have .dequantize working correctly, e.g. SAWB
854+ # 3. smooth_quant factor is included in the W here, will also include it in the forward
823855 if isinstance (Qw , SAWB ):
824856 Qw .dequantize = False
825- w_int8 = Qw (fms_mo_qlinear .weight .float ())
857+ w_int8 = Qw (fms_mo_qlinear .weight .float () * smq_scale )
826858 else :
827859 w_int8 = (
828- torch .round (fms_mo_qlinear .weight .t () / w_scale )
860+ torch .round (( fms_mo_qlinear .weight * smq_scale ) .t () / w_scale )
829861 .clamp (- w_levels / 2 , w_levels / 2 )
830862 .t ()
831863 )
832-
864+ w_int8 = w_int8 .to (
865+ torch .int
866+ ) # stored as int32 as correction term needs sum()
833867 qlin_int .weight = nn .Parameter (w_int8 .to (torch .int8 ), requires_grad = False )
834868
835- corr_term = (
836- (input_zero_point - 128 + qlin_int .useSymAct )
837- * (w_int8 .sum (dim = 1 ))
838- * w_scale .float ()
839- * input_scale .float ()
840- )
841- # dim=1 because w_int is in [out,in], after sum shape=[out,], same as w_scale and bias.
842- # (zp-128)*w_int8.sum(dim=1) can be >> fp16.max, use fp32 scales
843- # to make sure dtype is large enough
844- qlin_int .register_buffer ("corr_term" , corr_term .half ()) # [DEBUG only]
845- if fms_mo_qlinear .bias is not None :
846- qlin_int .bias = nn .Parameter (
847- (fms_mo_qlinear .bias - corr_term ).to (fms_mo_w_dtype ),
848- requires_grad = False ,
849- )
869+ # Pre-compute the "correction term" for zero-shift for asym activation quantizers
870+ # NOTE:
871+ # 1. sym act should have corr_term=0, unless we want to introduce fake zero-shift
872+ # 2. sum to reduce dim=1 because w_int is in [out,in], after sum shape=[out,], same as
873+ # w_scale (per-Ch) and bias.
874+ # 3. calc INT part, i.e. (zp-128)*w_int8.sum(dim=1), first in INT32. because it can be
875+ # >> fp16.max (~65535 only) easily, make sure not to cast INT32 to FP16 during calc,
876+ # simply cast scales to FP32
877+ # 4. for the "fake zero-shift case", input_scale will be max/(127-fake_zero_shift)
878+ # instead of max/127, see qa_dyn_max_fake_zero_shift()
879+ # 5. Combine correction term into linear.bias for non-dynamic cases. For dyn quant,
880+ # input_scale is a placehold for now and will be calc'ed on the fly later.
881+ if qlin_int .useSymAct :
882+ corr_term_int = 0
883+ if qlin_int .use_fake_zero_shift :
884+ # one exception, fake zero-shift
885+ corr_term_int = input_zero_point * (w_int8 .sum (dim = 1 ))
886+ else :
887+ corr_term_int = (input_zero_point - 128 ) * (w_int8 .sum (dim = 1 ))
850888
851- qlin_int .org_model_has_bias = True
889+ qlin_int .register_buffer (
890+ "corr_term" , corr_term_int * w_scale .float () * input_scale .float ()
891+ ) # keep in FP32, cast at the end
892+
893+ qlin_int .org_model_has_bias = fms_mo_qlinear .bias is not None
894+ # Combine correction term into linear.bias when possible. NOTE the magnitude of these 2
895+ # terms could vary a lot. use fp32 in case of underflow and lose accuracy.
896+ if qlin_int .org_model_has_bias :
897+ new_bias = fms_mo_qlinear .bias .float () - qlin_int .corr_term
852898 else :
853- delattr (qlin_int , "bias" )
854- # even if bias is None, reg_buffer() is still unhappy about it
855- qlin_int .register_buffer ("bias" , - corr_term .to (fms_mo_w_dtype ))
856- qlin_int .org_model_has_bias = False
899+ new_bias = - qlin_int .corr_term
900+
901+ if qlin_int .use_fake_zero_shift :
902+ # dyn sym act but with fake zp, remove corr_term from bias
903+ new_bias += qlin_int .corr_term
904+
905+ delattr (qlin_int , "bias" )
906+ # sometimes reg_buffer() is unhappy about existing bias
907+ qlin_int .register_buffer ("bias" , new_bias .to (fms_mo_w_dtype ))
857908
858909 # redundant variables to be cleaned up
859910 # qlin_int.register_buffer("Qa_clip_val", Qa.clip_val.detach())
@@ -1039,9 +1090,25 @@ def qa_dynamic_max_qfunc(self, x):
10391090 """
10401091 amax = x .abs ().max (dim = self .useDynMaxQfunc , keepdim = True )[0 ]
10411092 levels = 2 ** (self .nbits_a - 1 ) - 1
1093+ self .cvs [0 ] = amax
1094+ self .cvs [1 ] = - amax
10421095 self .input_scale = amax .clamp (min = 1e-5 ).div (levels )
10431096 return torch .round (x / self .input_scale ).to (torch .int8 )
10441097
1098+ def qa_dyn_max_fake_zero_shift (self , x ):
1099+ """Dynamic max quantizer with fake zero-shift in order to accommodate "zero-centered"
1100+ activations. "partial" correction term has been pre-computed in from_fms_mo() but still need
1101+ to multiply input_scale. (Assuming per-tensor, can shift left or right)
1102+ """
1103+ amax = x .abs ().max ()
1104+ shift_dir = 1 if amax == x .max () else - 1
1105+ levels = 2 ** (self .nbits_a - 1 ) - 1 - self .input_zp
1106+ self .cvs [0 ] = amax
1107+ self .cvs [1 ] = - amax
1108+ self .input_scale = amax .clamp (min = 1e-5 ) / levels
1109+ xq = torch .round (x / self .input_scale ) + self .input_zp
1110+ return xq .to (torch .int8 )
1111+
10451112 def iaddmm_int (self , bias , m1 , m2 ):
10461113 """
10471114 Performs integer matrix multiplication with optional addition of a bias term.
@@ -1061,11 +1128,14 @@ def iaddmm_int(self, bias, m1, m2):
10611128
10621129 if self .useDynMaxQfunc in [- 1 , - 2 ]:
10631130 m1 = self .qa_dynamic_max_qfunc (m1 )
1131+ elif self .use_fake_zero_shift :
1132+ m1 = self .qa_dyn_max_fake_zero_shift (m1 )
10641133 elif self .usePTnativeQfunc :
10651134 m1 = self .qa_raw_qfunc (m1 )
10661135 else :
10671136 m1 = self .qa_fmo_mo_qfunc (m1 )
10681137
1138+ # NOTE simulate chunk behavior in pytorch is serial and slow, use triton when possible
10691139 if m1 .shape [1 ] > self .chunk_size and self .use_int_kernel != "triton" :
10701140 idx = list (range (0 , m1 .shape [1 ], self .chunk_size ))
10711141 Nchunk = len (idx )
@@ -1099,11 +1169,19 @@ def iaddmm_int(self, bias, m1, m2):
10991169 accumulator
11001170 * (trun_scale * self .input_scale * self .w_scale ) # .to(torch.float16)
11011171 + bias
1102- ).to (self .acc_dtype )
1103- # The safest casting, i32 -> f32
1172+ ).to (self .acc_dtype ) # safest casting would be i32 -> f32
1173+
11041174 imm_out = torch .ops .fms_mo .imatmul (m1 , m2 )
1175+
1176+ updated_bias = bias
1177+ if self .use_fake_zero_shift :
1178+ # Do NOT change the stored self.corr_term and self.bias
1179+ updated_bias = bias - self .input_scale * self .corr_term
1180+
1181+ # cast to fp16 could be modified based on real HW behavior/design
11051182 return (
1106- imm_out .float () * (self .input_scale * self .w_scale ).to (torch .float16 ) + bias
1183+ imm_out .float () * (self .input_scale * self .w_scale ).to (torch .float16 )
1184+ + updated_bias
11071185 ).to (self .acc_dtype )
11081186
11091187 def iaddmm_FP (self , bias , m1 , m2 ):
@@ -1247,9 +1325,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
12471325 self .weight .shape [0 ],
12481326 ) # W.shape=[out,in]
12491327
1250- x = self .iaddmm (self .bias , x .view (re_shape ), self .weight .t ()).reshape (
1251- tar_shape
1252- )
1328+ if torch .all (self .smq_scale != 1 ).item ():
1329+ x = x .view (re_shape ) / self .smq_scale
1330+ else :
1331+ x = x .view (re_shape )
1332+
1333+ x = self .iaddmm (self .bias , x , self .weight .t ()).reshape (tar_shape )
12531334
12541335 return x .to (org_dtype )
12551336
0 commit comments