@@ -712,8 +712,8 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
712712 cls: The class of the QLinearModule to be created.
713713 fms_mo_qlinear: The QLinear module to be converted.
714714 (experimental)
715- useINTkernel : choose from ['cutlass', 'triton', False], "cutlass" kernel is faster,
716- "triton" support chunky truncation, "False" fallback to torch.matmul
715+ use_int_kernel : choose from ['cutlass', 'triton', False], "cutlass" kernel is faster,
716+ "triton" supports chunky truncation, "False" fallbacks to torch.matmul
717717 max_acc_bits: usually INT matmul accumulate in INT32, but some HW could have different
718718 design, such as using INT24 accumulator, which will saturate at
719719 (-2**(acc_bit-1) +1, 2**(acc_bit-1) )
@@ -745,8 +745,8 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
745745 )
746746 # Make sure to register an Op for integer matmul, could be real INT matmul or emulation
747747 qcfg = getattr (fms_mo_qlinear , "qcfg" , {})
748- qlin_int .useINTkernel = kwargs .get (
749- "useINTkernel " , qcfg .get ("useINTkernel " , "cutlass" )
748+ qlin_int .use_int_kernel = kwargs .get (
749+ "use_int_kernel " , qcfg .get ("use_int_kernel " , "cutlass" )
750750 )
751751 qlin_int .usePTnativeQfunc = kwargs .get ("use_PT_native_Qfunc" , False )
752752 qlin_int .max_acc_bits = kwargs .get ("max_acc_bits" , 32 )
@@ -772,7 +772,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
772772 ) # Qw.clipval should have been updated after this
773773 qlin_int .weight = nn .Parameter (
774774 w_int8 .to (torch .int8 ), requires_grad = False
775- ) # NOTE: may needs INT W stored as FP in some cases
775+ ) # NOTE: may need INT W stored as FP in some cases
776776
777777 if qlin_int .usePTnativeQfunc :
778778 input_scale = torch .tensor (
@@ -873,7 +873,7 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs):
873873 qlinear_iW .nbits_w = 8
874874 qlinear_iW .acc_dtype = torch .float16
875875 qlinear_iW .usePTnativeQfunc = kwargs .get ("use_PT_native_Qfunc" , True )
876- qlinear_iW .useINTkernel = True
876+ qlinear_iW .use_int_kernel = True
877877 qlinear_iW .weight = nn .Parameter (
878878 nnlin_iW .weight .to (torch .int8 ), requires_grad = False
879879 )
@@ -1086,22 +1086,22 @@ def set_matmul_op(self):
10861086 """
10871087 Sets the matmul operator for the quantized linear module.
10881088
1089- If `useINTkernel ` is True and CUDA is available, it will use the INT kernel
1089+ If `use_int_kernel ` is True and CUDA is available, it will use the INT kernel
10901090 for integer matrix multiplication. Otherwise, it will use the FP kernel.
10911091
10921092 If the operator has already been set, it will do nothing.
10931093 """
1094- if self .useINTkernel and not torch .cuda .is_available ():
1094+ if self .use_int_kernel and not torch .cuda .is_available ():
10951095 logger .warning (
1096- "Cannot set useINTkernel =True when CUDA is not available. "
1097- "Fallback to useINTkernel =False"
1096+ "Cannot set use_int_kernel =True when CUDA is not available. "
1097+ "Fallback to use_int_kernel =False"
10981098 )
1099- self .useINTkernel = False
1099+ self .use_int_kernel = False
11001100
11011101 if hasattr (torch .ops , "fms_mo" ) and hasattr (torch .ops .fms_mo , "imatmul" ):
11021102 # imatmul already registered, e.g. when swapping the 2nd QLinear
11031103 self .imatmul = torch .ops .fms_mo .imatmul
1104- self .iaddmm = self .iaddmm_int if self .useINTkernel else self .iaddmm_FP
1104+ self .iaddmm = self .iaddmm_int if self .use_int_kernel else self .iaddmm_FP
11051105 else :
11061106 # When swapping the first QLinear, need to register our custom Op and choose the kernel
11071107 # Standard
@@ -1113,14 +1113,16 @@ def set_matmul_op(self):
11131113 imatmul_ops_reg ,
11141114 )
11151115
1116- if self .useINTkernel == "triton" : # will use real imatmul written in triton
1116+ if self .use_int_kernel == "triton" :
1117+ # will use real imatmul written in triton
11171118 imm_func = partial (
11181119 tl_matmul ,
11191120 chunk_trun_bits = self .truncate_lsb ,
11201121 chunk_size = self .chunk_size ,
11211122 )
11221123
1123- elif self .useINTkernel == "cutlass" :
1124+ elif self .use_int_kernel == "cutlass" :
1125+ # will use real imatmul written in cutlass
11241126 cutlass_ops_load_and_reg ()
11251127 # Third Party
11261128 import cutlass_mm # this module will only be available after calling reg()
@@ -1129,9 +1131,9 @@ def set_matmul_op(self):
11291131 else :
11301132 imm_func = torch .matmul
11311133
1132- imatmul_ops_reg (self .useINTkernel , imm_func )
1134+ imatmul_ops_reg (self .use_int_kernel , imm_func )
11331135 self .imatmul = torch .ops .fms_mo .imatmul
1134- self .iaddmm = self .iaddmm_int if self .useINTkernel else self .iaddmm_FP
1136+ self .iaddmm = self .iaddmm_int if self .use_int_kernel else self .iaddmm_FP
11351137
11361138 def _get_name (self ):
11371139 """
@@ -1145,7 +1147,7 @@ def extra_repr(self) -> str:
11451147 """
11461148 return (
11471149 f"in={ self .in_features } , out={ self .out_features } , bias={ self .bias is not None } , "
1148- f"useINTkernel ={ self .useINTkernel } "
1150+ f"use_int_kernel ={ self .use_int_kernel } "
11491151 )
11501152
11511153 def __getstate__ (self ):
@@ -1861,7 +1863,7 @@ class LinearFPxAcc(torch.nn.Linear):
18611863 """Linear layer wrapper that can simulate the HW behavior of LSB truncation on FP accumulation.
18621864 Some HW may have options to allow FP matmul engine to accumulate in precision lower than FP32,
18631865 such as accumulate in TF32 or even BF16. According to Nvidia doc, ~7-10x speed up with minor
1864- accuracy trade-off. This support both FWD and BWD.
1866+ accuracy trade-off. This supports both FWD and BWD.
18651867 Ref:
18661868 1. https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/
18671869 2. PyTorch's "torch.backends.cuda.matmul.allow_tf32"
0 commit comments