Skip to content

Commit cea5f03

Browse files
authored
Merge branch 'main' into jz/fix-rpath-apple
2 parents ec0d1f1 + 0d0769a commit cea5f03

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1445
-956
lines changed

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .cast_to_int32_pass import CastToInt32Pass # noqa
1515
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
1616
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
17+
from .convert_elu_params import ConvertELUParamsPass # noqa
1718
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
1819
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
1920
from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa
@@ -36,6 +37,7 @@
3637
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
3738
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
3839
from .decompose_div_pass import DecomposeDivPass # noqa
40+
from .decompose_elu_pass import DecomposeEluPass # noqa
3941
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
4042
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
4143
from .decompose_gelu_pass import DecomposeGeluPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ComputeConstantOpsAOT,
1919
Conv1dUnsqueezePass,
2020
ConvertAnyDefaultDimDimsPass,
21+
ConvertELUParamsPass,
2122
ConvertExpandCopyToRepeatPass,
2223
ConvertFullLikeToFullPass,
2324
ConvertInt64ConstOpsToInt32Pass,
@@ -41,6 +42,7 @@
4142
DecomposeCosineSimilarityPass,
4243
DecomposeCumsumPass,
4344
DecomposeDivPass,
45+
DecomposeEluPass,
4446
DecomposeEmbeddingPass,
4547
DecomposeExpm1Pass,
4648
DecomposeGeluPass,
@@ -135,6 +137,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
135137
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
136138
self.add_pass(AnnotateDecomposedMatmulPass())
137139
self.add_pass(QuantizeOperatorArguments())
140+
self.add_pass(ConvertELUParamsPass())
138141
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
139142
self.add_pass(RetraceFoldedDtypesPass())
140143
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
@@ -183,6 +186,8 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
183186
self.add_pass(DecomposeAtanPass())
184187
self.add_pass(DecomposeAtanhPass())
185188
self.add_pass(DecomposeAddmmPass())
189+
self.add_pass(DecomposeEluPass())
190+
self.add_pass(DecomposeExpm1Pass())
186191
self.add_pass(ConvertIntPowToMuls())
187192
self.add_pass(CastBoolToInt8Pass())
188193
self.add_pass(DecomposeSinhPass())
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.arm_pass_utils import create_node
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class ConvertELUParamsPass(ExportPass):
13+
"""
14+
Pass to convert the input_scale kwarg of ELU operator from float to
15+
int.
16+
17+
It has been set to 2 as the outputs seem to stay the same regardless of what
18+
the value of input_scale is, as long as that value is not 1.
19+
"""
20+
21+
def call(self, graph_module: torch.fx.GraphModule):
22+
modified_graph = False
23+
graph = graph_module.graph
24+
node_list = graph.find_nodes(
25+
op="call_function", target=exir_ops.edge.aten.elu.default
26+
)
27+
for node in node_list:
28+
with graph.inserting_after(node):
29+
replace_node = create_node(graph, exir_ops.edge.aten.elu.default)
30+
old_args = list(node.args)
31+
32+
alpha = old_args[1] if len(old_args) > 1 else 1.0
33+
scale = 1.0
34+
input_scale = 2.0
35+
36+
replace_node.args = (old_args[0],)
37+
38+
updated_kwargs = dict(node.kwargs)
39+
updated_kwargs["alpha"] = int(alpha)
40+
updated_kwargs["scale"] = int(scale)
41+
updated_kwargs["input_scale"] = int(input_scale)
42+
43+
replace_node.kwargs = updated_kwargs
44+
45+
node.replace_all_uses_with(replace_node)
46+
graph.erase_node(node)
47+
48+
modified_graph = True
49+
if modified_graph:
50+
graph_module.recompile()
51+
graph_module = super().call(graph_module).graph_module
52+
53+
return PassResult(graph_module, modified_graph)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm._passes import ArmPass
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
9+
edge_elu_ops = (exir_ops.edge.aten.elu.default,)
10+
11+
12+
def get_elu_decomposition(op) -> tuple:
13+
"""
14+
Returns the decomposition of the given aten.elu operation into
15+
its equivalent TOSA-supported operations
16+
17+
This handles both edge dialect ops and core PyTorch ops. The decomposition strategy
18+
is:
19+
elu(x, y) → where(greater_or_eq(x, 0), (exp(x)-1), x)
20+
21+
Returns:
22+
A tuple (expm1_op, ge_op, where_op, mul_op) corresponding to the appropriate operator
23+
overloads for the input op.
24+
25+
Raises:
26+
RuntimeError: If the provided operator is not a supported elu variant.
27+
"""
28+
29+
if op in edge_elu_ops:
30+
return (
31+
exir_ops.edge.aten.expm1.default,
32+
exir_ops.edge.aten.ge.Scalar,
33+
exir_ops.edge.aten.where.self,
34+
exir_ops.edge.aten.mul.Scalar,
35+
)
36+
37+
raise RuntimeError(f"Can't get elu decomposition for op {op}")
38+
39+
40+
class DecomposeEluPass(ArmPass):
41+
"""
42+
A transformation pass that decomposes unsupported 'aten.elu' operations
43+
into a combination of supported TOSA-equivalent operations.
44+
45+
Since TOSA does not provide a native ELU operator, this pass rewrites:
46+
elu(x) → where(greater_or_eq(x, 0), (alpha*(exp(x)-1)), x)
47+
48+
Supported input ops:
49+
- exir_ops.edge.aten.elu.Tensor(x)
50+
51+
These are replaced with:
52+
- exir_ops.edge.aten.expm1.default
53+
- exir_ops.edge.aten.ge.Scalar
54+
- exir_ops.edge.aten.where.self
55+
- exir_ops.edge.aten.mul.Scalar
56+
"""
57+
58+
def call_operator(self, op, args, kwargs, meta):
59+
if op not in edge_elu_ops:
60+
return super().call_operator(op, args, kwargs, meta, updated=False)
61+
62+
(
63+
expm1_op,
64+
ge_op,
65+
where_op,
66+
mul_op,
67+
) = get_elu_decomposition(op)
68+
69+
input = args[0]
70+
alpha = args[1] if len(args) > 1 else 1.0
71+
72+
if alpha == 0:
73+
relu_op = exir_ops.edge.aten.relu.default
74+
return super().call_operator(relu_op, (input,), {}, meta, updated=True)
75+
76+
expm1_node = super().call_operator(expm1_op, (input,), {}, meta, updated=True)
77+
mul_node = super().call_operator(
78+
mul_op, (expm1_node, alpha), {}, meta, updated=True
79+
)
80+
ge_node = super().call_operator(ge_op, (input, 0.0), {}, meta, updated=True)
81+
where_node = super().call_operator(
82+
where_op, (ge_node, input, mul_node), {}, meta, updated=True
83+
)
84+
85+
return where_node

