Skip to content

Commit 5b038a7

Browse files
committed
Add 16A8W support and test for mul operation
Add 16A8W quantization support and test for the mul operation in ExecutorTorch ARM backend. This follows the pattern established for linear operations, extending int16 support to mul operations. Changes: - Add INT16 dtype validation support in op_mul.py - Add test_mul_tensor_16a8w_tosa_INT test function - Enable test_mul.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: [D80510628](https://our.internmc.facebook.com/intern/diff/D80510628/) [ghstack-poisoned]
1 parent 15b87c3 commit 5b038a7

File tree

3 files changed

+109
-2
lines changed

3 files changed

+109
-2
lines changed

backends/arm/operators/op_mul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_node(
5151
validate_valid_dtype(
5252
self.target,
5353
[*inputs, output],
54-
[ts.DType.INT8, ts.DType.INT32],
54+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
5555
output.tosa_spec,
5656
)
5757

backends/arm/test/ops/test_mul.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,23 @@
88

99
from typing import Tuple
1010

11+
import pytest
1112
import torch
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
1217

13-
from executorch.backends.arm.test import common
18+
from executorch.backends.arm.test import common, conftest
1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
1621
EthosU85PipelineINT,
1722
TosaPipelineFP,
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2128

2229
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x
2330
aten_op = "torch.ops.aten.mul.Tensor"
@@ -268,3 +275,102 @@ def test_mul_tensor_vgf_INT_int32(test_data: torch.Tensor):
268275
)
269276
pipeline.pop_stage("check.quant_nodes")
270277
pipeline.run()
278+
279+
280+
def get_symmetric_a16w8_mul_quantizer(per_channel_quantization=False):
281+
tosa_version = conftest.get_option("tosa_version")
282+
tosa_profiles = {
283+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
284+
}
285+
286+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
287+
quantizer.set_global(
288+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
289+
)
290+
291+
return Quantize(
292+
quantizer,
293+
get_symmetric_a16w8_quantization_config(
294+
is_per_channel=per_channel_quantization
295+
),
296+
)
297+
298+
299+
@common.parametrize("test_data", test_data_suite)
300+
def test_mul_tensor_16a8w_tosa_INT(test_data: input_t1):
301+
"""Test mul operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
302+
per_channel_quantization = False
303+
304+
pipeline = TosaPipelineINT[input_t1](
305+
Mul(),
306+
test_data(),
307+
aten_op,
308+
exir_op=[],
309+
per_channel_quantization=per_channel_quantization,
310+
use_to_edge_transform_and_lower=True,
311+
tosa_extensions=["int16"],
312+
)
313+
314+
pipeline.change_args(
315+
"quantize",
316+
get_symmetric_a16w8_mul_quantizer(
317+
per_channel_quantization=per_channel_quantization
318+
),
319+
)
320+
pipeline.run()
321+
322+
323+
@common.parametrize("test_data", test_data_suite)
324+
@common.XfailIfNoCorstone300
325+
@pytest.mark.xfail(
326+
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations"
327+
)
328+
def test_mul_tensor_16a8w_u55_INT16(test_data: input_t1):
329+
"""Test mul operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
330+
per_channel_quantization = False
331+
332+
pipeline = EthosU55PipelineINT[input_t1](
333+
Mul(),
334+
test_data(),
335+
aten_op,
336+
exir_ops=[],
337+
per_channel_quantization=per_channel_quantization,
338+
use_to_edge_transform_and_lower=True,
339+
run_on_fvp=True,
340+
)
341+
342+
pipeline.change_args(
343+
"quantize",
344+
get_symmetric_a16w8_mul_quantizer(
345+
per_channel_quantization=per_channel_quantization
346+
),
347+
)
348+
pipeline.run()
349+
350+
351+
@common.parametrize("test_data", test_data_suite)
352+
@common.XfailIfNoCorstone320
353+
@pytest.mark.xfail(
354+
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations"
355+
)
356+
def test_mul_tensor_16a8w_u85_INT16(test_data: input_t1):
357+
"""Test mul operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
358+
per_channel_quantization = False
359+
360+
pipeline = EthosU85PipelineINT[input_t1](
361+
Mul(),
362+
test_data(),
363+
aten_op,
364+
exir_ops=[],
365+
per_channel_quantization=per_channel_quantization,
366+
use_to_edge_transform_and_lower=True,
367+
run_on_fvp=True,
368+
)
369+
370+
pipeline.change_args(
371+
"quantize",
372+
get_symmetric_a16w8_mul_quantizer(
373+
per_channel_quantization=per_channel_quantization
374+
),
375+
)
376+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def define_arm_tests():
1616
"ops/test_add.py",
1717
"ops/test_avg_pool2d.py",
1818
"ops/test_linear.py",
19+
"ops/test_mul.py",
1920
"ops/test_slice.py",
2021
"ops/test_sigmoid.py",
2122
"ops/test_tanh.py",

0 commit comments

Comments
 (0)