Skip to content

Commit 8f6f9b4

Browse files
Ninja91facebook-github-bot
authored andcommitted
Add 16A8W support and test for mul operation
Summary: Add 16A8W quantization support and test for the mul operation in ExecutorTorch ARM backend. This follows the pattern established for linear operations, extending int16 support to mul operations. Changes: - Add INT16 dtype validation support in op_mul.py - Add test_mul_tensor_16a8w_tosa_INT test function - Enable test_mul.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: D80510628
1 parent bb0518c commit 8f6f9b4

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

backends/arm/operators/op_mul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_node(
5151
validate_valid_dtype(
5252
self.target,
5353
[*inputs, output],
54-
[ts.DType.INT8, ts.DType.INT32],
54+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
5555
output.tosa_spec,
5656
)
5757

backends/arm/test/ops/test_mul.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,21 @@
99
from typing import Tuple
1010

1111
import torch
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_a16w8_quantization_config,
14+
TOSAQuantizer,
15+
)
1216

13-
from executorch.backends.arm.test import common
17+
from executorch.backends.arm.test import common, conftest
1418
from executorch.backends.arm.test.tester.test_pipeline import (
1519
EthosU55PipelineINT,
1620
EthosU85PipelineINT,
1721
TosaPipelineFP,
1822
TosaPipelineINT,
1923
VgfPipeline,
2024
)
25+
from executorch.backends.arm.tosa_specification import TosaSpecification
26+
from executorch.backends.xnnpack.test.tester import Quantize
2127

2228
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x
2329
aten_op = "torch.ops.aten.mul.Tensor"
@@ -284,3 +290,46 @@ def test_mul_tensor_vgf_INT_int32(test_data: torch.Tensor):
284290
)
285291
pipeline.pop_stage("check.quant_nodes")
286292
pipeline.run()
293+
294+
295+
def get_symmetric_a16w8_mul_quantizer(u55_config=False, per_channel_quantization=False):
296+
tosa_version = conftest.get_option("tosa_version")
297+
tosa_profiles = {
298+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
299+
}
300+
301+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
302+
quantizer.set_global(
303+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
304+
)
305+
306+
return Quantize(
307+
quantizer,
308+
get_symmetric_a16w8_quantization_config(
309+
is_per_channel=per_channel_quantization
310+
),
311+
)
312+
313+
314+
@common.parametrize("test_data", test_data_suite)
315+
def test_mul_tensor_16a8w_tosa_INT(test_data: input_t1):
316+
"""Test mul operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
317+
per_channel_quantization = False
318+
319+
pipeline = TosaPipelineINT[input_t1](
320+
Mul(),
321+
test_data(),
322+
aten_op,
323+
exir_op=[],
324+
per_channel_quantization=per_channel_quantization,
325+
use_to_edge_transform_and_lower=True,
326+
tosa_extensions=["int16"],
327+
)
328+
329+
pipeline.change_args(
330+
"quantize",
331+
get_symmetric_a16w8_mul_quantizer(
332+
per_channel_quantization=per_channel_quantization
333+
),
334+
)
335+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def define_arm_tests():
1616
"ops/test_add.py",
1717
"ops/test_avg_pool2d.py",
1818
"ops/test_linear.py",
19+
"ops/test_mul.py",
1920
"ops/test_slice.py",
2021
"ops/test_sigmoid.py",
2122
"ops/test_tanh.py",

0 commit comments

Comments
 (0)