@@ -69,7 +69,7 @@ def insert_rescale_ops_to_int32(
6969 tosa_graph ,
7070 tensor ,
7171 qarg .zp ,
72- scale ,
72+ [ scale ] ,
7373 )
7474 )
7575 return rescaled_nodes , min_scale
@@ -109,7 +109,7 @@ def insert_rescale_op_to_int8(
109109 last_tensor .name ,
110110 node .name ,
111111 qargs_out .zp ,
112- output_rescale_scale ,
112+ [ output_rescale_scale ] ,
113113 )
114114
115115
@@ -156,65 +156,73 @@ def is_scale32(type: int) -> ts.DType:
156156# The RESCALE operator is defined using an integer multiply, add, and shift.
157157# This utility function is for calculating the multier and shift given a scale.
158158# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
159- def compute_multiplier_and_shift (scale : float , scaleWidth : int = 32 ) -> Tuple [int , int ]:
159+ def compute_multiplier_and_shift (
160+ scales : list [float ], scaleWidth : int = 32
161+ ) -> Tuple [list [int ], list [int ]]:
160162 if scaleWidth == 16 :
161163 offset = 15
162164 elif scaleWidth == 32 :
163165 offset = 31
164166 else :
165- raise AssertionError ( "unsupported scale width" )
166-
167- assert isinstance ( scale , float )
167+ raise ValueError (
168+ f"Unsupported scale width: { scaleWidth } , only 16 and 32 are valid values."
169+ )
168170
169- mantissa , exponent = math .frexp (scale )
170- shift = exponent
171+ multipliers = []
172+ shifts = []
173+ for scale in scales :
174+ mantissa , exponent = math .frexp (scale )
175+ shift = exponent
171176
172- const_2_power_15_or_31 = 1 << offset
173- shifted_mantissa = int ( round (mantissa * const_2_power_15_or_31 ) )
177+ const_2_power_15_or_31 = 1 << offset
178+ shifted_mantissa = round (mantissa * const_2_power_15_or_31 )
174179
175- assert shifted_mantissa <= const_2_power_15_or_31
180+ assert shifted_mantissa <= const_2_power_15_or_31
176181
177- if shifted_mantissa == const_2_power_15_or_31 :
178- shifted_mantissa = int ( shifted_mantissa / 2 )
179- shift += 1
182+ if shifted_mantissa == const_2_power_15_or_31 :
183+ shifted_mantissa = shifted_mantissa // 2
184+ shift += 1
180185
181- # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
182- shift = offset - shift
186+ # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
187+ shift = offset - shift
183188
184- # INT32_MAX, 2^31 - 1
185- assert shifted_mantissa <= (const_2_power_15_or_31 - 1 )
189+ # INT32_MAX, 2^31 - 1
190+ assert shifted_mantissa <= (const_2_power_15_or_31 - 1 )
186191
187- multiplier = shifted_mantissa
192+ multiplier = shifted_mantissa
188193
189- if shift > 62 :
190- multiplier = multiplier >> min (31 , shift - 62 )
191- shift = 62
192- return multiplier , shift
194+ if shift > 62 :
195+ multiplier = multiplier >> min (31 , shift - 62 )
196+ shift = 62
197+ multipliers .append (multiplier )
198+ shifts .append (shift )
199+ return multipliers , shifts
193200
194201
195202def build_rescale (
196203 tosa_fb : TosaSerializer ,
197- scale : float ,
204+ scale : list [ float ] ,
198205 input_node : TosaSerializerTensor ,
199206 output_name : str ,
200207 output_type : ts .DType ,
201208 output_shape : List [int ],
202209 input_zp : int ,
203210 output_zp : int ,
204211 is_double_round : bool = False ,
212+ per_channel = False ,
205213):
206214 scale_width = 32 if is_scale32 (output_type ) else 16
207- multiplier , shift = compute_multiplier_and_shift (scale , scale_width )
215+ multipliers , shifts = compute_multiplier_and_shift (scale , scale_width )
208216
209217 attr_rescale = ts .TosaSerializerAttribute ()
210218 attr_rescale .RescaleAttribute (
211219 input_zp = input_zp ,
212220 output_zp = output_zp ,
213- multiplier = [ multiplier ] ,
214- shift = [ shift ] ,
221+ multiplier = multipliers ,
222+ shift = shifts ,
215223 scale32 = is_scale32 (output_type ),
216224 double_round = is_double_round ,
217- per_channel = False ,
225+ per_channel = per_channel ,
218226 input_unsigned = False ,
219227 output_unsigned = False ,
220228 )
@@ -230,20 +238,21 @@ def build_rescale_to_int32(
230238 tosa_fb : TosaSerializer ,
231239 input_arg : executorch .backends .arm .tosa_mapping .TosaArg ,
232240 input_zp : int ,
233- rescale_scale : float ,
241+ rescale_scale : list [ float ] ,
234242 is_scale32 : bool = True ,
235243 is_double_round : bool = False ,
244+ per_channel : bool = False ,
236245) -> TosaSerializerTensor :
237- multiplier , shift = compute_multiplier_and_shift (rescale_scale )
246+ multipliers , shifts = compute_multiplier_and_shift (rescale_scale )
238247 attr_rescale = ts .TosaSerializerAttribute ()
239248 attr_rescale .RescaleAttribute (
240249 input_zp = input_zp ,
241250 output_zp = 0 ,
242- multiplier = [ multiplier ] ,
243- shift = [ shift ] ,
251+ multiplier = multipliers ,
252+ shift = shifts ,
244253 scale32 = is_scale32 ,
245254 double_round = is_double_round ,
246- per_channel = False ,
255+ per_channel = per_channel ,
247256 input_unsigned = False ,
248257 output_unsigned = False ,
249258 )
@@ -263,20 +272,21 @@ def build_rescale_from_int32(
263272 input_name : str ,
264273 output_name : str ,
265274 output_zp : int ,
266- rescale_scale : float ,
275+ rescale_scale : list [ float ] ,
267276 is_scale32 : bool = True ,
268277 is_double_round : bool = False ,
278+ per_channel : bool = False ,
269279) -> None :
270- multiplier , shift = compute_multiplier_and_shift (rescale_scale )
280+ multipliers , shifts = compute_multiplier_and_shift (rescale_scale )
271281 attr_rescale_output = ts .TosaSerializerAttribute ()
272282 attr_rescale_output .RescaleAttribute (
273283 input_zp = 0 ,
274284 output_zp = output_zp ,
275- multiplier = [ multiplier ] ,
276- shift = [ shift ] ,
285+ multiplier = multipliers ,
286+ shift = shifts ,
277287 scale32 = is_scale32 ,
278288 double_round = is_double_round ,
279- per_channel = False ,
289+ per_channel = per_channel ,
280290 input_unsigned = False ,
281291 output_unsigned = False ,
282292 )
@@ -296,13 +306,15 @@ def build_rescale_conv_output(
296306 op : TosaSerializerTensor ,
297307 output_name : str ,
298308 output_type : ts .DType ,
299- input_scale : float ,
300- weight_scale : float ,
301- output_scale : float ,
309+ input_scale : list [ float ] ,
310+ weight_scale : list [ float ] ,
311+ output_scale : list [ float ] ,
302312 output_zp : int ,
303313):
304314 # TODO add check to verify if this is a Per-channel quantization.
305- post_conv2d_scale = (input_scale * weight_scale ) / output_scale
315+ post_conv2d_scale = [
316+ (inp * w ) / out for inp , w , out in zip (input_scale , weight_scale , output_scale )
317+ ]
306318
307319 # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0.
308320 build_rescale (
@@ -314,5 +326,7 @@ def build_rescale_conv_output(
314326 op .shape ,
315327 0 ,
316328 output_zp ,
329+ False ,
330+ isinstance (weight_scale , torch .Tensor ),
317331 )
318332 return
0 commit comments