Skip to content

Commit 9930b36

Browse files
Merge branch 'main' into experimental/clone_support
2 parents 4e03115 + e38734e commit 9930b36

37 files changed

+1058
-283
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .decompose_int16_activation_conv2d_pass import ( # noqa
5353
DecomposeConv2dWithInt16ActivationPass,
5454
)
55+
from .decompose_int32_clamp_pass import DecomposeInt32ClampPass # noqa
5556
from .decompose_int_pow_pass import DecomposeIntPowPass # noqa
5657
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
5758
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
DecomposeGluPass,
5656
DecomposeGroupedConvPass,
5757
DecomposeGroupNormPass,
58+
DecomposeInt32ClampPass,
5859
DecomposeIntPowPass,
5960
DecomposeLayerNormPass,
6061
DecomposeLeakyReLUPass,
@@ -122,7 +123,6 @@
122123

123124

124125
class ArmPassManager(PassManager):
125-
126126
def __init__(self, tosa_spec: TosaSpecification) -> None:
127127
self.tosa_spec = tosa_spec
128128
super().__init__()
@@ -174,6 +174,7 @@ def _tosa_pipeline(
174174
FuseQuantizedActivationPass(),
175175
RemoveGetItemPass(),
176176
ConvertToClampPass(),
177+
DecomposeInt32ClampPass(),
177178
DecomposeGroupNormPass(),
178179
DecomposeLayerNormPass(),
179180
DecomposeBatchNormNoStatsPass(),
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass
12+
13+
14+
class DecomposeInt32ClampPass(ArmPass):
15+
"""Rewrite int32 clamp into min/max chain since TOSA lacks int32 clamp support."""
16+
17+
_passes_required_after: Set[Type[ExportPass]] = set()
18+
_supported_ops = {
19+
exir_ops.edge.aten.clamp.default,
20+
torch.ops.aten.clamp.default,
21+
}
22+
23+
def _ensure_tensor(
24+
self,
25+
value,
26+
ref_tensor,
27+
dtype,
28+
rank,
29+
meta,
30+
):
31+
if value is None:
32+
return None
33+
return super().call_operator(
34+
exir_ops.edge.aten.full.default,
35+
((1,) * rank, value),
36+
{"dtype": dtype},
37+
meta,
38+
updated=True,
39+
)
40+
41+
def call_operator(self, op, args, kwargs, meta):
42+
val = meta["val"]
43+
if op not in self._supported_ops or val.dtype != torch.int32:
44+
return super().call_operator(op, args, kwargs, meta)
45+
46+
input_tensor = args[0]
47+
min_arg = args[1] if len(args) > 1 else None
48+
max_arg = args[2] if len(args) > 2 else None
49+
dtype = val.dtype
50+
rank = len(val.shape)
51+
52+
min_arg = self._ensure_tensor(min_arg, input_tensor, dtype, rank, meta)
53+
max_arg = self._ensure_tensor(max_arg, input_tensor, dtype, rank, meta)
54+
55+
current = input_tensor
56+
if max_arg is not None:
57+
current = super().call_operator(
58+
exir_ops.edge.aten.minimum.default,
59+
(current, max_arg),
60+
{},
61+
meta,
62+
updated=True,
63+
)
64+
if min_arg is not None:
65+
current = super().call_operator(
66+
exir_ops.edge.aten.maximum.default,
67+
(current, min_arg),
68+
{},
69+
meta,
70+
updated=True,
71+
)
72+
return current

backends/arm/operators/node_visitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _serialize_operator(
8787
None: Mutates ``tosa_graph`` in place.
8888
8989
"""
90-
op_location = ts.TosaOpLocation()
90+
op_location = None
9191
if self.debug_hook:
9292
debug_info = self.debug_hook.add(
9393
node,
@@ -96,7 +96,7 @@ def _serialize_operator(
9696
)
9797

9898
if self.debug_hook.mode == ArmCompileSpec.DebugMode.TOSA:
99-
op_location.text = json.dumps(debug_info.to_dict())
99+
op_location = json.dumps(debug_info.to_dict())
100100

101101
tosa_graph.addOperator(
102102
tosa_op,

backends/arm/operators/op_clamp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(self, *args):
4040
def _get_min_max_arguments(
4141
self, node: Node, dtype: torch.dtype
4242
) -> Tuple[int | float, int | float]:
43-
4443
def cast_type(value: Any) -> int | float:
4544
if isinstance(value, int):
4645
return value
@@ -91,7 +90,12 @@ def define_node(
9190
validate_valid_dtype(
9291
self.target,
9392
[inputs[0], output],
94-
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP16, ts.DType.FP32],
93+
[
94+
ts.DType.INT8,
95+
ts.DType.INT16,
96+
ts.DType.FP16,
97+
ts.DType.FP32,
98+
],
9599
output.tosa_spec,
96100
)
97101

0 commit comments

Comments
 (0)