Skip to content

Commit dc8c21a

Browse files
committed
Arm backend: Move ReplaceScalarTensorWithFullPass to transforms
The pass is general and can be used by multiple backends. The aten.scalar_tensor is replaced by a aten.full which is already supported by Arm backend. Adds new method to Arm tester for getting output as the nn.module in the unit test does not take any input. The output is then manually compared within the unit test. Change-Id: I2bf211a2ce561d53e8a6cf683fdbda58e675938e
1 parent d3cca89 commit dc8c21a

File tree

6 files changed

+245
-30
lines changed

6 files changed

+245
-30
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@
7878
UnsqueezeScalarPlaceholdersPass,
7979
)
8080
from executorch.backends.arm.tosa_specification import TosaSpecification
81-
81+
from executorch.backends.transforms.replace_scalar_tensor_with_full import (
82+
ReplaceScalarTensorWithFullPass,
83+
)
8284
from executorch.backends.transforms.replace_scalar_with_tensor import (
8385
ReplaceScalarWithTensorArgPass,
8486
)
@@ -133,6 +135,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
133135
return self._transform(exported_program.graph_module)
134136

135137
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
138+
self.add_pass(ReplaceScalarTensorWithFullPass())
136139
self.add_pass(ReplaceScalarWithTensorArgPass())
137140
self.add_pass(FuseQuantizedActivationPass())
138141
self.add_pass(RemoveGetItemPass())
@@ -194,4 +197,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
194197
self.add_pass(DecomposeDivPass())
195198
self.add_pass(DecomposeSoftmaxesPass())
196199
self.add_pass(ConvertMinMaxPass())
200+
self.add_pass(ReplaceScalarTensorWithFullPass())
197201
return self._transform(graph_module)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def is_node_supported(
171171
exir_ops.edge.aten.constant_pad_nd.default,
172172
exir_ops.edge.aten.amax.default,
173173
exir_ops.edge.aten.amin.default,
174+
torch.ops.aten.scalar_tensor.default,
174175
]
175176

176177
return supported
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
10+
import torch
11+
from executorch.backends.arm.quantizer.arm_quantizer import (
12+
get_symmetric_quantization_config,
13+
TOSAQuantizer,
14+
)
15+
from executorch.backends.arm.test import common
16+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17+
from executorch.backends.arm.tosa_specification import TosaSpecification
18+
from executorch.backends.xnnpack.test.tester.tester import Quantize
19+
from parameterized import parameterized
20+
21+
22+
float_test_data_suite = [
23+
# (test_name, scalar input, scalar input type,)
24+
(
25+
"scalar_tensor_float_1",
26+
3.7,
27+
torch.float32,
28+
),
29+
(
30+
"scalar_tensor_float_2",
31+
66,
32+
torch.float32,
33+
),
34+
]
35+
36+
int_test_data_suite = [
37+
# (test_name, scalar input, scalar input type,)
38+
(
39+
"scalar_tensor_int32",
40+
33,
41+
torch.int32,
42+
),
43+
(
44+
"scalar_tensor_int8",
45+
8,
46+
torch.int8,
47+
),
48+
(
49+
"scalar_tensor_int16",
50+
16 * 16 * 16,
51+
torch.int16,
52+
),
53+
]
54+
55+
56+
class ScalarTensor(torch.nn.Module):
57+
def __init__(self, scalar, dtype=torch.float32):
58+
super().__init__()
59+
self.scalar = scalar
60+
self.dtype = dtype
61+
62+
def forward(self):
63+
return torch.scalar_tensor(self.scalar, dtype=self.dtype)
64+
65+
66+
class TestScalarTensor(unittest.TestCase):
67+
68+
def _test_scalar_tensor_tosa_MI_pipeline(
69+
self, module: torch.nn.Module, expected_output
70+
):
71+
test_outputs = []
72+
in_data = ()
73+
74+
(
75+
ArmTester(
76+
module,
77+
example_inputs=in_data,
78+
compile_spec=common.get_tosa_compile_spec(
79+
"TOSA-0.80+MI",
80+
),
81+
)
82+
.export()
83+
.check_count({"torch.ops.aten.scalar_tensor.default": 1})
84+
.to_edge_transform_and_lower()
85+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
86+
.to_executorch()
87+
.run_method_and_get_output(test_outputs, inputs=in_data)
88+
)
89+
self._verify_output(test_outputs, expected_output)
90+
91+
def _test_scalar_tensor_tosa_BI_pipeline(
92+
self, module: torch.nn.Module, expected_output
93+
):
94+
test_outputs = []
95+
in_data = ()
96+
tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+BI")
97+
compile_spec = common.get_tosa_compile_spec(tosa_spec)
98+
quantizer = TOSAQuantizer(tosa_spec).set_io(get_symmetric_quantization_config())
99+
100+
(
101+
ArmTester(
102+
module,
103+
example_inputs=in_data,
104+
compile_spec=compile_spec,
105+
)
106+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
107+
.export()
108+
.check_count({"torch.ops.aten.full.default": 1}) # Already replaced
109+
.to_edge_transform_and_lower()
110+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
111+
.to_executorch()
112+
.run_method_and_get_output(test_outputs, inputs=in_data)
113+
)
114+
self._verify_output(test_outputs, expected_output)
115+
116+
def _verify_output(self, test_outputs, expected_output):
117+
out_data = torch.squeeze(test_outputs[0][0])
118+
assert out_data == expected_output
119+
assert out_data.dtype == expected_output.dtype
120+
121+
@parameterized.expand(int_test_data_suite + float_test_data_suite)
122+
def test_scalar_tensor_tosa_MI( # Note TOSA MI supports all types
123+
self, test_name: str, scalar_value, scalar_type
124+
):
125+
scalar = scalar_value
126+
dtype = scalar_type
127+
self._test_scalar_tensor_tosa_MI_pipeline(
128+
ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype)
129+
)
130+
131+
@parameterized.expand(float_test_data_suite)
132+
def test_scalar_tensor_tosa_BI(self, test_name: str, scalar_value, scalar_type):
133+
scalar = scalar_value
134+
dtype = scalar_type
135+
self._test_scalar_tensor_tosa_BI_pipeline(
136+
ScalarTensor(scalar, dtype), torch.scalar_tensor(scalar, dtype=dtype)
137+
)

