Skip to content

Commit 8f3578a

Browse files
committed
Update on "Arm backend: Add 16A8W support and test for mul operation"
Add 16A8W quantization support and test for the mul operation in ExecutorTorch ARM backend. This follows the pattern established for linear operations, extending int16 support to mul operations. Changes: - Add INT16 dtype validation support in op_mul.py - Add test_mul_tensor_16a8w_tosa_INT test function - Enable test_mul.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: [D80510628](https://our.internmc.facebook.com/intern/diff/D80510628/) cc digantdesai freddan80 per zingo oscarandersson8218 [ghstack-poisoned]
2 parents 337c463 + 91d4363 commit 8f3578a

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

backends/arm/tosa/quant_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# pyre-unsafe
77

8-
# Utiliy functions for TOSA quantized lowerings
8+
# Utility functions for TOSA quantized lowerings
99

1010
import math
1111

@@ -29,11 +29,11 @@ def insert_rescale_ops_to_int32_maxscale(
2929
tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None
3030
) -> tuple[list[Any], float]:
3131
"""For ADD and SUB, we rescale to int32 using a different common scale(2*max(left scale,right scale))
32-
compared to all the other cases. We also multply the left and right scales by 1<<20 giving us extra precision
32+
compared to all the other cases. We also multiply the left and right scales by 1<<20 giving us extra precision
3333
for the computation without overflowing.
3434
3535
Returns a list of the rescaled nodes and the scale factor used,
36-
needed by rescale_node_back_to_int8.
36+
needed by insert_rescale_op_to_int8.
3737
"""
3838

3939
if len(inputs) > 2:
@@ -88,7 +88,7 @@ def insert_rescale_ops_to_int32(
8888
The scales are adjusted using the smallest scale of all 'nodes'.
8989
9090
Returns a list of the rescaled nodes and the scale factor used,
91-
needed by rescale_node_back_to_int8.
91+
needed by insert_rescale_op_to_int8.
9292
9393
This functions is used in serialization to TOSA for target ops that are
9494
handled by the DQ/D folding pass, which stores the quantization parameters
@@ -136,7 +136,7 @@ def insert_rescale_op_to_int8(
136136
Parameters:
137137
node: The original node that is being handled by the rescales.
138138
last_tensor:the tosa tensor to rescale back.
139-
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
139+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
140140
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
141141
tosa_graph: the tosa_graph to manipulate.
142142
@@ -161,7 +161,7 @@ def insert_rescale_op_to_int16(
161161
Parameters:
162162
node: The original node that is being handled by the rescales.
163163
last_tensor:the tosa tensor to rescale back.
164-
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
164+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
165165
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
166166
tosa_graph: the tosa_graph to manipulate.
167167
@@ -187,7 +187,7 @@ def _insert_rescale_op_to_dtype(
187187
Parameters:
188188
node: The original node that is being handled by the rescales.
189189
last_tensor:the tosa tensor to rescale back.
190-
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
190+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
191191
output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16)
192192
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
193193
tosa_graph: the tosa_graph to manipulate.
@@ -224,7 +224,7 @@ def _insert_rescale_op_to_dtype(
224224

225225
# TOSA uses the RESCALE operation to scale between values with differing precision.
226226
# The RESCALE operator is defined using an integer multiply, add, and shift.
227-
# This utility function is for calculating the multier and shift given a scale.
227+
# This utility function is for calculating the multiplier and shift given a scale.
228228
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
229229
def compute_multiplier_and_shift(
230230
scales: list[float], scaleWidth: int = 32
@@ -269,7 +269,7 @@ def compute_multiplier_and_shift(
269269
return multipliers, shifts
270270

271271

272-
# For TOSA spec v1.0 RESCALE operator requires multipler, shifts, input_zp and output_zp to be
272+
# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be
273273
# const inputs. Create constant operators from the data already initialized.
274274
def create_const_ops_for_rescale(
275275
tosa_fb,

0 commit comments

Comments
 (0)