Skip to content

Commit 3605902

Browse files
committed
Add 16A8W support and test for mul operation
Pull Request resolved: #13795 Add 16A8W quantization support and test for the mul operation in ExecutorTorch ARM backend. This follows the pattern established for linear operations, extending int16 support to mul operations. Changes: - Add INT16 dtype validation support in op_mul.py - Add test_mul_tensor_16a8w_tosa_INT test function - Enable test_mul.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 308058745 @exported-using-ghexport @bypass-github-pytorch-ci-checks @bypass-github-executorch-ci-checks Differential Revision: [D80510628](https://our.internmc.facebook.com/intern/diff/D80510628/)
1 parent 3038914 commit 3605902

File tree

4 files changed

+230
-16
lines changed

4 files changed

+230
-16
lines changed

backends/arm/operators/op_mul.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class MulVisitor_INT(NodeVisitor):
3434

3535
tosa_specs = [
3636
TosaSpecification.create_from_string("TOSA-1.0+INT"),
37+
TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
3738
]
3839

3940
def define_node(
@@ -51,11 +52,11 @@ def define_node(
5152
validate_valid_dtype(
5253
self.target,
5354
[*inputs, output],
54-
[ts.DType.INT8, ts.DType.INT32],
55+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
5556
output.tosa_spec,
5657
)
5758

58-
if inputs[0].dtype == ts.DType.INT8:
59+
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
5960
input_A = inputs[0]
6061
input_B = inputs[1]
6162
input_qparams = get_input_qparams(node)
@@ -80,15 +81,15 @@ def define_node(
8081
tosa_spec=self.tosa_spec,
8182
)
8283
else:
83-
# input[0].dtype == ts.DType.INT32
84+
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
8485
# Non quantized input, natively support by TOSA.MUL
8586
input_A_rescaled, input_B_rescaled = inputs[0], inputs[1]
8687

87-
if output.dtype == ts.DType.INT8:
88+
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
8889
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
8990
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
9091
else:
91-
# output.dtype == ts.DType.INT32
92+
# output.dtype == ts.DType.INT32 (non-quantized)
9293
mul_output = output
9394

9495
# Do the INT32 Mul
@@ -110,6 +111,15 @@ def define_node(
110111
tqutils.insert_rescale_op_to_int8(
111112
tosa_graph, mul_output, output_scale, node, self.tosa_spec
112113
)
114+
elif output.dtype == ts.DType.INT16:
115+
# Scale output back to 16 bit
116+
output_scale = (
117+
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
118+
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
119+
)
120+
tqutils.insert_rescale_op_to_int16(
121+
tosa_graph, mul_output, output_scale, node, self.tosa_spec
122+
)
113123

114124

115125
@register_node_visitor

backends/arm/test/ops/test_mul.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,23 @@
88

99
from typing import Tuple
1010

11+
import pytest
1112
import torch
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
1217

13-
from executorch.backends.arm.test import common
18+
from executorch.backends.arm.test import common, conftest
1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
1621
EthosU85PipelineINT,
1722
TosaPipelineFP,
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2128

2229
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x
2330
aten_op = "torch.ops.aten.mul.Tensor"
@@ -284,3 +291,105 @@ def test_mul_tensor_vgf_INT_int32(test_data: torch.Tensor):
284291
)
285292
pipeline.pop_stage("check.quant_nodes")
286293
pipeline.run()
294+
295+
296+
def get_symmetric_a16w8_mul_quantizer(per_channel_quantization=False):
297+
tosa_version = conftest.get_option("tosa_version")
298+
tosa_profiles = {
299+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
300+
}
301+
302+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
303+
quantizer.set_global(
304+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
305+
)
306+
307+
return Quantize(
308+
quantizer,
309+
get_symmetric_a16w8_quantization_config(
310+
is_per_channel=per_channel_quantization
311+
),
312+
)
313+
314+
315+
@common.parametrize("test_data", test_data_suite)
316+
@pytest.mark.xfail(
317+
reason="missing int16 mul ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13947"
318+
)
319+
def test_mul_tensor_16a8w_tosa_INT(test_data: input_t1):
320+
"""Test mul operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
321+
per_channel_quantization = False
322+
323+
pipeline = TosaPipelineINT[input_t1](
324+
Mul(),
325+
test_data(),
326+
aten_op,
327+
exir_op=[],
328+
per_channel_quantization=per_channel_quantization,
329+
use_to_edge_transform_and_lower=True,
330+
tosa_extensions=["int16"],
331+
)
332+
333+
pipeline.change_args(
334+
"quantize",
335+
get_symmetric_a16w8_mul_quantizer(
336+
per_channel_quantization=per_channel_quantization
337+
),
338+
)
339+
pipeline.run()
340+
341+
342+
@common.parametrize("test_data", test_data_suite)
343+
@common.XfailIfNoCorstone300
344+
@pytest.mark.xfail(
345+
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947"
346+
)
347+
def test_mul_tensor_16a8w_u55_INT16(test_data: input_t1):
348+
"""Test mul operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
349+
per_channel_quantization = False
350+
351+
pipeline = EthosU55PipelineINT[input_t1](
352+
Mul(),
353+
test_data(),
354+
aten_op,
355+
exir_ops=[],
356+
per_channel_quantization=per_channel_quantization,
357+
use_to_edge_transform_and_lower=True,
358+
run_on_fvp=True,
359+
)
360+
361+
pipeline.change_args(
362+
"quantize",
363+
get_symmetric_a16w8_mul_quantizer(
364+
per_channel_quantization=per_channel_quantization
365+
),
366+
)
367+
pipeline.run()
368+
369+
370+
@common.parametrize("test_data", test_data_suite)
371+
@common.XfailIfNoCorstone320
372+
@pytest.mark.xfail(
373+
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947"
374+
)
375+
def test_mul_tensor_16a8w_u85_INT16(test_data: input_t1):
376+
"""Test mul operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
377+
per_channel_quantization = False
378+
379+
pipeline = EthosU85PipelineINT[input_t1](
380+
Mul(),
381+
test_data(),
382+
aten_op,
383+
exir_ops=[],
384+
per_channel_quantization=per_channel_quantization,
385+
use_to_edge_transform_and_lower=True,
386+
run_on_fvp=True,
387+
)
388+
389+
pipeline.change_args(
390+
"quantize",
391+
get_symmetric_a16w8_mul_quantizer(
392+
per_channel_quantization=per_channel_quantization
393+
),
394+
)
395+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def define_arm_tests():
1616
"ops/test_add.py",
1717
"ops/test_avg_pool2d.py",
1818
"ops/test_linear.py",
19+
"ops/test_mul.py",
1920
"ops/test_slice.py",
2021
"ops/test_sigmoid.py",
2122
"ops/test_tanh.py",

backends/arm/tosa/quant_utils.py

Lines changed: 104 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# pyre-unsafe
77

8-
# Utiliy functions for TOSA quantized lowerings
8+
# Utility functions for TOSA quantized lowerings
99

1010
import math
1111

@@ -29,11 +29,11 @@ def insert_rescale_ops_to_int32_maxscale(
2929
tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None
3030
) -> tuple[list[Any], float]:
3131
"""For ADD and SUB, we rescale to int32 using a different common scale(2*max(left scale,right scale))
32-
compared to all the other cases. We also multply the left and right scales by 1<<20 giving us extra precision
32+
compared to all the other cases. We also multiply the left and right scales by 1<<20 giving us extra precision
3333
for the computation without overflowing.
3434
3535
Returns a list of the rescaled nodes and the scale factor used,
36-
needed by rescale_node_back_to_int8.
36+
needed by insert_rescale_op_to_int8.
3737
"""
3838

3939
if len(inputs) > 2:
@@ -88,7 +88,7 @@ def insert_rescale_ops_to_int32(
8888
The scales are adjusted using the smallest scale of all 'nodes'.
8989
9090
Returns a list of the rescaled nodes and the scale factor used,
91-
needed by rescale_node_back_to_int8.
91+
needed by insert_rescale_op_to_int8.
9292
9393
This functions is used in serialization to TOSA for target ops that are
9494
handled by the DQ/D folding pass, which stores the quantization parameters
@@ -136,7 +136,59 @@ def insert_rescale_op_to_int8(
136136
Parameters:
137137
node: The original node that is being handled by the rescales.
138138
last_tensor:the tosa tensor to rescale back.
139-
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
139+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
140+
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
141+
tosa_graph: the tosa_graph to manipulate.
142+
143+
This functions is used in serialization to TOSA for target ops that are
144+
handled by the DQ/D folding pass, which stores the quantization parameters
145+
in the node meta dict.
146+
"""
147+
_insert_rescale_op_to_dtype(
148+
tosa_graph, last_tensor, scale, node, ts.DType.INT8, compute_rescale, tosa_spec
149+
)
150+
151+
152+
def insert_rescale_op_to_int16(
153+
tosa_graph: Any,
154+
last_tensor: TosaArg,
155+
scale: float,
156+
node: Node,
157+
compute_rescale=True,
158+
tosa_spec=None,
159+
) -> None:
160+
"""Rescales the node back to int16, adding a suitable RESCALE op to 'tosa_graph'.
161+
Parameters:
162+
node: The original node that is being handled by the rescales.
163+
last_tensor:the tosa tensor to rescale back.
164+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
165+
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
166+
tosa_graph: the tosa_graph to manipulate.
167+
168+
This functions is used in serialization to TOSA for target ops that are
169+
handled by the DQ/D folding pass, which stores the quantization parameters
170+
in the node meta dict.
171+
"""
172+
_insert_rescale_op_to_dtype(
173+
tosa_graph, last_tensor, scale, node, ts.DType.INT16, compute_rescale, tosa_spec
174+
)
175+
176+
177+
def _insert_rescale_op_to_dtype(
178+
tosa_graph: Any,
179+
last_tensor: TosaArg,
180+
scale: float,
181+
node: Node,
182+
output_dtype: Any,
183+
compute_rescale=True,
184+
tosa_spec=None,
185+
) -> None:
186+
"""Common implementation for rescaling nodes back to a specific dtype.
187+
Parameters:
188+
node: The original node that is being handled by the rescales.
189+
last_tensor:the tosa tensor to rescale back.
190+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
191+
output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16)
140192
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
141193
tosa_graph: the tosa_graph to manipulate.
142194
@@ -158,20 +210,21 @@ def insert_rescale_op_to_int8(
158210
else:
159211
output_rescale_scale = scale
160212

161-
# Rescale Back to INT8
162-
build_rescale_from_int32(
213+
# Rescale Back to the specified dtype
214+
build_rescale_from_int32_to_dtype(
163215
tosa_graph,
164216
last_tensor,
165217
node.name,
166218
qargs_out.get_zp_per_tensor(),
167219
output_rescale_scale,
220+
output_dtype,
168221
tosa_spec=tosa_spec,
169222
)
170223

171224

172225
# TOSA uses the RESCALE operation to scale between values with differing precision.
173226
# The RESCALE operator is defined using an integer multiply, add, and shift.
174-
# This utility function is for calculating the multier and shift given a scale.
227+
# This utility function is for calculating the multiplier and shift given a scale.
175228
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
176229
def compute_multiplier_and_shift(
177230
scales: list[float], scaleWidth: int = 32
@@ -216,7 +269,7 @@ def compute_multiplier_and_shift(
216269
return multipliers, shifts
217270

218271

219-
# For TOSA spec v1.0 RESCALE operator requires multipler, shifts, input_zp and output_zp to be
272+
# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be
220273
# const inputs. Create constant operators from the data already initialized.
221274
def create_const_ops_for_rescale(
222275
tosa_fb,
@@ -337,14 +390,55 @@ def build_rescale_from_int32(
337390
per_channel: bool = False,
338391
tosa_spec=None,
339392
) -> None:
393+
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
394+
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
395+
build_rescale_from_int32_to_dtype(
396+
tosa_fb,
397+
input_node,
398+
output_name,
399+
output_zp,
400+
rescale_scale,
401+
ts.DType.INT8,
402+
is_scale32,
403+
is_double_round,
404+
per_channel,
405+
tosa_spec,
406+
)
407+
408+
return
409+
410+
411+
def build_rescale_from_int32_to_dtype(
412+
tosa_fb: Any,
413+
input_node: TosaArg,
414+
output_name: str,
415+
output_zp: int,
416+
rescale_scale: float,
417+
output_dtype: Any,
418+
is_scale32: bool = True,
419+
is_double_round: bool = False,
420+
per_channel: bool = False,
421+
tosa_spec=None,
422+
) -> None:
423+
"""Common implementation for rescaling from INT32 to a specific dtype (INT8 or INT16).
424+
425+
Parameters:
426+
tosa_fb: The TOSA serializer
427+
input_node: Input tensor (should be INT32)
428+
output_name: Name for the output tensor
429+
output_zp: Output zero point
430+
rescale_scale: Rescaling factor
431+
output_dtype: Target dtype (ts.DType.INT8 or ts.DType.INT16)
432+
Other parameters: Standard rescale parameters
433+
"""
340434
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
341435
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
342436
build_rescale(
343437
tosa_fb,
344438
[rescale_scale],
345439
input_node,
346440
output_name=output_name,
347-
output_type=ts.DType.INT8,
441+
output_type=output_dtype,
348442
input_zp=[0],
349443
output_zp=[output_zp],
350444
rounding_mode=RoundingMode.SINGLE_ROUND,

0 commit comments

Comments
 (0)