@@ -269,7 +269,6 @@ def forward(self, x):
269269 if self .calib_counter == 0 :
270270 self .quantize_calib_feature = None
271271 self .quantize_calib_weight = None
272- self .calib_counter = None # [optional] this should release the memory
273272
274273 elif self .ptqmode == "fp32_out" :
275274 if self .W_fp is None :
@@ -1101,7 +1100,6 @@ def qa_dyn_max_fake_zero_shift(self, x):
11011100 to multiply input_scale. (Assuming per-tensor, can shift left or right)
11021101 """
11031102 amax = x .abs ().max ()
1104- shift_dir = 1 if amax == x .max () else - 1
11051103 levels = 2 ** (self .nbits_a - 1 ) - 1 - self .input_zp
11061104 self .cvs [0 ] = amax
11071105 self .cvs [1 ] = - amax
@@ -2061,6 +2059,167 @@ def extra_repr(self) -> str:
20612059 )
20622060
20632061
2062+ class LinearFuncINT8FwdFP32Bwd (torch .autograd .Function ):
2063+ """[Experimental] Linear autograd function using INT matmul/accumulation to simulate HW behavior
2064+ during QAT, in order to adjust weights for specific HW design.
2065+ Args:
2066+ activation x: FP tensor, need to be reshaped to 2D and quantized to INT8.
2067+ weight: FP 2D tensor, W.shape = [out, in].
2068+ bias: bias from original Linear, does not include INT ZP correction term yet.
2069+ NOTE:
2070+ 1. main purpose is to utilize triton INT kernel to simulate MSB/LSB truncation in FWD.
2071+ 2. BWD simply uses torch.matmul.
2072+ 3. *Max per-Ch* for weights and *dynamic max per-Token* for activations.
2073+ """
2074+
2075+ @staticmethod
2076+ def forward (
2077+ ctx ,
2078+ x ,
2079+ weight ,
2080+ bias = None ,
2081+ lsb_trun_bits = 0 ,
2082+ chunk_size = 64 ,
2083+ max_acc_bits = 32 ,
2084+ ):
2085+ assert x .dtype in [torch .float , torch .bfloat16 , torch .float16 ]
2086+ # input can be 2D or 3D, need to reshape before tl_matmul
2087+ org_dtype = x .dtype
2088+ target_shape_output = x .shape [:- 1 ] + (weight .shape [0 ],)
2089+ x = x .reshape (- 1 , x .shape [- 1 ])
2090+
2091+ if bias is not None :
2092+ ctx .has_bias = True
2093+ ctx .bias_dtype = bias .dtype
2094+ else :
2095+ ctx .has_bias = False
2096+
2097+ ctx .save_for_backward (x , weight ) # x, W are saved in their original dtype
2098+
2099+ # max per_token for input -> reduce_dim = -1
2100+ # per_channel for W but W.shape = [out, in] -> reduce_dim = -1
2101+ # sym activation -> correction term = 0
2102+ x_scale = x .abs ().amax (dim = - 1 , keepdim = True ) / 127
2103+ w_scale = weight .abs ().amax (dim = - 1 , keepdim = True ) / 127
2104+
2105+ x_i8 = torch .round (x / x_scale ).to (torch .int8 )
2106+ w_i8 = torch .round (weight / w_scale ).to (torch .int8 )
2107+
2108+ # triton kernel accepts 2d int8 then return int32
2109+ output = tl_matmul (
2110+ x_i8 ,
2111+ w_i8 .t (),
2112+ chunk_trun_bits = lsb_trun_bits ,
2113+ chunk_size = chunk_size ,
2114+ max_acc_bits = max_acc_bits ,
2115+ )
2116+ output = (
2117+ (output .to (torch .float ) * x_scale * w_scale .t ())
2118+ .reshape (target_shape_output )
2119+ .to (org_dtype )
2120+ )
2121+ if bias is not None :
2122+ output = output + bias .to (org_dtype )
2123+
2124+ return output
2125+
2126+ @staticmethod
2127+ def backward (ctx , grad_output ):
2128+ # load x and W from context, x is 2D already. no quant.
2129+ # option 1: use compute dtype = x.dtype
2130+ # option 2: compute in fp32 for best results.
2131+ x , weight = ctx .saved_tensors # x, W are saved in original dtype
2132+ out_dim = weight .shape [0 ]
2133+ in_dim = weight .shape [1 ]
2134+ dtype_grad = x .dtype # torch.float
2135+ # grad_input and grad_output could be 3D as x
2136+ target_shape_grad_input = grad_output .shape [:- 1 ] + (in_dim ,)
2137+ grad_output_2D = grad_output .reshape (- 1 , out_dim ).to (dtype_grad )
2138+
2139+ # Compute grad_weight, shape = [out, in]
2140+ grad_weight = torch .matmul (
2141+ grad_output_2D .transpose (0 , 1 ).contiguous (),
2142+ x .to (dtype_grad ),
2143+ ).to (weight .dtype )
2144+ # Compute grad_input in 2D then reshape to target shape, could be 3D or 2D
2145+ grad_input = (
2146+ torch .matmul (
2147+ grad_output_2D ,
2148+ weight .to (dtype_grad ),
2149+ )
2150+ .reshape (target_shape_grad_input )
2151+ .to (x .dtype )
2152+ )
2153+
2154+ if not ctx .has_bias :
2155+ grad_bias = None
2156+ else :
2157+ grad_bias = grad_output_2D .sum (0 ).to (ctx .bias_dtype )
2158+
2159+ return grad_input , grad_weight , grad_bias , None , None , None
2160+
2161+
2162+ class QLinearINT8Train (torch .nn .Linear ):
2163+ """QLinear layer wrapper that simulates INT8 HW behavior, e.g. MSB/LSB truncation, in forward
2164+ and FP32 in backward.
2165+ """
2166+
2167+ @classmethod
2168+ def from_fms_mo (cls , nnlin , lsb_trun_bits = 0 , ** kwargs ):
2169+ """Converts a torch.nn.Linear or QLinear module to QLinearINT8Train
2170+
2171+ Args:
2172+ cls (class): The class to be created.
2173+ nnlin (torch.nn.Linear or QLinear): The Linear module to be converted.
2174+ lsb_trun_bits (int): INT8 LSB truncation, [0 to 16].
2175+ chunk_size (int): usually >= 64 for INT8, based on HW design.
2176+ max_acc_bits: accumulator max bits, <=32, based on HW design.
2177+
2178+ Returns:
2179+ LinearFPxAcc: The converted linear layer.
2180+ """
2181+
2182+ target_device = kwargs .get (
2183+ "target_device" , kwargs .get ("device" , next (nnlin .parameters ()).device )
2184+ )
2185+
2186+ lin_int8fwd = cls (
2187+ nnlin .in_features ,
2188+ nnlin .out_features ,
2189+ bias = nnlin .bias is not None ,
2190+ device = "meta" , # target_device,
2191+ )
2192+
2193+ lin_int8fwd .weight = nnlin .weight
2194+ lin_int8fwd .lsb_trun_bits = lsb_trun_bits
2195+ lin_int8fwd .chunk_size = kwargs .get ("chunk_size" , 64 )
2196+ lin_int8fwd .max_acc_bits = kwargs .get ("max_acc_bits" , 32 )
2197+
2198+ if nnlin .bias is not None :
2199+ lin_int8fwd .bias = nnlin .bias
2200+ return lin_int8fwd .to (target_device )
2201+
2202+ def forward (self , inputs ):
2203+ return LinearFuncINT8FwdFP32Bwd .apply (
2204+ inputs ,
2205+ self .weight ,
2206+ self .bias ,
2207+ self .lsb_trun_bits ,
2208+ self .chunk_size ,
2209+ self .max_acc_bits ,
2210+ )
2211+
2212+ def extra_repr (self ) -> str :
2213+ """Returns an alternative string representation of the object."""
2214+ repr_str = f"{ self .in_features } ,{ self .out_features } "
2215+ repr_str += f",bias={ self .bias is not None } ,chunk_size={ self .chunk_size } "
2216+ if self .lsb_trun_bits > 0 :
2217+ repr_str += f",lsb_trun={ self .lsb_trun_bits } "
2218+ if self .max_acc_bits < 32 :
2219+ repr_str += f",max_acc_bits={ self .max_acc_bits } "
2220+ return repr_str
2221+
2222+
20642223if available_packages ["mx" ]:
20652224 # Third Party
20662225 # pylint: disable = import-error
0 commit comments