2222
2323logger = logging .getLogger (__name__ )
2424
25+ # pylint: disable=unused-argument
26+ # i8i8 op must be registered with specific I/O, even if not in use by the op function
27+
28+ # pylint: disable=not-callable
29+ # torch.nn.functional.linear not recognized as callable
30+ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482
31+
2532
2633def register_aiu_i8i8_op ():
2734 """Register AIU-specific op to enable torch compile without graph break.
@@ -64,7 +71,8 @@ def i8i8_aiu(
6471 dtype = x .dtype
6572 out_feat , in_feat = weight .size ()
6673
67- w_cv , w_cvn , a_cv , a_cvn , zshift , sq = extract_qdata (
74+ # unused returns are w_cvn and zero_shift
75+ w_cv , _ , a_cv , a_cvn , _ , sq = extract_qdata (
6876 qdata ,
6977 weight_quant_type ,
7078 activ_quant_type ,
@@ -88,6 +96,8 @@ def i8i8_aiu_abstract(
8896 activ_quant_type ,
8997 smoothquant ,
9098 ):
99+ """OP template of I/O sizes"""
100+
91101 outshape = x .size ()[:- 1 ] + (weight .size (0 ),)
92102 return torch .empty (
93103 outshape , dtype = x .dtype , device = x .device , requires_grad = False
@@ -153,18 +163,19 @@ def dequant_weights(
153163 w_cv : torch .Tensor ,
154164 sq : torch .Tensor ,
155165 weight_quant_type : str ,
156- ):
166+ ) -> torch .Tensor :
167+ """Dequantize integer weights based on quantizer type"""
168+
157169 if weight_quant_type == "per_tensor" : # assume 8-bit symmetric W quantization
158170 # w size: (out_feat, in_feat)
159171 # sq size: (in_feat) or (1), no need to unsqueeze
160172 return (weight * w_cv / 127 ) / sq
161- elif weight_quant_type == "per_channel" :
173+ if weight_quant_type == "per_channel" :
162174 # w_cv is (out_feat), need to unsqueeze to broadcast mul to weight
163175 return (weight * w_cv .unsqueeze (dim = 1 ) / 127 ) / sq
164- else :
165- raise NotImplementedError (
166- f"weight quantizantion type { weight_quant_type } is not supported"
167- )
176+ raise NotImplementedError (
177+ f"weight quantizantion type { weight_quant_type } is not supported"
178+ )
168179
169180
170181def quant_dequant_activ (
@@ -173,8 +184,10 @@ def quant_dequant_activ(
173184 a_cvn : torch .Tensor ,
174185 sq : torch .Tensor ,
175186 activ_quant_type : str ,
176- ):
187+ ) -> torch . Tensor :
177188 """
189+ Quantize and dequantize activations based on quantizer type
190+
178191 x size (*, hid_dim)
179192 sq size (hid_dim) or (1)
180193 => no need to unsqueeze to perform x / sq
@@ -183,18 +196,17 @@ def quant_dequant_activ(
183196 scale_x = 127 / a_cv
184197 x_int = torch .round (x / sq * scale_x ).clamp (- 127 , 127 )
185198 return x_int / scale_x * sq
186- elif activ_quant_type == "per_tensor_asymm" :
199+ if activ_quant_type == "per_tensor_asymm" :
187200 scale_x = 255 / (a_cv - a_cvn )
188201 zp_x = a_cvn * scale_x
189202 x_int = torch .round (x / sq * scale_x - zp_x ).clamp (0 , 255 )
190203 return (x_int + zp_x ) / scale_x * sq
191- elif activ_quant_type == "per_token" :
204+ if activ_quant_type == "per_token" :
192205 x_sq = x / sq
193206 a_cv_per_token = x_sq .abs ().max (dim = - 1 , keepdim = True )[0 ]
194207 scale_x = 127 / a_cv_per_token
195208 x_int = torch .round (x_sq * scale_x ).clamp (- 127 , 127 )
196209 return x_int / scale_x * sq
197- else :
198- raise NotImplementedError (
199- f"activation quantizantion type { activ_quant_type } is not supported"
200- )
210+ raise NotImplementedError (
211+ f"activation quantizantion type { activ_quant_type } is not supported"
212+ )
0 commit comments