Skip to content

Commit fe9447e

Browse files
pytorchbotNinja91
andauthored
Arm backend: Add 16A8W support and test for mul operation (#14196)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13795 by @Ninja91 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/Ninja91/8/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/8/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/Ninja91/8/orig @diff-train-skip-merge Co-authored-by: Nitin Jain <[email protected]>
1 parent 068d341 commit fe9447e

File tree

4 files changed

+231
-16
lines changed

4 files changed

+231
-16
lines changed

backends/arm/operators/op_mul.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class MulVisitor_INT(NodeVisitor):
3434

3535
tosa_specs = [
3636
TosaSpecification.create_from_string("TOSA-1.0+INT"),
37+
TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
3738
]
3839

3940
def define_node(
@@ -51,11 +52,11 @@ def define_node(
5152
validate_valid_dtype(
5253
self.target,
5354
[*inputs, output],
54-
[ts.DType.INT8, ts.DType.INT32],
55+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
5556
output.tosa_spec,
5657
)
5758

58-
if inputs[0].dtype == ts.DType.INT8:
59+
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
5960
input_A = inputs[0]
6061
input_B = inputs[1]
6162
input_qparams = get_input_qparams(node)
@@ -80,15 +81,15 @@ def define_node(
8081
tosa_spec=self.tosa_spec,
8182
)
8283
else:
83-
# input[0].dtype == ts.DType.INT32
84+
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
8485
# Non quantized input, natively support by TOSA.MUL
8586
input_A_rescaled, input_B_rescaled = inputs[0], inputs[1]
8687

87-
if output.dtype == ts.DType.INT8:
88+
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
8889
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
8990
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
9091
else:
91-
# output.dtype == ts.DType.INT32
92+
# output.dtype == ts.DType.INT32 (non-quantized)
9293
mul_output = output
9394

9495
# Do the INT32 Mul
@@ -110,6 +111,15 @@ def define_node(
110111
tqutils.insert_rescale_op_to_int8(
111112
tosa_graph, mul_output, output_scale, node, self.tosa_spec
112113
)
114+
elif output.dtype == ts.DType.INT16:
115+
# Scale output back to 16 bit
116+
output_scale = (
117+
input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
118+
* input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined]
119+
)
120+
tqutils.insert_rescale_op_to_int16(
121+
tosa_graph, mul_output, output_scale, node, self.tosa_spec
122+
)
113123

114124

115125
@register_node_visitor

backends/arm/test/ops/test_mul.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,24 @@
88

99
from typing import Tuple
1010

11+
import pytest
1112
import torch
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
1217

13-
from executorch.backends.arm.test import common
18+
from executorch.backends.arm.test import common, conftest
1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
1621
EthosU85PipelineINT,
1722
TosaPipelineFP,
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
26+
from executorch.backends.arm.tosa.specification import TosaSpecification
27+
28+
from executorch.backends.xnnpack.test.tester import Quantize
2129

2230
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x
2331
aten_op = "torch.ops.aten.mul.Tensor"
@@ -284,3 +292,105 @@ def test_mul_tensor_vgf_INT_int32(test_data: torch.Tensor):
284292
)
285293
pipeline.pop_stage("check.quant_nodes")
286294
pipeline.run()
295+
296+
297+
def get_symmetric_a16w8_mul_quantizer(per_channel_quantization=False):
298+
tosa_version = conftest.get_option("tosa_version")
299+
tosa_profiles = {
300+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
301+
}
302+
303+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
304+
quantizer.set_global(
305+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
306+
)
307+
308+
return Quantize(
309+
quantizer,
310+
get_symmetric_a16w8_quantization_config(
311+
is_per_channel=per_channel_quantization
312+
),
313+
)
314+
315+
316+
@common.parametrize("test_data", test_data_suite)
317+
@pytest.mark.xfail(
318+
reason="missing int16 mul ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13947"
319+
)
320+
def test_mul_tensor_16a8w_tosa_INT(test_data: input_t1):
321+
"""Test mul operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
322+
per_channel_quantization = False
323+
324+
pipeline = TosaPipelineINT[input_t1](
325+
Mul(),
326+
test_data(),
327+
aten_op,
328+
exir_op=[],
329+
per_channel_quantization=per_channel_quantization,
330+
use_to_edge_transform_and_lower=True,
331+
tosa_extensions=["int16"],
332+
)
333+
334+
pipeline.change_args(
335+
"quantize",
336+
get_symmetric_a16w8_mul_quantizer(
337+
per_channel_quantization=per_channel_quantization
338+
),
339+
)
340+
pipeline.run()
341+
342+
343+
@common.parametrize("test_data", test_data_suite)
344+
@common.XfailIfNoCorstone300
345+
@pytest.mark.xfail(
346+
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947"
347+
)
348+
def test_mul_tensor_16a8w_u55_INT16(test_data: input_t1):
349+
"""Test mul operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
350+
per_channel_quantization = False
351+
352+
pipeline = EthosU55PipelineINT[input_t1](
353+
Mul(),
354+
test_data(),
355+
aten_op,
356+
exir_ops=[],
357+
per_channel_quantization=per_channel_quantization,
358+
use_to_edge_transform_and_lower=True,
359+
run_on_fvp=True,
360+
)
361+
362+
pipeline.change_args(
363+
"quantize",
364+
get_symmetric_a16w8_mul_quantizer(
365+
per_channel_quantization=per_channel_quantization
366+
),
367+
)
368+
pipeline.run()
369+
370+
371+
@common.parametrize("test_data", test_data_suite)
372+
@common.XfailIfNoCorstone320
373+
@pytest.mark.xfail(
374+
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947"
375+
)
376+
def test_mul_tensor_16a8w_u85_INT16(test_data: input_t1):
377+
"""Test mul operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
378+
per_channel_quantization = False
379+
380+
pipeline = EthosU85PipelineINT[input_t1](
381+
Mul(),
382+
test_data(),
383+
aten_op,
384+
exir_ops=[],
385+
per_channel_quantization=per_channel_quantization,
386+
use_to_edge_transform_and_lower=True,
387+
run_on_fvp=True,
388+
)
389+
390+
pipeline.change_args(
391+
"quantize",
392+
get_symmetric_a16w8_mul_quantizer(
393+
per_channel_quantization=per_channel_quantization
394+
),
395+
)
396+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def define_arm_tests():
1616
"ops/test_add.py",
1717
"ops/test_avg_pool2d.py",
1818
"ops/test_linear.py",
19+
"ops/test_mul.py",
1920
"ops/test_slice.py",
2021
"ops/test_sigmoid.py",
2122
"ops/test_tanh.py",

backends/arm/tosa/quant_utils.py

Lines changed: 104 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# pyre-unsafe
77

8-
# Utiliy functions for TOSA quantized lowerings
8+
# Utility functions for TOSA quantized lowerings
99

1010
import math
1111

@@ -27,11 +27,11 @@ def insert_rescale_ops_to_int32_maxscale(
2727
tosa_graph: Any, inputs: list[TosaArg], node: Node, tosa_spec=None
2828
) -> tuple[list[Any], float]:
2929
"""For ADD and SUB, we rescale to int32 using a different common scale(2*max(left scale,right scale))
30-
compared to all the other cases. We also multply the left and right scales by 1<<20 giving us extra precision
30+
compared to all the other cases. We also multiply the left and right scales by 1<<20 giving us extra precision
3131
for the computation without overflowing.
3232
3333
Returns a list of the rescaled nodes and the scale factor used,
34-
needed by rescale_node_back_to_int8.
34+
needed by insert_rescale_op_to_int8.
3535
"""
3636

3737
if len(inputs) > 2:
@@ -86,7 +86,7 @@ def insert_rescale_ops_to_int32(
8686
The scales are adjusted using the smallest scale of all 'nodes'.
8787
8888
Returns a list of the rescaled nodes and the scale factor used,
89-
needed by rescale_node_back_to_int8.
89+
needed by insert_rescale_op_to_int8.
9090
9191
This functions is used in serialization to TOSA for target ops that are
9292
handled by the DQ/D folding pass, which stores the quantization parameters
@@ -134,7 +134,59 @@ def insert_rescale_op_to_int8(
134134
Parameters:
135135
node: The original node that is being handled by the rescales.
136136
last_tensor:the tosa tensor to rescale back.
137-
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32'
137+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
138+
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
139+
tosa_graph: the tosa_graph to manipulate.
140+
141+
This functions is used in serialization to TOSA for target ops that are
142+
handled by the DQ/D folding pass, which stores the quantization parameters
143+
in the node meta dict.
144+
"""
145+
_insert_rescale_op_to_dtype(
146+
tosa_graph, last_tensor, scale, node, ts.DType.INT8, compute_rescale, tosa_spec
147+
)
148+
149+
150+
def insert_rescale_op_to_int16(
151+
tosa_graph: Any,
152+
last_tensor: TosaArg,
153+
scale: float,
154+
node: Node,
155+
compute_rescale=True,
156+
tosa_spec=None,
157+
) -> None:
158+
"""Rescales the node back to int16, adding a suitable RESCALE op to 'tosa_graph'.
159+
Parameters:
160+
node: The original node that is being handled by the rescales.
161+
last_tensor:the tosa tensor to rescale back.
162+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
163+
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
164+
tosa_graph: the tosa_graph to manipulate.
165+
166+
This functions is used in serialization to TOSA for target ops that are
167+
handled by the DQ/D folding pass, which stores the quantization parameters
168+
in the node meta dict.
169+
"""
170+
_insert_rescale_op_to_dtype(
171+
tosa_graph, last_tensor, scale, node, ts.DType.INT16, compute_rescale, tosa_spec
172+
)
173+
174+
175+
def _insert_rescale_op_to_dtype(
176+
tosa_graph: Any,
177+
last_tensor: TosaArg,
178+
scale: float,
179+
node: Node,
180+
output_dtype: Any,
181+
compute_rescale=True,
182+
tosa_spec=None,
183+
) -> None:
184+
"""Common implementation for rescaling nodes back to a specific dtype.
185+
Parameters:
186+
node: The original node that is being handled by the rescales.
187+
last_tensor:the tosa tensor to rescale back.
188+
scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_ops_to_int32'
189+
output_dtype: The target dtype (ts.DType.INT8 or ts.DType.INT16)
138190
compute_rescale: boolean indicating whether we need to divide the output scale by the original scale.
139191
tosa_graph: the tosa_graph to manipulate.
140192
@@ -156,20 +208,21 @@ def insert_rescale_op_to_int8(
156208
else:
157209
output_rescale_scale = scale
158210

159-
# Rescale Back to INT8
160-
build_rescale_from_int32(
211+
# Rescale Back to the specified dtype
212+
build_rescale_from_int32_to_dtype(
161213
tosa_graph,
162214
last_tensor,
163215
node.name,
164216
qargs_out.get_zp_per_tensor(),
165217
output_rescale_scale,
218+
output_dtype,
166219
tosa_spec=tosa_spec,
167220
)
168221

169222

170223
# TOSA uses the RESCALE operation to scale between values with differing precision.
171224
# The RESCALE operator is defined using an integer multiply, add, and shift.
172-
# This utility function is for calculating the multier and shift given a scale.
225+
# This utility function is for calculating the multiplier and shift given a scale.
173226
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
174227
def compute_multiplier_and_shift(
175228
scales: list[float], scaleWidth: int = 32
@@ -214,7 +267,7 @@ def compute_multiplier_and_shift(
214267
return multipliers, shifts
215268

216269

217-
# For TOSA spec v1.0 RESCALE operator requires multipler, shifts, input_zp and output_zp to be
270+
# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be
218271
# const inputs. Create constant operators from the data already initialized.
219272
def create_const_ops_for_rescale(
220273
tosa_fb,
@@ -335,14 +388,55 @@ def build_rescale_from_int32(
335388
per_channel: bool = False,
336389
tosa_spec=None,
337390
) -> None:
391+
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
392+
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
393+
build_rescale_from_int32_to_dtype(
394+
tosa_fb,
395+
input_node,
396+
output_name,
397+
output_zp,
398+
rescale_scale,
399+
ts.DType.INT8,
400+
is_scale32,
401+
is_double_round,
402+
per_channel,
403+
tosa_spec,
404+
)
405+
406+
return
407+
408+
409+
def build_rescale_from_int32_to_dtype(
410+
tosa_fb: Any,
411+
input_node: TosaArg,
412+
output_name: str,
413+
output_zp: int,
414+
rescale_scale: float,
415+
output_dtype: Any,
416+
is_scale32: bool = True,
417+
is_double_round: bool = False,
418+
per_channel: bool = False,
419+
tosa_spec=None,
420+
) -> None:
421+
"""Common implementation for rescaling from INT32 to a specific dtype (INT8 or INT16).
422+
423+
Parameters:
424+
tosa_fb: The TOSA serializer
425+
input_node: Input tensor (should be INT32)
426+
output_name: Name for the output tensor
427+
output_zp: Output zero point
428+
rescale_scale: Rescaling factor
429+
output_dtype: Target dtype (ts.DType.INT8 or ts.DType.INT16)
430+
Other parameters: Standard rescale parameters
431+
"""
338432
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
339433
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
340434
build_rescale(
341435
tosa_fb,
342436
[rescale_scale],
343437
input_node,
344438
output_name=output_name,
345-
output_type=ts.DType.INT8,
439+
output_type=output_dtype,
346440
input_zp=[0],
347441
output_zp=[output_zp],
348442
rounding_mode=RoundingMode.SINGLE_ROUND,

0 commit comments

Comments
 (0)