Skip to content

Commit 7e1a002

Browse files
committed
Update
[ghstack-poisoned]
2 parents 1697cbc + b66072c commit 7e1a002

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+2791
-337
lines changed

backends/arm/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,19 @@ python_library(
2121
"//executorch/exir/dialects:lib",
2222
],
2323
)
24+
python_library(
25+
name = "common",
26+
srcs = [
27+
"common/__init__.py",
28+
"common/debug.py",
29+
],
30+
deps = [
31+
"fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/serializer:serializer",
32+
"fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/serializer:serializer",
33+
"//caffe2:torch",
34+
"//executorch/exir:lib",
35+
],
36+
)
2437
python_library(
2538
name = "arm_partitioner",
2639
srcs = [

backends/arm/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ python_library(
44
name = "passes",
55
srcs = glob(["*.py"]),
66
deps = [
7+
"//executorch/backends/arm:common",
78
"//executorch/backends/arm:constants",
89
"//executorch/backends/arm:tosa_quant_utils",
910
"//executorch/backends/arm:tosa_utils",

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
3636
from .decompose_div_pass import DecomposeDivPass # noqa
3737
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
38+
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
3839
from .decompose_gelu_pass import DecomposeGeluPass # noqa
3940
from .decompose_glu_pass import DecomposeGluPass # noqa
4041
from .decompose_grouped_conv import DecomposeGroupedConv # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
DecomposeCosineSimilarityPass,
4141
DecomposeDivPass,
4242
DecomposeEmbeddingPass,
43+
DecomposeExpm1Pass,
4344
DecomposeGeluPass,
4445
DecomposeGluPass,
4546
DecomposeGroupedConv,
@@ -164,6 +165,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
164165
return self._transform(exported_program.graph_module)
165166

166167
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
168+
self.add_pass(DecomposeExpm1Pass())
167169
self.add_pass(DecomposeMaskedFill())
168170
self.add_pass(DecomposeRoundPass())
169171
self.add_pass(DecomposeAcoshPass())
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm._passes import ArmPass
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
9+
10+
edge_expm1_ops = (exir_ops.edge.aten.expm1.default,) # MI case
11+
12+
13+
def _get_expm1_decomposition(op) -> tuple:
14+
"""
15+
Returns the decomposition of the given aten.expm1 operation into
16+
its equivalent TOSA-supported operations
17+
18+
This handles both edge dialect ops and core PyTorch ops. The decomposition strategy
19+
is:
20+
expm1(x) → where(and(ge(x, -0.35), le(x, 0.35)), {taylor_series_expansion}, (exp(x)-1))
21+
22+
where {taylor_series_expansion} = x + (x^2/2) + (x^3/6) + (x^4/24)
23+
24+
Returns:
25+
A tuple (op_pow, op_div, op_add, op_exp, op_sub, op_ge, op_where, op_le, op_and)
26+
corresponding to the appropriate operator overloads for the input op.
27+
28+
Raises:
29+
RuntimeError: If the provided operator is not a supported elu variant.
30+
"""
31+
if op in edge_expm1_ops:
32+
return (
33+
exir_ops.edge.aten.pow.Tensor_Scalar,
34+
exir_ops.edge.aten.div.Scalar,
35+
exir_ops.edge.aten.add.Tensor,
36+
exir_ops.edge.aten.exp.default,
37+
exir_ops.edge.aten.sub.Scalar,
38+
exir_ops.edge.aten.ge.Scalar,
39+
exir_ops.edge.aten.where.self,
40+
exir_ops.edge.aten.le.Scalar,
41+
exir_ops.edge.aten.logical_and.default,
42+
)
43+
44+
raise RuntimeError(f"Can't get expm1 decomposition for op {op}")
45+
46+
47+
class DecomposeExpm1Pass(ArmPass):
48+
"""
49+
A transformation pass that decomposes unsupported 'aten.expm1' operations
50+
into a combination of supported TOSA-equivalent operations.
51+
52+
Since TOSA does not provide a native expm1 operator, this pass rewrites:
53+
expm1(x) → where(and(ge(x, -0.35), le(x, 0.35)), {taylor_series_expansion}, (exp(x)-1))
54+
where {taylor_series_expansion} = x + (x^2/2) + (x^3/6) + (x^4/24)
55+
56+
Supported input ops:
57+
- exir_ops.edge.aten.expm1.default(x)
58+
59+
These are replaced with:
60+
- exir_ops.edge.aten.pow.Tensor_Scalar,
61+
- exir_ops.edge.aten.div.Scalar,
62+
- exir_ops.edge.aten.add.Tensor,
63+
- exir_ops.edge.aten.exp.default,
64+
- exir_ops.edge.aten.sub.Scalar,
65+
- exir_ops.edge.aten.ge.Scalar,
66+
- exir_ops.edge.aten.where.self,
67+
- exir_ops.edge.aten.le.Scalar,
68+
- exir_ops.edge.aten.logical_and.default
69+
"""
70+
71+
def call_operator(self, op, args, kwargs, meta):
72+
if op not in edge_expm1_ops:
73+
return super().call_operator(op, args, kwargs, meta, updated=False)
74+
75+
(
76+
op_pow,
77+
op_div,
78+
op_add,
79+
op_exp,
80+
op_sub,
81+
op_ge,
82+
op_where,
83+
op_le,
84+
op_and,
85+
) = _get_expm1_decomposition(op)
86+
87+
input = args[0]
88+
89+
cutlo = -0.35
90+
cuthi = 0.35
91+
92+
taylor_term_2_numerator = super().call_operator(
93+
op_pow, (input, 2), {}, meta, updated=False
94+
)
95+
taylor_term_3_numerator = super().call_operator(
96+
op_pow, (input, 3), {}, meta, updated=False
97+
)
98+
taylor_term_4_numerator = super().call_operator(
99+
op_pow, (input, 4), {}, meta, updated=False
100+
)
101+
102+
taylor_term_2 = super().call_operator(
103+
op_div, (taylor_term_2_numerator, 2), {}, meta, updated=False
104+
)
105+
taylor_term_3 = super().call_operator(
106+
op_div, (taylor_term_3_numerator, 6), {}, meta, updated=False
107+
)
108+
taylor_term_4 = super().call_operator(
109+
op_div, (taylor_term_4_numerator, 24), {}, meta, updated=False
110+
)
111+
112+
add_terms_1_2 = super().call_operator(
113+
op_add, (input, taylor_term_2), {}, meta, updated=False
114+
)
115+
add_term_3 = super().call_operator(
116+
op_add, (add_terms_1_2, taylor_term_3), {}, meta, updated=False
117+
)
118+
taylor_expansion = super().call_operator(
119+
op_add, (add_term_3, taylor_term_4), {}, meta, updated=False
120+
)
121+
122+
decomp_exp = super().call_operator(op_exp, (input,), {}, meta, updated=False)
123+
decomp_sub = super().call_operator(
124+
op_sub, (decomp_exp, 1.0), {}, meta, updated=False
125+
)
126+
127+
ge = super().call_operator(op_ge, (input, cutlo), {}, meta, updated=False)
128+
le = super().call_operator(op_le, (input, cuthi), {}, meta, updated=False)
129+
130+
cond_and = super().call_operator(op_and, (ge, le), {}, meta, updated=False)
131+
where = super().call_operator(
132+
op_where, (cond_and, taylor_expansion, decomp_sub), {}, meta, updated=True
133+
)
134+
135+
return where

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class TableOps:
4343
exir_ops.edge.aten.ceil.default: torch.ceil,
4444
exir_ops.edge.aten.erf.default: torch.erf,
4545
exir_ops.edge.aten.exp.default: torch.exp,
46+
exir_ops.edge.aten.expm1.default: torch.expm1,
4647
exir_ops.edge.aten.floor.default: torch.floor,
4748
exir_ops.edge.aten.log.default: torch.log,
4849
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,

backends/arm/arm_backend.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,6 @@ def is_vgf(compile_spec: List[CompileSpec]) -> bool:
217217
return False
218218

219219

220-
def get_tosa_spec(compile_spec: List[CompileSpec]) -> TosaSpecification:
221-
for spec in compile_spec:
222-
if spec.key == "tosa_spec":
223-
return TosaSpecification.create_from_string(spec.value.decode())
224-
raise ValueError("Could not find TOSA version in CompileSpec")
225-
226-
227220
def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]:
228221
for spec in compile_spec:
229222
if spec.key == "debug_artifact_path":

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def is_node_supported(
179179
exir_ops.edge.aten.eq.Scalar,
180180
exir_ops.edge.aten.erf.default,
181181
exir_ops.edge.aten.exp.default,
182+
exir_ops.edge.aten.expm1.default,
182183
exir_ops.edge.aten.log.default,
183184
exir_ops.edge.aten.linear.default,
184185
exir_ops.edge.aten.split_with_sizes_copy.default,

backends/arm/process_node.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.backends.arm.operators.node_visitor import NodeVisitor
1515
from executorch.backends.arm.tosa_mapping import TosaArg
1616
from executorch.backends.arm.tosa_specification import TosaSpecification
17-
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
17+
from executorch.backends.arm.tosa_utils import tosa_shape
1818
from torch._export.utils import (
1919
get_buffer,
2020
get_lifted_tensor_constant,
@@ -33,7 +33,10 @@ def process_call_function(
3333
tosa_spec: TosaSpecification,
3434
):
3535
# Unpack arguments and convert
36-
inputs = getNodeArgs(node, tosa_spec)
36+
try:
37+
inputs = [TosaArg(arg, tosa_spec) for arg in node.args]
38+
except ValueError as e:
39+
raise ValueError(f"Failed processing args to op:\n{node}") from e
3740

3841
# Convert output (this node itself)
3942
try:

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
from executorch.backends.arm._passes import ArmPassManager
2121

2222
from executorch.backends.arm.quantizer import QuantizationConfig
23-
from executorch.backends.arm.tosa_specification import TosaSpecification
23+
from executorch.backends.arm.tosa_specification import get_tosa_spec, TosaSpecification
2424

2525
from .arm_quantizer_utils import is_annotated, mark_node_as_annotated
2626
from .quantization_annotator import annotate_graph
2727
from executorch.backends.arm.arm_backend import (
28-
get_tosa_spec,
2928
is_ethosu,
3029
is_vgf,
3130
) # usort: skip

0 commit comments

Comments
 (0)