Skip to content

Commit c11e1bd

Browse files
committed
Update
[ghstack-poisoned]
2 parents abef683 + f673a4b commit c11e1bd

File tree

75 files changed

+1785
-510
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+1785
-510
lines changed

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ exclude_patterns = [
7676
'examples/demo-apps/apple_ios/**',
7777
'examples/demo-apps/react-native/rnllama/ios/**',
7878
'extension/apple/**',
79+
'extension/llm/apple/**',
7980
# File contains @generated
8081
'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h',
8182
'extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h',

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
2323
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2424
from .convert_to_clamp import ConvertToClampPass # noqa
25+
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
2526
from .decompose_atan_pass import DecomposeAtanPass # noqa
2627
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
2728
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ConvertSplitToSlicePass,
2626
ConvertSqueezesToViewPass,
2727
ConvertToClampPass,
28+
DecomposeAcoshPass,
2829
DecomposeAtanPass,
2930
DecomposeAvgPool2d,
3031
DecomposeBatchNormNoStatsPass,
@@ -151,6 +152,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
151152

152153
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
153154
self.add_pass(DecomposeRoundPass())
155+
self.add_pass(DecomposeAcoshPass())
154156
self.add_pass(DecomposeSqrtPass())
155157
self.add_pass(DecomposeAtanPass())
156158
self.add_pass(ConvertIntPowToMuls())
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from executorch.backends.arm._passes import ArmPass
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
11+
# For MI case
12+
edge_acosh_op = exir_ops.edge.aten.acosh.default
13+
14+
15+
class DecomposeAcoshPass(ArmPass):
16+
"""
17+
Decomposes acosh to supported TOSA-operations.
18+
This decomposition is based on the mathematical identity:
19+
acosh(x) = log(x + sqrt((x-1)(x+1))
20+
"""
21+
22+
def call_operator(self, op, args, kwargs, meta, updated=False):
23+
24+
if op is not edge_acosh_op:
25+
return super().call_operator(op, args, kwargs, meta, updated)
26+
27+
log_op, sqrt_op, mul_op, sub_op, add_op, add_op_scalar = (
28+
exir_ops.edge.aten.log.default,
29+
exir_ops.edge.aten.sqrt.default,
30+
exir_ops.edge.aten.mul.Tensor,
31+
exir_ops.edge.aten.sub.Scalar,
32+
exir_ops.edge.aten.add.Tensor,
33+
exir_ops.edge.aten.add.Scalar,
34+
)
35+
36+
x = args[0]
37+
38+
# (x-1)(x+1)
39+
sub = super().call_operator(sub_op, (x, 1.0), {}, meta, True)
40+
add = super().call_operator(add_op_scalar, (x, 1.0), {}, meta, True)
41+
mul = super().call_operator(mul_op, (sub, add), {}, meta, True)
42+
43+
# sqrt((x-1)(x+1))
44+
sqrt = super().call_operator(sqrt_op, (mul,), {}, meta, True)
45+
46+
# x + sqrt((x-1)(x+1))
47+
add = super().call_operator(add_op, (x, sqrt), {}, meta, True)
48+
49+
# out = ln(x + sqrt((x-1)(x+1))
50+
out = super().call_operator(log_op, (add,), {}, meta, True)
51+
52+
return out

backends/arm/_passes/decompose_sqrt_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
from typing import Tuple, Union
8+
79
import torch
810
from executorch.exir.dialects._ops import ops as exir_ops
911
from executorch.exir.pass_base import ExportPass
@@ -15,7 +17,7 @@
1517
)
1618

1719

18-
def get_sqrt_decomposition(op) -> tuple:
20+
def get_sqrt_decomposition(op) -> Union[Tuple, torch._ops.OpOverload]:
1921
# TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor"
2022
if op in edge_sqrt_ops:
2123
return exir_ops.edge.aten.pow.Tensor_Scalar

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class TableOps:
5555
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
5656
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
5757
exir_ops.edge.aten.sinh.default: torch.sinh,
58+
exir_ops.edge.aten.acosh.default: torch.acosh,
5859
}
5960

6061
# Targets that must be treated explicitly

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, exported_program):
5151
exir_ops.edge.aten.gt.Tensor,
5252
exir_ops.edge.aten.ge.Tensor,
5353
exir_ops.edge.aten.lt.Tensor,
54+
exir_ops.edge.aten.le.Tensor,
5455
exir_ops.edge.aten.pow.Tensor_Tensor,
5556
exir_ops.edge.aten.where.self,
5657
]

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# pyre-unsafe
77

88

9-
from typing import Dict
9+
from typing import Dict, Union
1010