backends/arm/_passes/insert_table_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class TableOps:
5959
special_table_ops: Set[EdgeOpOverload] = {
6060
exir_ops.edge.aten.pow.Tensor_Scalar,
6161
exir_ops.edge.aten.gelu.default,
62+
exir_ops.edge.aten.elu.default,
6263
}
6364

6465
def __init__(self, exported_program: ExportedProgram):
@@ -92,6 +93,11 @@ def __getitem__(self, node: Node):
9293
return lambda x: torch.nn.functional.gelu(
9394
x, approximate=approximate
9495
).flatten()
96+
case exir_ops.edge.aten.elu.default:
97+
input_alpha = cast(int, node.kwargs["alpha"])
98+
return lambda x: torch.nn.functional.elu(
99+
x, alpha=input_alpha
100+
).flatten()
95101
case _:
96102
# Op must be handled if it's inside self.special_ops
97103
raise AssertionError("Unhandled table operation")

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
get_first_fake_tensor,
1313
is_param_node,
1414
)
15-
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1615
from executorch.exir import ExportedProgram
1716
from executorch.exir.dialects._ops import ops as exir_ops
1817
from executorch.exir.pass_base import ExportPass, PassResult
@@ -43,6 +42,19 @@ def __init__(self, exported_program: ExportedProgram) -> None:
4342
self.exported_program = exported_program
4443
super().__init__()
4544

