Skip to content

Commit 436e7ac

Browse files
authored
Merge branch 'main' into add_inception_v3_test
2 parents 73cab60 + ec4228c commit 436e7ac

File tree

244 files changed

+2438
-6472
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

244 files changed

+2438
-6472
lines changed

backends/arm/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ The Arm EthosU Backend should be considered a prototype quality at this point, l
181181
## Current flows
182182

183183
The EthosUBackend has a two stage process,
184-
- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v0.80 TOSA BI with specific concern to a subset which gives support on Ethos-U55 and Ethos-U85, the target of the initial prototype efforts. This calls into the TOSABackend.
185-
- Lower via the ethos-u-vela compilation flow which takes TOSA v0.80 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution.
184+
- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v1.0 TOSA INT with specific concern to a subset which gives support on Ethos-U55 and Ethos-U85, the target of the initial prototype efforts. This calls into the TOSABackend.
185+
- Lower via the ethos-u-vela compilation flow which takes TOSA v1.0 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution.
186186

187187
The EthosUPartitioner is currenly used to ensure the operations converted are Ethos-U compatible, but will be extended to offer spec-correct TOSA Base inference and TOSA Main Inference generation in future.
188188

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
2727
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
2828
from .decompose_asin_pass import DecomposeAsinPass # noqa
29+
from .decompose_asinh_pass import DecomposeAsinhPass # noqa
2930
from .decompose_atan_pass import DecomposeAtanPass # noqa
3031
from .decompose_atanh_pass import DecomposeAtanhPass # noqa
3132
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
DecomposeAcoshPass,
3131
DecomposeAdaptiveAvgPool2dPass,
3232
DecomposeAddmmPass,
33+
DecomposeAsinhPass,
3334
DecomposeAsinPass,
3435
DecomposeAtanhPass,
3536
DecomposeAtanPass,
@@ -105,7 +106,7 @@ def _transform(self, graph_module: GraphModule):
105106
with TosaLoweringContext(self.tosa_spec):
106107
return self(graph_module).graph_module
107108

108-
def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
109+
def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
109110
self.add_pass(FuseQuantizedActivationPass())
110111
self.add_pass(RemoveGetItemPass())
111112
self.add_pass(ConvertSplitToSlicePass())
@@ -114,7 +115,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
114115
self.add_pass(
115116
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
116117
)
117-
118118
self.add_pass(ConvertFullLikeToFullPass())
119119
self.add_pass(ConvertToClampPass())
120120
self.add_pass(ConvertMinMaxPass())
@@ -148,7 +148,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
148148
self.add_pass(DecomposeMaxPool2DPass())
149149
self.add_pass(SizeAdjustInputPass())
150150
self.add_pass(DecomposeSelectPass())
151-
152151
self.add_pass(ConvertSqueezesToViewPass())
153152

154153
self.add_pass(FuseViewCopyTransform())
@@ -162,11 +161,12 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
162161

163162
return self._transform(exported_program.graph_module)
164163

165-
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
164+
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
166165
self.add_pass(DecomposeMaskedFill())
167166
self.add_pass(DecomposeRoundPass())
168167
self.add_pass(DecomposeAcoshPass())
169168
self.add_pass(DecomposeAsinPass())
169+
self.add_pass(DecomposeAsinhPass())
170170
self.add_pass(DecomposeSqrtPass())
171171
self.add_pass(DecomposeAtanPass())
172172
self.add_pass(DecomposeAtanhPass())
@@ -235,22 +235,12 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
235235

236236
return self._transform(exported_program.graph_module)
237237