1111
import torch
1212
from executorch.backends.transforms.replace_scalar_with_tensor import (
@@ -18,7 +18,10 @@
1818

1919

2020
# Operators that are included for both TOSA profiles
21-
_common_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
21+
_common_ops: Dict[
22+
Union[EdgeOpOverload, torch._ops.OpOverload],
23+
Union[EdgeOpOverload, torch._ops.OpOverload],
24+
] = {
2225
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
2326
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
2427
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
@@ -29,6 +32,7 @@
2932
exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor,
3033
exir_ops.edge.aten.ge.Scalar: exir_ops.edge.aten.ge.Tensor,
3134
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
35+
exir_ops.edge.aten.le.Scalar: exir_ops.edge.aten.le.Tensor,
3236
exir_ops.edge.aten.ne.Scalar: exir_ops.edge.aten.ne.Tensor,
3337
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3438
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
@@ -40,6 +44,7 @@
4044
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
4145
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
4246
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
47+
torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor,
4348
torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor,
4449
}
4550

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class EthosU55NotSupported(OperatorSupportBase):
138138
exir_ops.edge.aten.gt.Tensor,
139139
exir_ops.edge.aten.gt.Scalar,
140140
exir_ops.edge.aten.le.Tensor,
141+
exir_ops.edge.aten.le.Scalar,
141142
exir_ops.edge.aten.lt.Tensor,
142143
exir_ops.edge.aten.lt.Scalar,
143144
exir_ops.edge.aten.ne.Tensor,
@@ -174,6 +175,69 @@ def is_node_supported(
174175
shape_t = list[int]
175176

176177

178+
class EthosU55ViewCheck(OperatorSupportBase):
179+
180+
def __init__(self, reporter: WhyNoPartitionReporter):
181+
super().__init__()
182+
self.reporter = reporter
183+
184+
def axes_product(self, nhwc_shape: shape_t) -> int:
185+
product = 1
186+
for axes in nhwc_shape:
187+
product *= axes
188+
return product
189+
190+
# TODO: Extend this check to comply with u55 restrictions
191+
def is_node_supported(
192+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
193+
) -> bool:
194+
"""
195+
Check whether a given view node is supported on U55.
196+
197+
Currently only checks dtypes and product of axes.
198+
199+
It is not the view operator itself that is not supported on U55. In order for the
200+
view operator to be compatible with the channels-last format of TosaBackend,
201+
transposes may need to be inserted before and after the view op. If that happens
202+
and that transpose operator does not adhere to the limitations then it will
203+
result in the following error:
204+
205+
CPU performance estimation for "Transpose" not implemented.
206+
...
207+
CPU operations are not supported for GraphAPI input
208+
209+
Args:
210+
node: The FX node representing the view_copy operator.
211+
212+
Returns:
213+
False if the operator is not support and True if it is supported.
214+
"""
215+
if not node.target == exir_ops.edge.aten.view_copy.default:
216+
return True
217+
218+
shape = list(get_first_fake_tensor(node).shape)
219+
dtype = _try_determine_dtype(node)
220+
permutation = list(typing.cast(list[int], node.args[1]))
221+
222+
rank = len(shape)
223+
if rank > 4:
224+
if dtype == torch.int32:
225+
self.reporter.report_reject(
226+
node, f"No support for {permutation=} in int32."
227+
)
228+
return False
229+
230+
if dtype in (torch.int8, torch.int16):
231+
if self.axes_product(shape) > 65536:
232+
self.reporter.report_reject(
233+
node,
234+
f"No support for {shape=}, {dtype=}. Product of axes must be <65536",
235+
)
236+
return False
237+
238+
return True
239+
240+
177241
class EthosU55TransposeCheck(OperatorSupportBase):
178242

179243
def __init__(self, reporter: WhyNoPartitionReporter):

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
EthosU55DtypeSupport,
2424
EthosU55NotSupported,
2525
EthosU55TransposeCheck,
26+
EthosU55ViewCheck,
2627
)
2728
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
2829
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -133,6 +134,7 @@ def tosa_support_factory(
133134
negative_checks.append(EthosU55NotSupported(reporter))
134135
negative_checks.append(EthosU55DtypeSupport(reporter))
135136
negative_checks.append(EthosU55TransposeCheck(reporter))
137+
negative_checks.append(EthosU55ViewCheck(reporter))
136138

137139
return chain(
138140
reporter.wrap_check(
@@ -187,6 +189,7 @@ def is_node_supported(
187189
exir_ops.edge.aten.gt.Tensor,
188190
exir_ops.edge.aten.gt.Scalar,
189191
exir_ops.edge.aten.le.Tensor,
192+
exir_ops.edge.aten.le.Scalar,
190193
exir_ops.edge.aten.lt.Tensor,
191194
exir_ops.edge.aten.lt.Scalar,
192195
exir_ops.edge.aten.mul.Tensor,
@@ -245,6 +248,7 @@ def is_node_supported(
245248
exir_ops.edge.aten.alias_copy.default,
246249
exir_ops.edge.aten.sinh.default,
247250
exir_ops.edge.aten.atan.default,
251+
exir_ops.edge.aten.acosh.default,
248252
]
249253

250254
return supported

0 commit comments

Comments
 (0)