5
5
6
6
# pyre-unsafe
7
7
8
- # Utiliy functions for TOSA quantized lowerings
8
+ # Utility functions for TOSA quantized lowerings
9
9
10
10
import math
11
11
@@ -27,11 +27,11 @@ def insert_rescale_ops_to_int32_maxscale(
27
27
tosa_graph : Any , inputs : list [TosaArg ], node : Node , tosa_spec = None
28
28
) -> tuple [list [Any ], float ]:
29
29
"""For ADD and SUB, we rescale to int32 using a different common scale(2*max(left scale,right scale))
30
- compared to all the other cases. We also multply the left and right scales by 1<<20 giving us extra precision
30
+ compared to all the other cases. We also multiply the left and right scales by 1<<20 giving us extra precision
31
31
for the computation without overflowing.
32
32
33
33
Returns a list of the rescaled nodes and the scale factor used,
34
- needed by rescale_node_back_to_int8 .
34
+ needed by insert_rescale_op_to_int8 .
35
35
"""
36
36
37
37
if len (inputs ) > 2 :
@@ -86,7 +86,7 @@ def insert_rescale_ops_to_int32(
86
86
The scales are adjusted using the smallest scale of all 'nodes'.
87
87
88
88
Returns a list of the rescaled nodes and the scale factor used,
89
- needed by rescale_node_back_to_int8 .
89
+ needed by insert_rescale_op_to_int8 .
90
90
91
91
This functions is used in serialization to TOSA for target ops that are
92
92
handled by the DQ/D folding pass, which stores the quantization parameters
@@ -134,7 +134,59 @@ def insert_rescale_op_to_int8(
134
134
Parameters:
135
135
node: The original node that is being handled by the rescales.
136
136
last_tensor:the tosa tensor to rescale back.
137
- scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
137
+ scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
138
+ compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
139
+ tosa_graph: the tosa_graph to manipulate.
140
+
141
+ This functions is used in serialization to TOSA for target ops that are
142
+ handled by the DQ/D folding pass, which stores the quantization parameters
143
+ in the node meta dict.
144
+ """
145
+ _insert_rescale_op_to_dtype (
146
+ tosa_graph , last_tensor , scale , node , ts .DType .INT8 , compute_rescale , tosa_spec
147
+ )
148
+
149
+
150
+ def insert_rescale_op_to_int16 (
151
+ tosa_graph : Any ,
152
+ last_tensor : TosaArg ,
153
+ scale : float ,
154
+ node : Node ,
155
+ compute_rescale = True ,
156
+ tosa_spec = None ,
157
+ ) -> None :
158
+ """Rescales the node back to int16, adding a suitable RESCALE op to 'tosa_graph'.
159
+ Parameters:
160
+ node: The original node that is being handled by the rescales.
161
+ last_tensor:the tosa tensor to rescale back.
162
+ scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
163
+ compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
164
+ tosa_graph: the tosa_graph to manipulate.
165
+
166
+ This functions is used in serialization to TOSA for target ops that are
167
+ handled by the DQ/D folding pass, which stores the quantization parameters
168
+ in the node meta dict.
169
+ """
170
+ _insert_rescale_op_to_dtype (
171
+ tosa_graph , last_tensor , scale , node , ts .DType .INT16 , compute_rescale , tosa_spec
172
+ )
173
+
174
+
175
+ def _insert_rescale_op_to_dtype (
176
+ tosa_graph : Any ,
177
+ last_tensor : TosaArg ,
178
+ scale : float ,
179
+ node : Node ,
180
+ output_dtype : Any ,
181
+ compute_rescale = True ,
182
+ tosa_spec = None ,
183
+ ) -> None :
184
+ """Common implementation for rescaling nodes back to a specific dtype.
185
+ Parameters:
186
+ node: The original node that is being handled by the rescales.
187
+ last_tensor:the tosa tensor to rescale back.
188
+ scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
189
+ output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16)
138
190
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
139
191
tosa_graph: the tosa_graph to manipulate.
140
192
@@ -156,20 +208,21 @@ def insert_rescale_op_to_int8(
156
208
else :
157
209
output_rescale_scale = scale
158
210
159
- # Rescale Back to INT8
160
- build_rescale_from_int32 (
211
+ # Rescale Back to the specified dtype
212
+ build_rescale_from_int32_to_dtype (
161
213
tosa_graph ,
162
214
last_tensor ,
163
215
node .name ,
164
216
qargs_out .get_zp_per_tensor (),
165
217
output_rescale_scale ,
218
+ output_dtype ,
166
219
tosa_spec = tosa_spec ,
167
220
)
168
221
169
222
170
223
# TOSA uses the RESCALE operation to scale between values with differing precision.
171
224
# The RESCALE operator is defined using an integer multiply, add, and shift.
172
- # This utility function is for calculating the multier and shift given a scale.
225
+ # This utility function is for calculating the multiplier and shift given a scale.
173
226
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
174
227
def compute_multiplier_and_shift (
175
228
scales : list [float ], scaleWidth : int = 32
@@ -214,7 +267,7 @@ def compute_multiplier_and_shift(
214
267
return multipliers , shifts
215
268
216
269
217
- # For TOSA spec v1.0 RESCALE operator requires multipler , shifts, input_zp and output_zp to be
270
+ # For TOSA spec v1.0 RESCALE operator requires multiplier , shifts, input_zp and output_zp to be
218
271
# const inputs. Create constant operators from the data already initialized.
219
272
def create_const_ops_for_rescale (
220
273
tosa_fb ,
@@ -335,14 +388,55 @@ def build_rescale_from_int32(
335
388
per_channel : bool = False ,
336
389
tosa_spec = None ,
337
390
) -> None :
391
+ # For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
392
+ # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
393
+ build_rescale_from_int32_to_dtype (
394
+ tosa_fb ,
395
+ input_node ,
396
+ output_name ,
397
+ output_zp ,
398
+ rescale_scale ,
399
+ ts .DType .INT8 ,
400
+ is_scale32 ,
401
+ is_double_round ,
402
+ per_channel ,
403
+ tosa_spec ,
404
+ )
405
+
406
+ return
407
+
408
+
409
+ def build_rescale_from_int32_to_dtype (
410
+ tosa_fb : Any ,
411
+ input_node : TosaArg ,
412
+ output_name : str ,
413
+ output_zp : int ,
414
+ rescale_scale : float ,
415
+ output_dtype : Any ,
416
+ is_scale32 : bool = True ,
417
+ is_double_round : bool = False ,
418
+ per_channel : bool = False ,
419
+ tosa_spec = None ,
420
+ ) -> None :
421
+ """Common implementation for rescaling from INT32 to a specific dtype (INT8 or INT16).
422
+
423
+ Parameters:
424
+ tosa_fb: The TOSA serializer
425
+ input_node: Input tensor (should be INT32)
426
+ output_name: Name for the output tensor
427
+ output_zp: Output zero point
428
+ rescale_scale: Rescaling factor
429
+ output_dtype: Target dtype (ts.DType.INT8 or ts.DType.INT16)
430
+ Other parameters: Standard rescale parameters
431
+ """
338
432
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
339
433
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
340
434
build_rescale (
341
435
tosa_fb ,
342
436
[rescale_scale ],
343
437
input_node ,
344
438
output_name = output_name ,
345
- output_type = ts . DType . INT8 ,
439
+ output_type = output_dtype ,
346
440
input_zp = [0 ],
347
441
output_zp = [output_zp ],
348
442
rounding_mode = RoundingMode .SINGLE_ROUND ,
0 commit comments