Skip to content

Commit 0fc2dae

Browse files
committed
Add 16A8W support and test for mul operation
Pull Request resolved: #13795 Add 16A8W quantization support and test for the mul operation in ExecutorTorch ARM backend. This follows the pattern established for linear operations, extending int16 support to mul operations. Changes: - Add INT16 dtype validation support in op_mul.py - Add test_mul_tensor_16a8w_tosa_INT test function - Enable test_mul.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 308046859 @exported-using-ghexport Differential Revision: [D80510628](https://our.internmc.facebook.com/intern/diff/D80510628/)
1 parent 7a0f900 commit 0fc2dae

File tree

4 files changed

+223
-9
lines changed

4 files changed

+223
-9
lines changed

backends/arm/operators/op_mul.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class MulVisitor_INT(NodeVisitor):
3434

3535
tosa_specs = [
3636
TosaSpecification.create_from_string("TOSA-1.0+INT"),
37+
TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
3738
]
3839

3940
def define_node(
@@ -51,11 +52,11 @@ def define_node(
5152
validate_valid_dtype(
5253
self.target,
5354
[*inputs, output],
54-
[ts.DType.INT8, ts.DType.INT32],
55+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
5556
output.tosa_spec,
5657
)
5758

58-
if inputs[0].dtype == ts.DType.INT8:
59+
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
5960
input_A = inputs[0]
6061
input_B = inputs[1]
6162
input_qparams = get_input_qparams(node)
@@ -80,15 +81,15 @@ def define_node(
8081
tosa_spec=self.tosa_spec,
8182
)
8283
else:
83-
# input[0].dtype == ts.DType.INT32
84+
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
8485
# Non quantized input, natively support by TOSA.MUL
8586
input_A_rescaled, input_B_rescaled = inputs[0], inputs[1]
8687

87-
if output.dtype == ts.DType.INT8:
88+
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
8889
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
8990
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
9091
else:
91-
# output.dtype == ts.DType.INT32
92+
# output.dtype == ts.DType.INT32 (non-quantized)
9293
mul_output = output
9394

9495
# Do the INT32 Mul
@@ -110,6 +111,15 @@ def define_node(
110111
tqutils.insert_rescale_op_to_int8(
111112
tosa_graph, mul_output, output_scale, node, self.tosa_spec
112113
)
114+
elif output.dtype == ts.DType.INT16:
115+
# Scale output back to 16 bit
116+
output_scale = (
117+
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
118+
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
119+
)
120+
tqutils.insert_rescale_op_to_int16(
121+
tosa_graph, mul_output, output_scale, node, self.tosa_spec
122+
)
113123

114124

115125
@register_node_visitor

backends/arm/test/ops/test_mul.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,23 @@
88

99
from typing import Tuple
1010

11+
import pytest
1112
import torch
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
1217

13-
from executorch.backends.arm.test import common
18+
from executorch.backends.arm.test import common, conftest
1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
1621
EthosU85PipelineINT,
1722
TosaPipelineFP,
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2128

2229
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x
2330
aten_op = "torch.ops.aten.mul.Tensor"
@@ -284,3 +291,105 @@ def test_mul_tensor_vgf_INT_int32(test_data: torch.Tensor):
284291
)
285292
pipeline.pop_stage("check.quant_nodes")
286293
pipeline.run()
294+
295+
296+
def get_symmetric_a16w8_mul_quantizer(per_channel_quantization=False):
297+
tosa_version = conftest.get_option("tosa_version")
298+
tosa_profiles = {
299+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
300+
}
301+
302+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
303+
quantizer.set_global(
304+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
305+
)
306+
307+
return Quantize(
308+
quantizer,
309+
get_symmetric_a16w8_quantization_config(
310+
is_per_channel=per_channel_quantization
311+
),
312+
)
313+
314+
315+
@common.parametrize("test_data", test_data_suite)
316+
@pytest.mark.xfail(
317+
reason="missing int16 mul ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13947"
318+
)
319+
def test_mul_tensor_16a8w_tosa_INT(test_data: input_t1):
320+
"""Test mul operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
321+
per_channel_quantization = False
322+
323+
pipeline = TosaPipelineINT[input_t1](
324+
Mul(),
325+
test_data(),
326+
aten_op,
327+
exir_op=[],
328+
per_channel_quantization=per_channel_quantization,
329+
use_to_edge_transform_and_lower=True,
330+
tosa_extensions=["int16"],
331+
)
332+
333+
pipeline.change_args(
334+
"quantize",
335+
get_symmetric_a16w8_mul_quantizer(
336+
per_channel_quantization=per_channel_quantization
337+
),
338+
)
339+
pipeline.run()
340+
341+
342+
@common.parametrize("test_data", test_data_suite)
343+
@common.XfailIfNoCorstone300
344+
@pytest.mark.xfail(
345+
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947"
346+
)
347+
def test_mul_tensor_16a8w_u55_INT16(test_data: input_t1):
348+
"""Test mul operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
349+
per_channel_quantization = False
350+
351+
pipeline = EthosU55PipelineINT[input_t1](
352+
Mul(),
353+
test_data(),
354+
aten_op,
355+
exir_ops=[],
356+
per_channel_quantization=per_channel_quantization,
357+
use_to_edge_transform_and_lower=True,
358+
run_on_fvp=True,
359+
)
360+
361+
pipeline.change_args(
362+
"quantize",
363+
get_symmetric_a16w8_mul_quantizer(
364+
per_channel_quantization=per_channel_quantization
365+
),
366+
)
367+
pipeline.run()
368+
369+
370+
@common.parametrize("test_data", test_data_suite)
371+
@common.XfailIfNoCorstone320
372+
@pytest.mark.xfail(
373+
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947"
374+
)
375+
def test_mul_tensor_16a8w_u85_INT16(test_data: input_t1):
376+
"""Test mul operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
377+
per_channel_quantization = False
378+
379+
pipeline = EthosU85PipelineINT[input_t1](
380+
Mul(),
381+
test_data(),
382+
aten_op,
383+
exir_ops=[],
384+
per_channel_quantization=per_channel_quantization,
385+
use_to_edge_transform_and_lower=True,
386+
run_on_fvp=True,
387+
)
388+
389+
pipeline.change_args(
390+
"quantize",
391+
get_symmetric_a16w8_mul_quantizer(
392+
per_channel_quantization=per_channel_quantization
393+
),
394+
)
395+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def define_arm_tests():
1616
"ops/test_add.py",
1717
"ops/test_avg_pool2d.py",
1818
"ops/test_linear.py",
19+
"ops/test_mul.py",
1920
"ops/test_slice.py",
2021
"ops/test_sigmoid.py",
2122
"ops/test_tanh.py",

backends/arm/tosa/quant_utils.py

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,58 @@ def insert_rescale_op_to_int8(
140140
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
141141
tosa_graph: the tosa_graph to manipulate.
142142
143+
This functions is used in serialization to TOSA for target ops that are
144+
handled by the DQ/D folding pass, which stores the quantization parameters
145+
in the node meta dict.
146+
"""
147+
_insert_rescale_op_to_dtype(
148+
tosa_graph, last_tensor, scale, node, ts.DType.INT8, compute_rescale, tosa_spec
149+
)
150+
151+
152+
def insert_rescale_op_to_int16(
153+
tosa_graph: Any,
154+
last_tensor: TosaArg,
155+
scale: float,
156+
node: Node,
157+
compute_rescale=True,
158+
tosa_spec=None,
159+
) -> None:
160+
"""Rescales the node back to int16, adding a suitable RESCALE op to 'tosa_graph'.
161+
Parameters:
162+
node: The original node that is being handled by the rescales.
163+
last_tensor:the tosa tensor to rescale back.
164+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
165+
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
166+
tosa_graph: the tosa_graph to manipulate.
167+
168+
This functions is used in serialization to TOSA for target ops that are
169+
handled by the DQ/D folding pass, which stores the quantization parameters
170+
in the node meta dict.
171+
"""
172+
_insert_rescale_op_to_dtype(
173+
tosa_graph, last_tensor, scale, node, ts.DType.INT16, compute_rescale, tosa_spec
174+
)
175+
176+
177+
def _insert_rescale_op_to_dtype(
178+
tosa_graph: Any,
179+
last_tensor: TosaArg,
180+
scale: float,
181+
node: Node,
182+
output_dtype: Any,
183+
compute_rescale=True,
184+
tosa_spec=None,
185+
) -> None:
186+
"""Common implementation for rescaling nodes back to a specific dtype.
187+
Parameters:
188+
node: The original node that is being handled by the rescales.
189+
last_tensor:the tosa tensor to rescale back.
190+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
191+
output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16)
192+
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
193+
tosa_graph: the tosa_graph to manipulate.
194+
143195
This functions is used in serialization to TOSA for target ops that are
144196
handled by the DQ/D folding pass, which stores the quantization parameters
145197
in the node meta dict.
@@ -158,13 +210,14 @@ def insert_rescale_op_to_int8(
158210
else:
159211
output_rescale_scale = scale
160212

161-
# Rescale Back to INT8
162-
build_rescale_from_int32(
213+
# Rescale Back to the specified dtype
214+
build_rescale_from_int32_to_dtype(
163215
tosa_graph,
164216
last_tensor,
165217
node.name,
166218
qargs_out.get_zp_per_tensor(),
167219
output_rescale_scale,
220+
output_dtype,
168221
tosa_spec=tosa_spec,
169222
)
170223

@@ -337,14 +390,55 @@ def build_rescale_from_int32(
337390
per_channel: bool = False,
338391
tosa_spec=None,
339392
) -> None:
393+
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
394+
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
395+
build_rescale_from_int32_to_dtype(
396+
tosa_fb,
397+
input_node,
398+
output_name,
399+
output_zp,
400+
rescale_scale,
401+
ts.DType.INT8,
402+
is_scale32,
403+
is_double_round,
404+
per_channel,
405+
tosa_spec,
406+
)
407+
408+
return
409+
410+
411+
def build_rescale_from_int32_to_dtype(
412+
tosa_fb: Any,
413+
input_node: TosaArg,
414+
output_name: str,
415+
output_zp: int,
416+
rescale_scale: float,
417+
output_dtype: Any,
418+
is_scale32: bool = True,
419+
is_double_round: bool = False,
420+
per_channel: bool = False,
421+
tosa_spec=None,
422+
) -> None:
423+
"""Common implementation for rescaling from INT32 to a specific dtype (INT8 or INT16).
424+
425+
Parameters:
426+
tosa_fb: The TOSA serializer
427+
input_node: Input tensor (should be INT32)
428+
output_name: Name for the output tensor
429+
output_zp: Output zero point
430+
rescale_scale: Rescaling factor
431+
output_dtype: Target dtype (ts.DType.INT8 or ts.DType.INT16)
432+
Other parameters: Standard rescale parameters
433+
"""
340434
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
341435
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
342436
build_rescale(
343437
tosa_fb,
344438
[rescale_scale],
345439
input_node,
346440
output_name=output_name,
347-
output_type=ts.DType.INT8,
441+
output_type=output_dtype,
348442
input_zp=[0],
349443
output_zp=[output_zp],
350444
rounding_mode=RoundingMode.SINGLE_ROUND,

0 commit comments

Comments
 (0)