Skip to content

Commit 00d39d8

Browse files
committed
Update on "make to_edge_transform_and_lower support etrecord generation"
Differential Revision: [D79336982](https://our.internmc.facebook.com/intern/diff/D79336982/) umbrella issue: #12961 [ghstack-poisoned]
2 parents 6b37502 + d4a22a1 commit 00d39d8

File tree

258 files changed

+3184
-7003
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

258 files changed

+3184
-7003
lines changed

backends/arm/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ The Arm EthosU Backend should be considered a prototype quality at this point, l
181181
## Current flows
182182

183183
The EthosUBackend has a two stage process,
184-
- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v0.80 TOSA BI with specific concern to a subset which gives support on Ethos-U55 and Ethos-U85, the target of the initial prototype efforts. This calls into the TOSABackend.
185-
- Lower via the ethos-u-vela compilation flow which takes TOSA v0.80 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution.
184+
- Compile to TOSA to rationalise the graph into known hardware support profiles. Currently this is to v1.0 TOSA INT with specific concern to a subset which gives support on Ethos-U55 and Ethos-U85, the target of the initial prototype efforts. This calls into the TOSABackend.
185+
- Lower via the ethos-u-vela compilation flow which takes TOSA v1.0 as an input and produces a low level commandstream for the hardware which is then passed via the delegate to the ethos-u-core-driver for direct execution.
186186

187187
The EthosUPartitioner is currenly used to ensure the operations converted are Ethos-U compatible, but will be extended to offer spec-correct TOSA Base inference and TOSA Main Inference generation in future.
188188

backends/arm/TARGETS

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
# @noautodeps
22
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
3+
4+
python_library(
5+
name = "ethosu_partitioner",
6+
srcs = [
7+
"ethosu/__init__.py",
8+
"ethosu/backend.py",
9+
"ethosu/partitioner.py"
10+
],
11+
deps = [
12+
":arm_partitioner",
13+
]
14+
)
315
python_library(
416
name = "arm_partitioner",
517
srcs = [
6-
"ethosu_backend.py",
7-
"ethosu_partitioner.py",
818
"tosa_backend.py",
919
"tosa_partitioner.py",
1020
"vgf_backend.py",

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
2727
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
2828
from .decompose_asin_pass import DecomposeAsinPass # noqa
29+
from .decompose_asinh_pass import DecomposeAsinhPass # noqa
2930
from .decompose_atan_pass import DecomposeAtanPass # noqa
3031
from .decompose_atanh_pass import DecomposeAtanhPass # noqa
3132
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
DecomposeAcoshPass,
3131
DecomposeAdaptiveAvgPool2dPass,
3232
DecomposeAddmmPass,
33+
DecomposeAsinhPass,
3334
DecomposeAsinPass,
3435
DecomposeAtanhPass,
3536
DecomposeAtanPass,
@@ -105,7 +106,7 @@ def _transform(self, graph_module: GraphModule):
105106
with TosaLoweringContext(self.tosa_spec):
106107
return self(graph_module).graph_module
107108

108-
def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
109+
def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
109110
self.add_pass(FuseQuantizedActivationPass())
110111
self.add_pass(RemoveGetItemPass())
111112
self.add_pass(ConvertSplitToSlicePass())
@@ -114,7 +115,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
114115
self.add_pass(
115116
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
116117
)
117-
118118
self.add_pass(ConvertFullLikeToFullPass())
119119
self.add_pass(ConvertToClampPass())
120120
self.add_pass(ConvertMinMaxPass())
@@ -148,7 +148,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
148148
self.add_pass(DecomposeMaxPool2DPass())
149149
self.add_pass(SizeAdjustInputPass())
150150
self.add_pass(DecomposeSelectPass())
151-
152151
self.add_pass(ConvertSqueezesToViewPass())
153152

154153
self.add_pass(FuseViewCopyTransform())
@@ -162,11 +161,12 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
162161

163162
return self._transform(exported_program.graph_module)
164163

165-
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
164+
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
166165
self.add_pass(DecomposeMaskedFill())
167166
self.add_pass(DecomposeRoundPass())
168167
self.add_pass(DecomposeAcoshPass())
169168
self.add_pass(DecomposeAsinPass())
169+
self.add_pass(DecomposeAsinhPass())
170170
self.add_pass(DecomposeSqrtPass())
171171
self.add_pass(DecomposeAtanPass())
172172
self.add_pass(DecomposeAtanhPass())
@@ -235,22 +235,12 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
235235

236236
return self._transform(exported_program.graph_module)
237237

238-
def _tosa_1_0_int_quantized_pipeline(self, exported_program: ExportedProgram):
239-
return self._tosa_080_BI_pipeline(exported_program)
240-
241-
def _tosa_1_0_fp_pipeline(self, exported_program: ExportedProgram):
242-
return self._tosa_080_MI_pipeline(exported_program)
243-
244238
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
245239
"""Apply passes before transforming program to backend"""
246-
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"):
247-
return self._tosa_080_BI_pipeline(exported_program)
248-
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"):
249-
return self._tosa_080_MI_pipeline(exported_program)
250-
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
251-
return self._tosa_1_0_fp_pipeline(exported_program)
240+
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
241+
return self._tosa_FP_pipeline(exported_program)
252242
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
253-
return self._tosa_1_0_int_quantized_pipeline(exported_program)
243+
return self._tosa_INT_pipeline(exported_program)
254244
else:
255245
raise NotImplementedError(
256246
f"No pass pipeline implemented for {self.tosa_spec=}"
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
12+
# For MI case
13+
edge_asinh_op = (exir_ops.edge.aten.asinh.default,)
14+
15+
16+
class DecomposeAsinhPass(ArmPass):
17+
"""
18+
Decomposes asinh to supported TOSA-operations.
19+
This decomposition is based on the mathematical identity:
20+
asinh(x) = log(x + sqrt(x^2 + 1))
21+
"""
22+
23+
def call_operator(self, op, args, kwargs, meta):
24+
if op not in edge_asinh_op:
25+
return super().call_operator(op, args, kwargs, meta)
26+
27+
log_op, sqrt_op, mul_op, add_op_scalar, add_op = (
28+
exir_ops.edge.aten.log.default,
29+
exir_ops.edge.aten.sqrt.default,
30+
exir_ops.edge.aten.mul.Tensor,
31+
exir_ops.edge.aten.add.Scalar,
32+
exir_ops.edge.aten.add.Tensor,
33+
)
34+
35+
x = args[0]
36+
37+
# calculate t1 = x^2 + 1
38+
x2 = super().call_operator(mul_op, (x, x), {}, meta, True)
39+
t1 = super().call_operator(add_op_scalar, (x2, 1.0), {}, meta, True)
40+
41+
# t2 = sqrt(t1)
42+
t2 = super().call_operator(sqrt_op, (t1,), {}, meta, True)
43+
44+
# t3 = x + t2
45+
t3 = super().call_operator(add_op, (x, t2), {}, meta, True)
46+
47+
# out = ln(t3)
48+
out = super().call_operator(log_op, (t3,), {}, meta, True)
49+
50+
return out

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77

88
import torch._export.utils
9+
import torch.fx
910
from executorch.backends.arm._passes.arm_pass_utils import (
1011
get_constant_placeholder_kind,
1112
get_first_fake_tensor,
@@ -50,22 +51,26 @@ def _fuse_nodes(self, node) -> bool:
5051
the operations already carried out on the data.
5152
"""
5253

53-
# Extract tensors and args from the node
54-
data_list = [
55-
get_param_tensor(self.exported_program, input_node)
56-
for input_node in node.all_input_nodes
57-
]
58-
59-
args = node.args[len(node.all_input_nodes) :]
60-
kwargs = node.kwargs
61-
62-
if "input_qparams" in node.meta and len(node.meta["input_qparams"]) > 0:
63-
for i in range(len(node.all_input_nodes)):
64-
q_params = node.meta["input_qparams"][i]
65-
data_list[i] = q_params.dequantize_value(data_list[i])
66-
67-
# Run the op on the extracted tensor
68-
data = node.target(*data_list, *args, **kwargs)
54+
input_nodes = list(node.all_input_nodes)
55+
qparams = node.meta.get("input_qparams", None)
56+
57+
def resolve_arg(arg):
58+
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
59+
idx = input_nodes.index(arg)
60+
t = get_param_tensor(self.exported_program, arg)
61+
if qparams:
62+
t = qparams[idx].dequantize_value(t)
63+
return t
64+
if isinstance(arg, tuple):
65+
return tuple(resolve_arg(x) for x in arg)
66+
if isinstance(arg, list):
67+
return [resolve_arg(x) for x in arg]
68+
return arg
69+
70+
new_args = tuple(resolve_arg(a) for a in node.args)
71+
new_kwargs = {k: resolve_arg(v) for k, v in node.kwargs.items()}
72+
73+
data = node.target(*new_args, **new_kwargs)
6974

7075
# Only fuse if the tensor does not get bigger.
7176
if data.numel() > get_first_fake_tensor(node).numel():

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class TableOps:
5858
exir_ops.edge.aten.sinh.default: torch.sinh,
5959
exir_ops.edge.aten.acosh.default: torch.acosh,
6060
exir_ops.edge.aten.asin.default: torch.asin,
61+
exir_ops.edge.aten.asinh.default: torch.asinh,
6162
}
6263

6364
# Targets that must be treated explicitly

backends/arm/ethosu/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
# pyre-unsafe
7+
8+
from .backend import EthosUBackend # noqa: F401
9+
from .partitioner import EthosUPartitioner # noqa: F401
10+
11+
__all__ = [
12+
"EthosUBackend",
13+
"EthosUPartitioner",
14+
]
File renamed without changes.

backends/arm/ethosu_partitioner.py renamed to backends/arm/ethosu/partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm.arm_backend import (
1111
is_ethosu,
1212
) # usort: skip
13-
from executorch.backends.arm.ethosu_backend import EthosUBackend
13+
from executorch.backends.arm.ethosu import EthosUBackend
1414
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
1515
from executorch.exir.backend.compile_spec_schema import CompileSpec
1616
from executorch.exir.backend.partitioner import DelegationSpec

0 commit comments

Comments
 (0)