238-
def _tosa_1_0_int_quantized_pipeline(self, exported_program: ExportedProgram):
239-
return self._tosa_080_BI_pipeline(exported_program)
240-
241-
def _tosa_1_0_fp_pipeline(self, exported_program: ExportedProgram):
242-
return self._tosa_080_MI_pipeline(exported_program)
243-
244238
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
245239
"""Apply passes before transforming program to backend"""
246-
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"):
247-
return self._tosa_080_BI_pipeline(exported_program)
248-
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"):
249-
return self._tosa_080_MI_pipeline(exported_program)
250-
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
251-
return self._tosa_1_0_fp_pipeline(exported_program)
240+
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
241+
return self._tosa_FP_pipeline(exported_program)
252242
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
253-
return self._tosa_1_0_int_quantized_pipeline(exported_program)
243+
return self._tosa_INT_pipeline(exported_program)
254244
else:
255245
raise NotImplementedError(
256246
f"No pass pipeline implemented for {self.tosa_spec=}"
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
12+
# For MI case
13+
edge_asinh_op = (exir_ops.edge.aten.asinh.default,)
14+
15+
16+
class DecomposeAsinhPass(ArmPass):
17+
"""
18+
Decomposes asinh to supported TOSA-operations.
19+
This decomposition is based on the mathematical identity:
20+
asinh(x) = log(x + sqrt(x^2 + 1))
21+
"""
22+
23+
def call_operator(self, op, args, kwargs, meta):
24+
if op not in edge_asinh_op:
25+
return super().call_operator(op, args, kwargs, meta)
26+
27+
log_op, sqrt_op, mul_op, add_op_scalar, add_op = (
28+
exir_ops.edge.aten.log.default,
29+
exir_ops.edge.aten.sqrt.default,
30+
exir_ops.edge.aten.mul.Tensor,
31+
exir_ops.edge.aten.add.Scalar,
32+
exir_ops.edge.aten.add.Tensor,
33+
)
34+
35+
x = args[0]
36+
37+
# calculate t1 = x^2 + 1
38+
x2 = super().call_operator(mul_op, (x, x), {}, meta, True)
39+
t1 = super().call_operator(add_op_scalar, (x2, 1.0), {}, meta, True)
40+
41+
# t2 = sqrt(t1)
42+
t2 = super().call_operator(sqrt_op, (t1,), {}, meta, True)
43+
44+
# t3 = x + t2
45+
t3 = super().call_operator(add_op, (x, t2), {}, meta, True)
46+
47+
# out = ln(t3)
48+
out = super().call_operator(log_op, (t3,), {}, meta, True)
49+
50+
return out

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77

88
import torch._export.utils
9+
import torch.fx
910
from executorch.backends.arm._passes.arm_pass_utils import (
1011
get_constant_placeholder_kind,
1112
get_first_fake_tensor,
@@ -50,22 +51,26 @@ def _fuse_nodes(self, node) -> bool:
5051
the operations already carried out on the data.
5152
"""
5253

53-
# Extract tensors and args from the node
54-
data_list = [
55-
get_param_tensor(self.exported_program, input_node)
56-
for input_node in node.all_input_nodes
57-
]
58-
59-
args = node.args[len(node.all_input_nodes) :]
60-
kwargs = node.kwargs
61-
62-
if "input_qparams" in node.meta and len(node.meta["input_qparams"]) > 0:
63-
for i in range(len(node.all_input_nodes)):
64-
q_params = node.meta["input_qparams"][i]
65-
data_list[i] = q_params.dequantize_value(data_list[i])
66-
67-
# Run the op on the extracted tensor
68-
data = node.target(*data_list, *args, **kwargs)
54+
input_nodes = list(node.all_input_nodes)
55+
qparams = node.meta.get("input_qparams", None)
56+
57+
def resolve_arg(arg):
58+
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
59+
idx = input_nodes.index(arg)
60+
t = get_param_tensor(self.exported_program, arg)
61+
if qparams:
62+
t = qparams[idx].dequantize_value(t)
63+
return t
64+
if isinstance(arg, tuple):
65+
return tuple(resolve_arg(x) for x in arg)
66+
if isinstance(arg, list):
67+
return [resolve_arg(x) for x in arg]
68+
return arg
69+
70+
new_args = tuple(resolve_arg(a) for a in node.args)
71+
new_kwargs = {k: resolve_arg(v) for k, v in node.kwargs.items()}
72+
73+
data = node.target(*new_args, **new_kwargs)
6974

7075
# Only fuse if the tensor does not get bigger.
7176
if data.numel() > get_first_fake_tensor(node).numel():

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class TableOps:
5858
exir_ops.edge.aten.sinh.default: torch.sinh,
5959
exir_ops.edge.aten.acosh.default: torch.acosh,
6060
exir_ops.edge.aten.asin.default: torch.asin,
61+
exir_ops.edge.aten.asinh.default: torch.asinh,
6162
}
6263

6364
# Targets that must be treated explicitly

backends/arm/operator_support/convolution_support.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck):
2121
targets = [exir_ops.edge.aten.convolution.default]
2222

2323
tosa_specs = [
24-
TosaSpecification.create_from_string("TOSA-0.80+BI"),
25-
TosaSpecification.create_from_string("TOSA-0.80+MI"),
2624
TosaSpecification.create_from_string("TOSA-1.0+INT"),
2725
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2826
]

backends/arm/operator_support/embedding_support.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck):
2020
targets = [exir_ops.edge.aten.embedding.default]
2121

2222
tosa_specs = [
23-
TosaSpecification.create_from_string("TOSA-0.80+BI"),
24-
TosaSpecification.create_from_string("TOSA-0.80+MI"),
2523
TosaSpecification.create_from_string("TOSA-1.0+INT"),
2624
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2725
]

backends/arm/operator_support/index_select_support.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ class IndexSelectSupported(SupportedTOSAOperatorCheck):
1818
targets = [exir_ops.edge.aten.index_select.default]
1919

2020
tosa_specs = [
21-
TosaSpecification.create_from_string("TOSA-0.80+BI"),
22-
TosaSpecification.create_from_string("TOSA-0.80+MI"),
2321
TosaSpecification.create_from_string("TOSA-1.0+INT"),
2422
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2523
]

backends/arm/operator_support/index_tensor_support.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
100100
targets = [exir_ops.edge.aten.index.Tensor]
101101

102102
tosa_specs = [
103-
TosaSpecification.create_from_string("TOSA-0.80+BI"),
104-
TosaSpecification.create_from_string("TOSA-0.80+MI"),
105103
TosaSpecification.create_from_string("TOSA-1.0+INT"),
106104
TosaSpecification.create_from_string("TOSA-1.0+FP"),
107105
]

0 commit comments

Comments
 (0)