45+
@staticmethod
46+
def _is_consumer_node_depthwise_conv2d(node: torch.fx.Node):
47+
consumer_node = list(node.users)[0]
48+
if consumer_node.target == exir_ops.edge.aten.convolution.default:
49+
consumer_node_inputs = consumer_node.all_input_nodes
50+
groups = consumer_node.args[-1]
51+
in_channels = consumer_node_inputs[0].meta["val"].shape[1]
52+
out_channels = consumer_node_inputs[1].meta["val"].shape[0]
53+
if (in_channels == groups) and (out_channels % in_channels) == 0:
54+
return True
55+
56+
return False
57+
4658
def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
4759
"""
4860
returns True for w in the following sequence;
@@ -53,7 +65,7 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
5365
consumer_node = list(node.users)[0]
5466
if self.is_weight_node_for_depthwise_conv2d(consumer_node):
5567
return True
56-
if is_consumer_node_depthwise_conv2d(node):
68+
if self._is_consumer_node_depthwise_conv2d(node):
5769
# Check that node is the weight-argument and not input or bias
5870
return consumer_node.args[1] == node
5971

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def is_node_supported(
263263
exir_ops.edge.aten.glu.default,
264264
exir_ops.edge.aten.logit.default,
265265
exir_ops.edge.aten.acos.default,
266+
exir_ops.edge.aten.elu.default,
266267
]
267268

268269
return supported

backends/arm/operators/op_add.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@ def define_node(
5353
[ts.DType.INT8, ts.DType.INT32],
5454
output.tosa_spec,
5555
)
56-
5756
scale_back = 1.0
5857
if inputs[0].dtype == ts.DType.INT8:
59-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
58+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
6059
tosa_graph, inputs, node, self.tosa_spec
6160
)
6261
else:
@@ -85,7 +84,12 @@ def define_node(
8584
# Scale output back to 8 bit
8685
# pyre-ignore
8786
tqutils.insert_rescale_op_to_int8(
88-
tosa_graph, add_output, scale_back, node, self.tosa_spec
87+
tosa_graph,
88+
add_output,
89+
scale_back,
90+
node,
91+
compute_rescale=False,
92+
tosa_spec=self.tosa_spec,
8993
) # type: ignore[possibly-undefined]
9094

9195

backends/arm/operators/op_sub.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def define_node(
5656

5757
scale_back = 1.0
5858
if inputs[0].dtype == ts.DType.INT8:
59-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
59+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
6060
tosa_graph, inputs, node, self.tosa_spec
6161
)
6262
else:
@@ -86,7 +86,12 @@ def define_node(
8686
# Scale output back to 8 bit
8787
# pyre-ignore
8888
tqutils.insert_rescale_op_to_int8(
89-
tosa_graph, sub_output, scale_back, node, self.tosa_spec
89+
tosa_graph,
90+
sub_output,
91+
scale_back,
92+
node,
93+
compute_rescale=False,
94+
tosa_spec=self.tosa_spec,
9095
) # type: ignore[possibly-undefined]
9196

9297

backends/arm/quantizer/quantization_annotator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def _match_pattern(
266266
torch.ops.aten.erf.default,
267267
torch.ops.aten.exp.default,
268268
torch.ops.aten.expm1.default,
269+
torch.ops.aten.elu.default,
269270
torch.ops.aten.floor.default,
270271
torch.ops.aten.log.default,
271272
torch.ops.aten.reciprocal.default,
@@ -472,6 +473,10 @@ def any_or_hardtanh_min_zero(n: Node):
472473
]
473474
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
474475
elif node.target in (
476+
torch.ops.aten.add.Tensor,
477+
torch.ops.aten.add_.Tensor,
478+
torch.ops.aten.sub.Tensor,
479+
torch.ops.aten.sub_.Tensor,
475480
torch.ops.aten.matmul.default,
476481
torch.ops.aten.mm.default,
477482
torch.ops.aten.bmm.default,
@@ -484,10 +489,6 @@ def any_or_hardtanh_min_zero(n: Node):
484489
]
485490
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
486491
elif node.target in (
487-
torch.ops.aten.add.Tensor,
488-
torch.ops.aten.add_.Tensor,
489-
torch.ops.aten.sub.Tensor,
490-
torch.ops.aten.sub_.Tensor,
491492
torch.ops.aten.minimum.default,
492493
torch.ops.aten.maximum.default,
493494
):

0 commit comments

Comments
 (0)