Skip to content

Commit b5e7a75

Browse files
authored
Merge branch 'main' into docs/jhelsby/new-contributor-guide-update
2 parents fbb6145 + c38e71f commit b5e7a75

38 files changed

+629
-245
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .decompose_select import DecomposeSelectPass # noqa
2828
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
2929
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
30+
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
3031
from .decompose_var_pass import DecomposeVarPass # noqa
3132
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
3233
FoldAndAnnotateQParamsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DecomposeSelectPass,
3333
DecomposeSoftmaxPass,
3434
DecomposeSoftmaxUnstablePass,
35+
DecomposeSqrtPass,
3536
DecomposeVarPass,
3637
FoldAndAnnotateQParamsPass,
3738
FuseBatchnorm2DPass,
@@ -115,6 +116,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115116
return self._transform(exported_program.graph_module)
116117

117118
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
119+
self.add_pass(DecomposeSqrtPass())
118120
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
119121
self.add_pass(FuseQuantizedActivationPass())
120122
self.add_pass(RemoveGetItemPass())
@@ -181,6 +183,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
181183
self.add_pass(DecomposeMeanDimPass())
182184
self.add_pass(DecomposeDivPass())
183185
self.add_pass(DecomposeLeakyReLUPass())
186+
self.add_pass(DecomposeSqrtPass())
184187

185188
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
186189
# Numerically stable softmax uses amax which is not supported on Ethos-U55

backends/arm/_passes/cast_int64_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch._export.utils import is_buffer
1313

1414
logger = logging.getLogger(__name__)
15-
logger.setLevel(logging.WARNING)
1615

1716

1817
class CastInt64BuffersToInt32Pass(ExportPass):
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
edge_sqrt_ops = (exir_ops.edge.aten.sqrt.default,)
12+
aten_sqrt_ops = (
13+
torch.ops.aten.sqrt.default,
14+
torch.ops.aten.sqrt_.default,
15+
)
16+
17+
18+
def get_sqrt_decomposition(op) -> tuple:
19+
# TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor"
20+
if op in edge_sqrt_ops:
21+
return exir_ops.edge.aten.pow.Tensor_Scalar
22+
if op in aten_sqrt_ops:
23+
return torch.ops.aten.pow.Tensor_Scalar
24+
raise RuntimeError(f"Can't get sqrt decomposition for op {op}")
25+
26+
27+
class DecomposeSqrtPass(ExportPass):
28+
29+
def call_operator(self, op, args, kwargs, meta):
30+
"""
31+
Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support.
32+
"""
33+
34+
if op not in (edge_sqrt_ops + aten_sqrt_ops):
35+
return super().call_operator(op, args, kwargs, meta)
36+
37+
pow_op = get_sqrt_decomposition(op)
38+
39+
return super().call_operator(pow_op, (args[0], 0.5), {}, meta)

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(self, exported_program):
4848
exir_ops.edge.aten.bitwise_right_shift.Tensor,
4949
exir_ops.edge.aten.bitwise_left_shift.Tensor,
5050
exir_ops.edge.aten.eq.Tensor,
51+
exir_ops.edge.aten.gt.Tensor,
52+
exir_ops.edge.aten.lt.Tensor,
5153
exir_ops.edge.aten.pow.Tensor_Tensor,
5254
exir_ops.edge.aten.where.self,
5355
]

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@
2626
exir_ops.edge.aten.__rshift__.Scalar: exir_ops.edge.aten.bitwise_right_shift.Tensor,
2727
exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor,
2828
exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor,
29+
exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor,
30+
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
2931
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3032
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
3133
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
3234
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
3335
torch.ops.aten.__rshift__.Scalar: torch.ops.aten.bitwise_right_shift.Tensor,
3436
torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor,
3537
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
38+
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
39+
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
3640
}
3741

3842

backends/arm/arm_backend.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,13 @@
1111
# JIT compiler flows.
1212
#
1313

14-
import logging
15-
1614
from typing import List, Optional
1715

1816
from executorch.backends.arm.tosa_specification import TosaSpecification
1917

2018
from executorch.exir.backend.compile_spec_schema import CompileSpec
2119

2220

23-
logger = logging.getLogger(__name__)
24-
logger.setLevel(logging.WARNING)
25-
26-
2721
class ArmCompileSpecBuilder:
2822
def __init__(self):
2923
self.compile_spec: List[CompileSpec] = []

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,10 @@ class EthosU55NotSupported(OperatorSupportBase):
135135
exir_ops.edge.aten.eq.Scalar,
136136
exir_ops.edge.aten.ge.Tensor,
137137
exir_ops.edge.aten.gt.Tensor,
138+
exir_ops.edge.aten.gt.Scalar,
138139
exir_ops.edge.aten.le.Tensor,
139140
exir_ops.edge.aten.lt.Tensor,
141+
exir_ops.edge.aten.lt.Scalar,
140142
exir_ops.edge.aten.flip.default, # REVERSE
141143
exir_ops.edge.aten.grid_sampler_2d, # GATHER
142144
exir_ops.edge.aten.scatter.src,

backends/arm/operator_support/right_shift_support.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

1919
logger = logging.getLogger(__name__)
20-
logger.setLevel(logging.WARNING)
2120

2221

2322
@register_tosa_support_check

backends/arm/operator_support/slice_copy_support.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717

1818
logger = logging.getLogger(__name__)
19-
logger.setLevel(logging.WARNING)
2019

2120

2221
@register_tosa_support_check

0 commit comments

Comments
 (0)