Skip to content

Commit 222f96f

Browse files
committed
Update on "Arm backend: Add 16A8W support and test for mul operation"
Add 16A8W quantization support and test for the mul operation in ExecutorTorch ARM backend. This follows the pattern established for linear operations, extending int16 support to mul operations. Changes: - Add INT16 dtype validation support in op_mul.py - Add test_mul_tensor_16a8w_tosa_INT test function - Enable test_mul.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: [D80510628](https://our.internmc.facebook.com/intern/diff/D80510628/) cc digantdesai freddan80 per zingo oscarandersson8218 [ghstack-poisoned]
2 parents e1268e0 + 346cd5d commit 222f96f

File tree

3 files changed

+111
-7
lines changed

3 files changed

+111
-7
lines changed

backends/arm/operators/op_mul.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class MulVisitor_INT(NodeVisitor):
3434

3535
tosa_specs = [
3636
TosaSpecification.create_from_string("TOSA-1.0+INT"),
37+
TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
3738
]
3839

3940
def define_node(
@@ -55,7 +56,7 @@ def define_node(
5556
output.tosa_spec,
5657
)
5758

58-
if inputs[0].dtype == ts.DType.INT8:
59+
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
5960
input_A = inputs[0]
6061
input_B = inputs[1]
6162
input_qparams = get_input_qparams(node)
@@ -84,11 +85,11 @@ def define_node(
8485
# Non quantized input, natively support by TOSA.MUL
8586
input_A_rescaled, input_B_rescaled = inputs[0], inputs[1]
8687

87-
if output.dtype == ts.DType.INT8:
88+
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
8889
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
8990
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
9091
else:
91-
# output.dtype == ts.DType.INT16 or ts.DType.INT32
92+
# output.dtype == ts.DType.INT32 (non-quantized)
9293
mul_output = output
9394

9495
# Do the INT32 Mul
@@ -110,6 +111,15 @@ def define_node(
110111
tqutils.insert_rescale_op_to_int8(
111112
tosa_graph, mul_output, output_scale, node, self.tosa_spec
112113
)
114+
elif output.dtype == ts.DType.INT16:
115+
# Scale output back to 16 bit
116+
output_scale = (
117+
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
118+
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
119+
)
120+
tqutils.insert_rescale_op_to_int16(
121+
tosa_graph, mul_output, output_scale, node, self.tosa_spec
122+
)
113123

114124

115125
@register_node_visitor

backends/arm/test/ops/test_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def get_symmetric_a16w8_add_quantizer(per_channel_quantization=False):
262262

263263
@common.parametrize("test_data", Add.test_data)
264264
@pytest.mark.xfail(
265-
reason="missing int16 add ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13969"
265+
reason="missing int16 add ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13730"
266266
)
267267
def test_add_tensor_16a8w_tosa_INT(test_data: input_t1):
268268
"""Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)"""

backends/arm/tosa/quant_utils.py

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,58 @@ def insert_rescale_op_to_int8(
140140
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
141141
tosa_graph: the tosa_graph to manipulate.
142142
143+
This functions is used in serialization to TOSA for target ops that are
144+
handled by the DQ/D folding pass, which stores the quantization parameters
145+
in the node meta dict.
146+
"""
147+
_insert_rescale_op_to_dtype(
148+
tosa_graph, last_tensor, scale, node, ts.DType.INT8, compute_rescale, tosa_spec
149+
)
150+
151+
152+
def insert_rescale_op_to_int16(
153+
tosa_graph: Any,
154+
last_tensor: TosaArg,
155+
scale: float,
156+
node: Node,
157+
compute_rescale=True,
158+
tosa_spec=None,
159+
) -> None:
160+
"""Rescales the node back to int16, adding a suitable RESCALE op to 'tosa_graph'.
161+
Parameters:
162+
node: The original node that is being handled by the rescales.
163+
last_tensor:the tosa tensor to rescale back.
164+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
165+
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
166+
tosa_graph: the tosa_graph to manipulate.
167+
168+
This functions is used in serialization to TOSA for target ops that are
169+
handled by the DQ/D folding pass, which stores the quantization parameters
170+
in the node meta dict.
171+
"""
172+
_insert_rescale_op_to_dtype(
173+
tosa_graph, last_tensor, scale, node, ts.DType.INT16, compute_rescale, tosa_spec
174+
)
175+
176+
177+
def _insert_rescale_op_to_dtype(
178+
tosa_graph: Any,
179+
last_tensor: TosaArg,
180+
scale: float,
181+
node: Node,
182+
output_dtype: Any,
183+
compute_rescale=True,
184+
tosa_spec=None,
185+
) -> None:
186+
"""Common implementation for rescaling nodes back to a specific dtype.
187+
Parameters:
188+
node: The original node that is being handled by the rescales.
189+
last_tensor:the tosa tensor to rescale back.
190+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
191+
output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16)
192+
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
193+
tosa_graph: the tosa_graph to manipulate.
194+
143195
This functions is used in serialization to TOSA for target ops that are
144196
handled by the DQ/D folding pass, which stores the quantization parameters
145197
in the node meta dict.
@@ -158,13 +210,14 @@ def insert_rescale_op_to_int8(
158210
else:
159211
output_rescale_scale = scale
160212

161-
# Rescale Back to INT8
162-
build_rescale_from_int32(
213+
# Rescale Back to the specified dtype
214+
build_rescale_from_int32_to_dtype(
163215
tosa_graph,
164216
last_tensor,
165217
node.name,
166218
qargs_out.get_zp_per_tensor(),
167219
output_rescale_scale,
220+
output_dtype,
168221
tosa_spec=tosa_spec,
169222
)
170223

@@ -337,14 +390,55 @@ def build_rescale_from_int32(
337390
per_channel: bool = False,
338391
tosa_spec=None,
339392
) -> None:
393+
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
394+
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
395+
build_rescale_from_int32_to_dtype(
396+
tosa_fb,
397+
input_node,
398+
output_name,
399+
output_zp,
400+
rescale_scale,
401+
ts.DType.INT8,
402+
is_scale32,
403+
is_double_round,
404+
per_channel,
405+
tosa_spec,
406+
)
407+
408+
return
409+
410+
411+
def build_rescale_from_int32_to_dtype(
412+
tosa_fb: Any,
413+
input_node: TosaArg,
414+
output_name: str,
415+
output_zp: int,
416+
rescale_scale: float,
417+
output_dtype: Any,
418+
is_scale32: bool = True,
419+
is_double_round: bool = False,
420+
per_channel: bool = False,
421+
tosa_spec=None,
422+
) -> None:
423+
"""Common implementation for rescaling from INT32 to a specific dtype (INT8 or INT16).
424+
425+
Parameters:
426+
tosa_fb: The TOSA serializer
427+
input_node: Input tensor (should be INT32)
428+
output_name: Name for the output tensor
429+
output_zp: Output zero point
430+
rescale_scale: Rescaling factor
431+
output_dtype: Target dtype (ts.DType.INT8 or ts.DType.INT16)
432+
Other parameters: Standard rescale parameters
433+
"""
340434
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
341435
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
342436
build_rescale(
343437
tosa_fb,
344438
[rescale_scale],
345439
input_node,
346440
output_name=output_name,
347-
output_type=ts.DType.INT8,
441+
output_type=output_dtype,
348442
input_zp=[0],
349443
output_zp=[output_zp],
350444
rounding_mode=RoundingMode.SINGLE_ROUND,

0 commit comments

Comments
 (0)