backends/arm/test/tester/arm_tester.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,60 @@ def serialize(
338338
def is_quantized(self) -> bool:
339339
return self.stages[self.stage_name(tester.Quantize)] is not None
340340

341+
def run_method_and_get_output(
342+
self,
343+
test_outputs: List,
344+
inputs: Optional[Tuple[torch.Tensor]] = None,
345+
stage: Optional[str] = None,
346+
num_runs=1,
347+
):
348+
"""
349+
Returns the run_artifact output of 'stage'. This output is returned as parameter of type List.
350+
Returns self to allow the function to be run in a test chain.
351+
352+
Args:
353+
stage: (Optional[str]): The name of the stage to compare.
354+
The default is the latest run stage.
355+
test_output: All output results.
356+
inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data.
357+
The default is random data.
358+
"""
359+
edge_stage = self.stages[self.stage_name(tester.ToEdge)]
360+
if edge_stage is None:
361+
edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)]
362+
assert (
363+
edge_stage is not None
364+
), "To get outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
365+
366+
stage = stage or self.cur
367+
test_stage = self.stages[stage]
368+
369+
exported_program = self.stages[self.stage_name(tester.Export)].artifact
370+
output_nodes = get_output_nodes(exported_program)
371+
output_qparams = get_output_quantization_params(output_nodes)
372+
373+
quantization_scales = []
374+
for node in output_qparams:
375+
quantization_scales.append(getattr(output_qparams[node], "scale", None))
376+
377+
# Loop inputs and get outputs of the test stage.
378+
for run_iteration in range(num_runs):
379+
reference_input = inputs if inputs else next(self.generate_random_inputs())
380+
381+
input_shapes = [
382+
generated_input.shape if hasattr(generated_input, "shape") else (1,)
383+
for generated_input in reference_input
384+
]
385+
input_shape_str = ", ".join([str(list(i)) for i in input_shapes])
386+
logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}")
387+
388+
test_output, _ = pytree.tree_flatten(
389+
test_stage.run_artifact(reference_input)
390+
)
391+
test_outputs.append(test_output)
392+
393+
return self
394+
341395
def run_method_and_compare_outputs(
342396
self,
343397
inputs: Optional[Tuple[torch.Tensor]] = None,

backends/cadence/aot/replace_ops.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
)
3939
from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass
4040
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
41+
from executorch.backends.transforms.replace_scalar_tensor_with_full import (
42+
ReplaceScalarTensorWithFullPass,
43+
)
4144
from executorch.backends.transforms.replace_scalar_with_tensor import (
4245
ReplaceScalarWithTensorArgPass,
4346
)
@@ -1722,35 +1725,9 @@ def call_operator(self, op, args, kwargs, meta):
17221725
register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass)
17231726

17241727

1725-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1726-
class ReplaceScalarTensorWithFullPass(ExportPass):
1727-
"""
1728-
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
1729-
scalar_tensor is not supported, so this is an opt_level=0 pass.
1730-
"""
1731-
1732-
def call_operator(
1733-
self,
1734-
op,
1735-
args: Tuple[Argument, ...],
1736-
kwargs: Dict[str, Argument],
1737-
meta: NodeMetadata,
1738-
) -> ProxyValue:
1739-
if op not in {
1740-
exir_ops.edge.aten.scalar_tensor.default,
1741-
torch.ops.aten.scalar_tensor.default,
1742-
}:
1743-
return super().call_operator(op, args, kwargs, meta)
1744-
1745-
return super().call_operator(
1746-
exir_ops.edge.aten.full.default,
1747-
(
1748-
[1],
1749-
args[0],
1750-
),
1751-
{"dtype": torch.float32},
1752-
meta,
1753-
)
1728+
register_cadence_pass(CadencePassAttribute(opt_level=0))(
1729+
ReplaceScalarTensorWithFullPass
1730+
)
17541731

17551732

17561733
@register_cadence_pass(CadencePassAttribute(opt_level=0))
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import Dict, Tuple
9+
10+
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
13+
from torch.fx.node import Argument
14+
15+
16+
class ReplaceScalarTensorWithFullPass(ExportPass):
17+
"""
18+
aten.scalar_tensor can be replaced by aten.full with a shape of [1].
19+
"""
20+
21+
def call_operator(
22+
self,
23+
op,
24+
args: Tuple[Argument, ...],
25+
kwargs: Dict[str, Argument],
26+
meta: NodeMetadata,
27+
) -> ProxyValue:
28+
if op not in {
29+
exir_ops.edge.aten.scalar_tensor.default,
30+
torch.ops.aten.scalar_tensor.default,
31+
}:
32+
return super().call_operator(op, args, kwargs, meta)
33+
34+
return super().call_operator(
35+
exir_ops.edge.aten.full.default,
36+
(
37+
[1],
38+
args[0],
39+
),
40+
{"dtype": kwargs["dtype"]},
41+
meta,
42+
)

0 commit comments

Comments
 (0)