@@ -752,7 +752,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
752752 qlin_int .max_acc_bits = kwargs .get ("max_acc_bits" , 32 )
753753 qlin_int .accminmax = (
754754 - (1 << (qlin_int .max_acc_bits - 1 )),
755- 1 << (qlin_int .max_acc_bits - 1 ) - 1 ,
755+ ( 1 << (qlin_int .max_acc_bits - 1 ) ) - 1 ,
756756 )
757757 qlin_int .truncate_lsb = kwargs .get ("truncate_lsb" , 0 )
758758 qlin_int .chunk_size = kwargs .get ("chunk_size" , 100000 )
@@ -871,16 +871,16 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs):
871871
872872 qlinear_iW .nbits_a = 8 # Only support INT8 for now
873873 qlinear_iW .nbits_w = 8
874- qlinear_iW .acc_dtype = torch .float16
874+ qlinear_iW .acc_dtype = kwargs . get ( "acc_dtype" , torch .float )
875875 qlinear_iW .usePTnativeQfunc = kwargs .get ("use_PT_native_Qfunc" , True )
876- qlinear_iW .use_int_kernel = True
876+ qlinear_iW .use_int_kernel = kwargs . get ( "use_int_kernel" , "triton" )
877877 qlinear_iW .weight = nn .Parameter (
878878 nnlin_iW .weight .to (torch .int8 ), requires_grad = False
879879 )
880880 qlinear_iW .max_acc_bits = kwargs .get ("max_acc_bits" , 32 )
881881 qlinear_iW .accminmax = (
882882 - (1 << (qlinear_iW .max_acc_bits - 1 )),
883- 1 << (qlinear_iW .max_acc_bits - 1 ) - 1 ,
883+ ( 1 << (qlinear_iW .max_acc_bits - 1 ) ) - 1 ,
884884 )
885885 qlinear_iW .truncate_lsb = kwargs .get ("truncate_lsb" , False )
886886 qlinear_iW .chunk_size = kwargs .get ("chunk_size" , 100000 )
@@ -1027,11 +1027,11 @@ def iaddmm_int(self, bias, m1, m2):
10271027 else :
10281028 m1 = self .qa_fmo_mo_qfunc (m1 )
10291029
1030- if m1 .shape [1 ] > self .chunk_size :
1030+ if m1 .shape [1 ] > self .chunk_size and self . use_int_kernel != "triton" :
10311031 idx = list (range (0 , m1 .shape [1 ], self .chunk_size ))
10321032 Nchunk = len (idx )
10331033 idx .append (m1 .shape [1 ])
1034- fp16_out = torch .zeros (
1034+ accumulator = torch .zeros (
10351035 (m1 .shape [0 ], m2 .shape [1 ]), dtype = torch .float16 , device = m1 .device
10361036 )
10371037 trun_scale = 1
@@ -1052,11 +1052,11 @@ def iaddmm_int(self, bias, m1, m2):
10521052 # could cast to smaller data type to further simulate HW behavior, for example,
10531053 # if HW truncates 8b from both sides of i32 accumulator, the remaining data can
10541054 # be cast to i16 to be more realistic. pay attention to overflow handling
1055- fp16_out += imm_out .to (torch .float16 )
1055+ accumulator += imm_out .to (torch .float16 )
10561056
10571057 return (
1058- fp16_out
1059- * (trun_scale * self .input_scale * self .w_scale ).to (torch .float16 )
1058+ accumulator
1059+ * (trun_scale * self .input_scale * self .w_scale ) # .to(torch.float16)
10601060 + bias
10611061 ).to (self .acc_dtype )
10621062 # The safest casting, i32 -> f32
@@ -1145,10 +1145,13 @@ def extra_repr(self) -> str:
11451145 """
11461146 Returns an alternative string representation of the object
11471147 """
1148- return (
1148+ repr_str = (
11491149 f"in={ self .in_features } , out={ self .out_features } , bias={ self .bias is not None } , "
1150- f"use_int_kernel ={ self .use_int_kernel } "
1150+ f"int_kernel ={ self .use_int_kernel } "
11511151 )
1152+ if self .truncate_lsb > 0 or self .max_acc_bits < 32 :
1153+ repr_str += f", acc_bits={ self .max_acc_bits } , trun_lsb={ self .truncate_lsb } "
1154+ return repr_str
11521155
11531156 def __getstate__ (self ):
11541157 """
0 commit comments