Skip to content

Commit 87ac448

Browse files
authored
Merge branch 'main' into Meta_fusing_conv1d_relu_residual_add
2 parents cad772c + 71a7806 commit 87ac448

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1985
-806
lines changed

backends/arm/README.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,38 @@ List of model specific and optional passes:
209209
- InsertCastForOpsWithInt64InputPass
210210
- Functionality:
211211
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
212-
- Example usage: backends/arm/test/models/test_llama.py
213212
- Supported Ops:
214213
- aten.embedding.default, aten.slice_copy.Tensor
214+
- Example usage:
215+
- backends/arm/test/models/test_llama.py
216+
217+
- ConvertInt64ConstOpsToInt32Pass
218+
- Functionalities:
219+
- Rewrites constant-producing ops that output int64 to instead output int32, when values are within int32 bounds.
220+
- Supported Ops:
221+
- `torch.full`, `torch.arange`, `torch.eye`, `torch.linspace`, `torch.tensor`
222+
- Example usage:
223+
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
224+
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
225+
226+
- ConvertInt64OutputOpsToInt32Pass
227+
- Overview:
228+
- Rewrites or removes operations that produce int64 outputs, converting them to int32 where possible.
229+
- Overflow checks are applied selectively; for ops without such checks, users need to ensure values fit within the int32 range.
230+
- Functionalities:
231+
1. Handling casting to int64:
232+
- (1) int32 -> int64:
233+
- Removes the cast and redirect uses of int64 to int32
234+
- (2) other types -> int64:
235+
- Rewrites the cast to other types -> int32
236+
- Supported Ops:
237+
- torch.ops.aten.to.\[dtype|dtype_layout\]
238+
- exir_ops.edge.dim_order_ops._to_dim_order_copy.default
239+
2. Post-process argmax outputs:
240+
- Inserts an int64->int32 cast after the argmax operations that produce int64 outputs:
241+
- Supported Ops:
242+
- torch.ops.aten.argmax.default
243+
- exir_ops.edge.aten.argmax.default
244+
- Example usage:
245+
- (Functionality 1) backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
246+
- (Functionality 2) backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

backends/arm/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
from .cast_to_int32_pass import CastToInt32Pass # noqa
1515
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
1616
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
17+
from .convert_elu_params import ConvertELUParamsPass # noqa
1718
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
1819
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
20+
from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa
21+
from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa
1922
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
2023
from .convert_minmax_pass import ConvertMinMaxPass # noqa
2124
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
@@ -34,6 +37,7 @@
3437
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
3538
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
3639
from .decompose_div_pass import DecomposeDivPass # noqa
40+
from .decompose_elu_pass import DecomposeEluPass # noqa
3741
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
3842
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
3943
from .decompose_gelu_pass import DecomposeGeluPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
ComputeConstantOpsAOT,
1919
Conv1dUnsqueezePass,
2020
ConvertAnyDefaultDimDimsPass,
21+
ConvertELUParamsPass,
2122
ConvertExpandCopyToRepeatPass,
2223
ConvertFullLikeToFullPass,
24+
ConvertInt64ConstOpsToInt32Pass,
25+
ConvertInt64OutputOpsToInt32Pass,
2326
ConvertIntPowToMuls,
2427
ConvertMinMaxPass,
2528
ConvertMmToBmmPass,
@@ -39,6 +42,7 @@
3942
DecomposeCosineSimilarityPass,
4043
DecomposeCumsumPass,
4144
DecomposeDivPass,
45+
DecomposeEluPass,
4246
DecomposeEmbeddingPass,
4347
DecomposeExpm1Pass,
4448
DecomposeGeluPass,
@@ -98,6 +102,7 @@
98102
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
99103
from executorch.exir import ExportedProgram
100104
from executorch.exir.pass_manager import PassManager
105+
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
101106
from torch.fx import GraphModule
102107

103108

@@ -132,6 +137,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
132137
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
133138
self.add_pass(AnnotateDecomposedMatmulPass())
134139
self.add_pass(QuantizeOperatorArguments())
140+
self.add_pass(ConvertELUParamsPass())
135141
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
136142
self.add_pass(RetraceFoldedDtypesPass())
137143
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
@@ -180,6 +186,8 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
180186
self.add_pass(DecomposeAtanPass())
181187
self.add_pass(DecomposeAtanhPass())
182188
self.add_pass(DecomposeAddmmPass())
189+
self.add_pass(DecomposeEluPass())
190+
self.add_pass(DecomposeExpm1Pass())
183191
self.add_pass(ConvertIntPowToMuls())
184192
self.add_pass(CastBoolToInt8Pass())
185193
self.add_pass(DecomposeSinhPass())
@@ -258,6 +266,11 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
258266
)
259267

