@@ -760,6 +760,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
760760 )
761761 qlin_int .usePTnativeQfunc = kwargs .get ("use_PT_native_Qfunc" , False )
762762 qlin_int .useDynMaxQfunc = kwargs .get ("use_dynamic_max_act_Qfunc" , False )
763+ qlin_int .useSymAct = "sym" in fms_mo_qlinear .qa_mode
763764 qlin_int .max_acc_bits = kwargs .get ("max_acc_bits" , 32 )
764765 qlin_int .accminmax = (
765766 - (1 << (qlin_int .max_acc_bits - 1 )),
@@ -770,6 +771,8 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
770771 qlin_int .acc_dtype = torch .float16
771772 qlin_int .nbits_a = fms_mo_qlinear .num_bits_feature # only support INT8 for now
772773 qlin_int .nbits_w = fms_mo_qlinear .num_bits_weight
774+ w_levels = 2 ** qlin_int .nbits_w - 2
775+ a_levels = 2 ** qlin_int .nbits_a - 1 - qlin_int .useSymAct
773776
774777 with torch .no_grad ():
775778 Qa = fms_mo_qlinear .quantize_feature
@@ -794,29 +797,19 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
794797 if qlin_int .useDynMaxQfunc in [- 1 , - 2 ]:
795798 input_scale = torch .tensor (1.0 , device = tar_dev )
796799 input_zero_point = torch .tensor (128 , dtype = torch .int , device = tar_dev )
797- w_scale = torch .tensor (
798- [w_cv * 2 / (2 ** qlin_int .nbits_w - 2 )], device = tar_dev
799- )
800+ w_scale = torch .tensor ([w_cv * 2 / w_levels ], device = tar_dev )
800801 elif qlin_int .usePTnativeQfunc :
801- input_scale = torch .tensor (
802- [(a_cv - a_cvn ) / (2 ** qlin_int .nbits_a - 1 )], device = tar_dev
803- )
802+ input_scale = torch .tensor ([(a_cv - a_cvn ) / a_levels ], device = tar_dev )
804803 input_zero_point = torch .round (- a_cvn / input_scale ).to (torch .int )
805- w_scale = torch .tensor (
806- [w_cv * 2 / (2 ** qlin_int .nbits_w - 2 )], device = tar_dev
807- )
804+ w_scale = torch .tensor ([w_cv * 2 / w_levels ], device = tar_dev )
808805 else :
809806 # fms_mo formula is a bit different from conventional PT formula
810- quant_scale = (2 ** qlin_int .nbits_a - 1 ) / torch .tensor (
811- [a_cv - a_cvn ], device = tar_dev
812- )
807+ quant_scale = a_levels / torch .tensor ([a_cv - a_cvn ], device = tar_dev )
813808 quant_stepsize = 1.0 / quant_scale
814809 quant_zero_point = torch .round (a_cvn * quant_scale )
815810 input_scale = quant_stepsize
816811 input_zero_point = - quant_zero_point
817- quant_w_scale = (2 ** qlin_int .nbits_a - 2 ) / torch .tensor (
818- [w_cv * 2 ], device = tar_dev
819- )
812+ quant_w_scale = w_levels / torch .tensor ([w_cv * 2 ], device = tar_dev )
820813 w_scale = 1.0 / quant_w_scale
821814 qlin_int .register_buffer ("quant_scale" , quant_scale )
822815 qlin_int .register_buffer ("quant_stepsize" , quant_stepsize )
@@ -829,7 +822,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
829822 qlin_int .register_buffer ("w_zp" , w_zp )
830823
831824 corr_term = (
832- (input_zero_point - 128 )
825+ (input_zero_point - 128 + qlin_int . useSymAct )
833826 * (w_int8 .sum (dim = 1 ))
834827 * w_scale .float ()
835828 * input_scale .float ()
@@ -975,7 +968,7 @@ def qa_pt_qfunc_wrapped(self, x):
975968 Tensor: Quantized tensor with values in the range [-128, 127].
976969 """
977970 return torch .ops .fms_mo .q_per_t_sym (
978- x .float (), self .input_scale , self .input_zp - 128
971+ x .float (), self .input_scale , self .input_zp - 128 + self . useSymAct
979972 )
980973
981974 def qa_pt_quant_func (self , x ):
@@ -990,15 +983,22 @@ def qa_pt_quant_func(self, x):
990983 Tensor: Quantized tensor with values in the range [-128, 127].
991984 """
992985 return torch .quantize_per_tensor (
993- x .float (), self .input_scale , self .input_zp - 128 , torch .qint8
986+ x .float (),
987+ self .input_scale ,
988+ self .input_zp - 128 + self .useSymAct ,
989+ torch .qint8 ,
994990 ).int_repr ()
995991
996992 def qa_raw_qfunc (self , x ):
997993 """
998994 Quantizes the input tensor x to 8-bit integer values using raw formula, slower if not
999995 torch.compiled
1000996 """
1001- x = torch .clamp ((x / self .input_scale + self .input_zp - 128 ).round (), - 128 , 127 )
997+ x = torch .clamp (
998+ (x / self .input_scale + self .input_zp - 128 + self .useSymAct ).round (),
999+ - 128 ,
1000+ 127 ,
1001+ )
10021002 return x .to (torch .int8 )
10031003
10041004 def qa_fmo_mo_qfunc (self , x ):
@@ -1007,13 +1007,10 @@ def qa_fmo_mo_qfunc(self, x):
10071007 before rounds, as opposed to typical torch formula that rounds before clamps.
10081008 (See qa_raw_qfunc() above.)
10091009 """
1010- x = (
1011- torch .round (
1012- x .clamp (self .cvs [1 ], self .cvs [0 ]) / self .quant_stepsize
1013- - self .quant_zero_point
1014- )
1015- - 128
1016- )
1010+ x = torch .round (
1011+ x .clamp (self .cvs [1 ], self .cvs [0 ]) / self .quant_stepsize
1012+ - self .quant_zero_point
1013+ ) - (128 - self .useSymAct )
10171014 return x .to (torch .int8 )
10181015
10191016 def qa_dynamic_max_qfunc (self , x ):
@@ -1060,7 +1057,9 @@ def iaddmm_int(self, bias, m1, m2):
10601057 Nchunk = len (idx )
10611058 idx .append (m1 .shape [1 ])
10621059 accumulator = torch .zeros (
1063- (m1 .shape [0 ], m2 .shape [1 ]), dtype = torch .float16 , device = m1 .device
1060+ (m1 .shape [0 ], m2 .shape [1 ]),
1061+ dtype = torch .int ,
1062+ device = m1 .device , # cast float16 if needed
10641063 )
10651064 trun_scale = 1
10661065 if self .truncate_lsb > 0 :
@@ -1080,7 +1079,7 @@ def iaddmm_int(self, bias, m1, m2):
10801079 # could cast to smaller data type to further simulate HW behavior, for example,
10811080 # if HW truncates 8b from both sides of i32 accumulator, the remaining data can
10821081 # be cast to i16 to be more realistic. pay attention to overflow handling
1083- accumulator += imm_out .to (torch .float16 )
1082+ accumulator += imm_out # .to(torch.float16) if needed
10841083
10851084 return (
10861085 accumulator
@@ -1107,8 +1106,14 @@ def iaddmm_FP(self, bias, m1, m2):
11071106 Returns:
11081107 Tensor: the result of the matrix multiplication with addition of bias
11091108 """
1110- m2 = m2 .to (m1 .dtype )
1111- return torch .addmm (bias , m1 , m2 )
1109+ if self .useDynMaxQfunc in [- 1 , - 2 ]:
1110+ m1 = self .qa_dynamic_max_qfunc (m1 )
1111+ elif self .usePTnativeQfunc :
1112+ m1 = self .qa_raw_qfunc (m1 )
1113+ else :
1114+ m1 = self .qa_fmo_mo_qfunc (m1 )
1115+
1116+ return torch .matmul (m1 * self .input_scale , m2 * self .w_scale ) + bias
11121117
11131118 def set_matmul_op (self ):
11141119 """
0 commit comments