260268
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
269+
self.add_pass(
270+
RemoveGraphAssertsPass()
271+
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
272+
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
273+
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
261274
self.add_pass(InsertCastForOpsWithInt64InputPass())
262275
self.add_pass(DecomposeEmbeddingPass())
263276
self.add_pass(DecomposeScaledDotProductAttention())
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.arm_pass_utils import create_node
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
class ConvertELUParamsPass(ExportPass):
13+
"""
14+
Pass to convert the input_scale kwarg of ELU operator from float to
15+
int.
16+
17+
It has been set to 2 as the outputs seem to stay the same regardless of what
18+
the value of input_scale is, as long as that value is not 1.
19+
"""
20+
21+
def call(self, graph_module: torch.fx.GraphModule):
22+
modified_graph = False
23+
graph = graph_module.graph
24+
node_list = graph.find_nodes(
25+
op="call_function", target=exir_ops.edge.aten.elu.default
26+
)
27+
for node in node_list:
28+
with graph.inserting_after(node):
29+
replace_node = create_node(graph, exir_ops.edge.aten.elu.default)
30+
old_args = list(node.args)
31+
32+
alpha = old_args[1] if len(old_args) > 1 else 1.0
33+
scale = 1.0
34+
input_scale = 2.0
35+
36+
replace_node.args = (old_args[0],)
37+
38+
updated_kwargs = dict(node.kwargs)
39+
updated_kwargs["alpha"] = int(alpha)
40+
updated_kwargs["scale"] = int(scale)
41+
updated_kwargs["input_scale"] = int(input_scale)
42+
43+
replace_node.kwargs = updated_kwargs
44+
45+
node.replace_all_uses_with(replace_node)
46+
graph.erase_node(node)
47+
48+
modified_graph = True
49+
if modified_graph:
50+
graph_module.recompile()
51+
graph_module = super().call(graph_module).graph_module
52+
53+
return PassResult(graph_module, modified_graph)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
16+
logger = logging.getLogger(__name__)
17+
INT32_MIN = torch.iinfo(torch.int32).min
18+
INT32_MAX = torch.iinfo(torch.int32).max
19+
20+
21+
class ConvertInt64ConstOpsToInt32Pass(ExportPass):
22+
"""
23+
Rewrite constant ops that produce int64 to int32 where safe.
24+
25+
List of supported operatos:
26+
1. `torch.full`
27+
2. `torch.arange`
28+
3. `torch.eye`
29+
4. `torch.linspace`
30+
5. `torch.tensor`
31+
"""
32+
33+
torch_ops = [
34+
torch.ops.aten.full.default,
35+
torch.ops.aten.arange.default,
36+
torch.ops.aten.arange.start,
37+
torch.ops.aten.arange.start_step,
38+
torch.ops.aten.eye.default,
39+
torch.ops.aten.linspace.default,
40+
]
41+
42+
def call(self, graph_module: torch.fx.GraphModule):
43+
modified = False
44+
for node in graph_module.graph.nodes:
45+
if node.op != "call_function":
46+
continue
47+
48+
if node.target not in ComputeConstantOpsAOT.targeted_ops + self.torch_ops:
49+
continue
50+
51+
data = node.target(*node.args, **node.kwargs)
52+
if data.dtype is not torch.int64:
53+
continue
54+
55+
min_val, max_val = torch.min(data), torch.max(data)
56+
if INT32_MIN <= min_val and max_val <= INT32_MAX:
57+
logger.warning(
58+
f"Casting {node.name} from torch.int64 to torch.int32"
59+
f" defined in {node.meta.get('stack_trace','[no stack trace found]')}"
60+
)
61+
node.update_kwarg("dtype", torch.int32)
62+
modified = True
63+
else:
64+
logger.warning(
65+
f"[{node.name}] has values: min={min_val}, max={max_val}, which exceeds int32 range "
66+
f"([{INT32_MIN}, {INT32_MAX}]); not converting dtype to int32."
67+
)
68+
69+
if modified:
70+
graph_module.graph.eliminate_dead_code()
71+
graph_module.recompile()
72+
graph_module = super().call(graph_module).graph_module
73+
74+
return PassResult(graph_module, modified)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.backends.arm._passes.arm_pass_utils import (
13+
create_node,
14+
get_first_fake_tensor,
15+
set_node_arg,
16+
)
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
from executorch.exir.pass_base import ExportPass, PassResult
19+
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class ConvertInt64OutputOpsToInt32Pass(ExportPass):
25+
"""
26+
Rewrites or removes operations that produce int64 outputs, converting them
27+
to int32 where possible.
28+
29+
30+
Currently, this pass handles casting and argmax operators:
31+
1. int32 -> int64:
32+
removes the cast and redirects all uses to the original int32 value.
33+
2. other types -> int64:
34+
rewrites the cast to produce int32 instead of int64.
35+
3. torch.argmax()
36+
insert an int64->int32 cast after the argmax node
37+
38+
Future extensions may include operators that return int64 outputs by default
39+
(e.g., `argmin`), rewriting them or inserting an int64 -> int32 cast to yield
40+
int32 results.
41+
42+
Note: Overflow checks are applied selectively in this pass. For operators without
43+
such checks, it is the user's responsibility to ensure that values fit within
44+
the int32 range.
45+
"""
46+
47+
aten_cast_ops = (
48+
torch.ops.aten.to.dtype,
49+
torch.ops.aten.to.dtype_layout,
50+
)
51+
edge_cast_ops = (exir_ops.edge.dim_order_ops._to_dim_order_copy.default,)
52+
53+
aten_argmax_ops = (torch.ops.aten.argmax.default,)
54+
edge_argmax_ops = (exir_ops.edge.aten.argmax.default,)
55+
56+
aten_ops = aten_cast_ops + aten_argmax_ops
57+
edge_ops = edge_cast_ops + edge_argmax_ops
58+
59+
# dtype is specified in args
60+
cast_ops_args = (
61+
torch.ops.aten.to.dtype, # to_2: node.args: (gt, torch.int64) node.kwargs: {}
62+
)
63+
# dtype is specified in kwargs
64+
cast_ops_kwargs = (
65+
torch.ops.aten.to.dtype_layout, # to_1: node.args: (unsqueeze,) node.kwargs: {'dtype': torch.int64, 'layout': torch.strided, 'device': device(type='cpu')}
66+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default, # node.args: (aten_gt_scalar,) node.kwargs: {'dtype': torch.int64, 'dim_order': [0, 1]}
67+
)
68+
69+
def _get_decomposition(self, op):
70+
if op in self.edge_ops:
71+
return exir_ops.edge.aten._to_copy.default
72+
73+
if op in self.aten_ops:
74+
return torch.ops.aten._to_copy.default
75+
76+
raise RuntimeError(
77+
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"
78+
)
79+
80+
def _convert_casting_operators(self, node: torch.fx.Node):
81+
input_node = node.all_input_nodes[0]
82+
input_dtype = get_first_fake_tensor(input_node).dtype
83+
# Case 1: int32 -> int64 - removes the ops
84+
if input_dtype == torch.int32:
85+
users = [user for user in node.users if node != user]
86+
for user in users:
87+
logger.warning(
88+
f"Removing int32->int64 casting node {node.name} defined in"
89+
f" {node.meta.get('stack_trace','[no stack trace found]')}"
90+
)
91+
user.replace_input_with(node, input_node)
92+
# Case 2: other types -> int64 - rewrites to cast to int32
93+
else:
94+
if node.target in self.cast_ops_kwargs:
95+
set_node_arg(node, "dtype", torch.int32)
96+
elif node.target in self.cast_ops_args:
97+
set_node_arg(node, 1, torch.int32)
98+
else:
99+
raise RuntimeError(f"Unexpected target {node.target} in {node.name}")
100+
output_dtype = get_first_fake_tensor(node).dtype
101+
logger.warning(
102+
f"Converting casting node {node.name} from {input_dtype}->{output_dtype} to"
103+
f" {input_dtype}->torch.int32 defined in {node.meta.get('stack_trace','[no stack trace found]')}"
104+
)
105+
106+
def _convert_argmax_operators(self, node: torch.fx.Node, graph: torch.fx.Graph):
107+
output_tensor = node
108+
to_copy_op = self._get_decomposition(node.target)
109+
with graph.inserting_after(node):
110+
cast_after = create_node(
111+
graph,
112+
to_copy_op,
113+
args=(output_tensor,),
114+
kwargs={
115+
"dtype": torch.int32,
116+
},
117+
)
118+
users = [user for user in node.users if user != cast_after]
119+
for user in users:
120+
user.replace_input_with(output_tensor, cast_after)
121+
logger.warning(
122+
f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 output"
123+
f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}"
124+
)
125+
126+
def call(self, graph_module: torch.fx.GraphModule):
127+
modified = False
128+
graph = graph_module.graph
129+
for node in list(graph.nodes):
130+
if node.op != "call_function":
131+
continue
132+
if node.target not in self.aten_ops + self.edge_ops:
133+
continue
134+
output_dtype = get_first_fake_tensor(node).dtype
135+
if output_dtype != torch.int64:
136+
continue
137+
138+
if node.target in self.aten_cast_ops + self.edge_cast_ops:
139+
self._convert_casting_operators(node)
140+
elif node.target in self.aten_argmax_ops + self.edge_argmax_ops:
141+
# TODO: Add range check based on the input tensor shape before casting the output
142+
self._convert_argmax_operators(node, graph)
143+
else:
144+
raise RuntimeError(f"Unexpected target {node.target} in {node.name}")
145+
146+
modified = True
147+
148+
if modified:
149+
graph_module.graph.eliminate_dead_code()
150+
graph_module.recompile()
151+
graph_module = super().call(graph_module).graph_module
152+
153+
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)