Skip to content

Commit a19c42b

Browse files
YufengShi-duduErik-Lundellzingo
authored
Arm backend: Add passes to handle int64 const and int64 output ops (#13803)
- Add ConvertInt64ConstOpsToInt32Pass to convert constant-producing ops that output int64 to instead output int32, when values are within int32 bounds. Supported Ops: `torch.full`, `torch.arange`, `torch.eye`, `torch.linspace`, `torch.tensor` - Add ConvertInt64OutputOpsToInt32Pass to 1. convert or remove unnecessary casts to int64 2. insert an int64->int32 cast after the argmax ndoes that produce int64 outputs Change-Id: I04e5fa9a7170c5b5dc785ae8619189545de0ec2c cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Yufeng Shi <[email protected]> Co-authored-by: Erik Lundell <[email protected]> Co-authored-by: Zingo Andersen <[email protected]>
1 parent a3a8691 commit a19c42b

11 files changed

+940
-27
lines changed

backends/arm/README.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,38 @@ List of model specific and optional passes:
209209
- InsertCastForOpsWithInt64InputPass
210210
- Functionality:
211211
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
212-
- Example usage: backends/arm/test/models/test_llama.py
213212
- Supported Ops:
214213
- aten.embedding.default, aten.slice_copy.Tensor
214+
- Example usage:
215+
- backends/arm/test/models/test_llama.py
216+
217+
- ConvertInt64ConstOpsToInt32Pass
218+
- Functionalities:
219+
- Rewrites constant-producing ops that output int64 to instead output int32, when values are within int32 bounds.
220+
- Supported Ops:
221+
- `torch.full`, `torch.arange`, `torch.eye`, `torch.linspace`, `torch.tensor`
222+
- Example usage:
223+
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
224+
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
225+
226+
- ConvertInt64OutputOpsToInt32Pass
227+
- Overview:
228+
- Rewrites or removes operations that produce int64 outputs, converting them to int32 where possible.
229+
- Overflow checks are applied selectively; for ops without such checks, users need to ensure values fit within the int32 range.
230+
- Functionalities:
231+
1. Handling casting to int64:
232+
- (1) int32 -> int64:
233+
- Removes the cast and redirect uses of int64 to int32
234+
- (2) other types -> int64:
235+
- Rewrites the cast to other types -> int32
236+
- Supported Ops:
237+
- torch.ops.aten.to.\[dtype|dtype_layout\]
238+
- exir_ops.edge.dim_order_ops._to_dim_order_copy.default
239+
2. Post-process argmax outputs:
240+
- Inserts an int64->int32 cast after the argmax operations that produce int64 outputs:
241+
- Supported Ops:
242+
- torch.ops.aten.argmax.default
243+
- exir_ops.edge.aten.argmax.default
244+
- Example usage:
245+
- (Functionality 1) backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
246+
- (Functionality 2) backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
1717
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
1818
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
19+
from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa
20+
from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa
1921
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
2022
from .convert_minmax_pass import ConvertMinMaxPass # noqa
2123
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
ConvertAnyDefaultDimDimsPass,
2121
ConvertExpandCopyToRepeatPass,
2222
ConvertFullLikeToFullPass,
23+
ConvertInt64ConstOpsToInt32Pass,
24+
ConvertInt64OutputOpsToInt32Pass,
2325
ConvertIntPowToMuls,
2426
ConvertMinMaxPass,
2527
ConvertMmToBmmPass,
@@ -98,6 +100,7 @@
98100
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
99101
from executorch.exir import ExportedProgram
100102
from executorch.exir.pass_manager import PassManager
103+
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
101104
from torch.fx import GraphModule
102105

103106

@@ -258,6 +261,11 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
258261
)
259262

260263
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
264+
self.add_pass(
265+
RemoveGraphAssertsPass()
266+
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
267+
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
268+
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
261269
self.add_pass(InsertCastForOpsWithInt64InputPass())
262270
self.add_pass(DecomposeEmbeddingPass())
263271
self.add_pass(DecomposeScaledDotProductAttention())
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
16+
logger = logging.getLogger(__name__)
17+
INT32_MIN = torch.iinfo(torch.int32).min
18+
INT32_MAX = torch.iinfo(torch.int32).max
19+
20+
21+
class ConvertInt64ConstOpsToInt32Pass(ExportPass):
22+
"""
23+
Rewrite constant ops that produce int64 to int32 where safe.
24+
25+
List of supported operatos:
26+
1. `torch.full`
27+
2. `torch.arange`
28+
3. `torch.eye`
29+
4. `torch.linspace`
30+
5. `torch.tensor`
31+
"""
32+
33+
torch_ops = [
34+
torch.ops.aten.full.default,
35+
torch.ops.aten.arange.default,
36+
torch.ops.aten.arange.start,
37+
torch.ops.aten.arange.start_step,
38+
torch.ops.aten.eye.default,
39+
torch.ops.aten.linspace.default,
40+
]
41+
42+
def call(self, graph_module: torch.fx.GraphModule):
43+
modified = False
44+
for node in graph_module.graph.nodes:
45+
if node.op != "call_function":
46+
continue
47+
48+
if node.target not in ComputeConstantOpsAOT.targeted_ops + self.torch_ops:
49+
continue
50+
51+
data = node.target(*node.args, **node.kwargs)
52+
if data.dtype is not torch.int64:
53+
continue
54+
55+
min_val, max_val = torch.min(data), torch.max(data)
56+
if INT32_MIN <= min_val and max_val <= INT32_MAX:
57+
logger.warning(
58+
f"Casting {node.name} from torch.int64 to torch.int32"
59+
f" defined in {node.meta.get('stack_trace','[no stack trace found]')}"
60+
)
61+
node.update_kwarg("dtype", torch.int32)
62+
modified = True
63+
else:
64+
logger.warning(
65+
f"[{node.name}] has values: min={min_val}, max={max_val}, which exceeds int32 range "
66+
f"([{INT32_MIN}, {INT32_MAX}]); not converting dtype to int32."
67+
)
68+
69+
if modified:
70+
graph_module.graph.eliminate_dead_code()
71+
graph_module.recompile()
72+
graph_module = super().call(graph_module).graph_module
73+
74+
return PassResult(graph_module, modified)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.backends.arm._passes.arm_pass_utils import (
13+
create_node,
14+
get_first_fake_tensor,
15+
set_node_arg,
16+
)
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
from executorch.exir.pass_base import ExportPass, PassResult
19+
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class ConvertInt64OutputOpsToInt32Pass(ExportPass):
25+
"""
26+
Rewrites or removes operations that produce int64 outputs, converting them
27+
to int32 where possible.
28+
29+
30+
Currently, this pass handles casting and argmax operators:
31+
1. int32 -> int64:
32+
removes the cast and redirects all uses to the original int32 value.
33+
2. other types -> int64:
34+
rewrites the cast to produce int32 instead of int64.
35+
3. torch.argmax()
36+
insert an int64->int32 cast after the argmax node
37+
38+
Future extensions may include operators that return int64 outputs by default
39+
(e.g., `argmin`), rewriting them or inserting an int64 -> int32 cast to yield
40+
int32 results.
41+
42+
Note: Overflow checks are applied selectively in this pass. For operators without
43+
such checks, it is the user's responsibility to ensure that values fit within
44+
the int32 range.
45+
"""
46+
47+
aten_cast_ops = (
48+
torch.ops.aten.to.dtype,
49+
torch.ops.aten.to.dtype_layout,
50+
)
51+
edge_cast_ops = (exir_ops.edge.dim_order_ops._to_dim_order_copy.default,)
52+
53+
aten_argmax_ops = (torch.ops.aten.argmax.default,)
54+
edge_argmax_ops = (exir_ops.edge.aten.argmax.default,)
55+
56+
aten_ops = aten_cast_ops + aten_argmax_ops
57+
edge_ops = edge_cast_ops + edge_argmax_ops
58+
59+
# dtype is specified in args
60+
cast_ops_args = (
61+
torch.ops.aten.to.dtype, # to_2: node.args: (gt, torch.int64) node.kwargs: {}
62+
)
63+
# dtype is specified in kwargs
64+
cast_ops_kwargs = (
65+
torch.ops.aten.to.dtype_layout, # to_1: node.args: (unsqueeze,) node.kwargs: {'dtype': torch.int64, 'layout': torch.strided, 'device': device(type='cpu')}
66+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default, # node.args: (aten_gt_scalar,) node.kwargs: {'dtype': torch.int64, 'dim_order': [0, 1]}
67+
)
68+
69+
def _get_decomposition(self, op):
70+
if op in self.edge_ops:
71+
return exir_ops.edge.aten._to_copy.default
72+
73+
if op in self.aten_ops:
74+
return torch.ops.aten._to_copy.default
75+
76+
raise RuntimeError(
77+
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"
78+
)
79+
80+
def _convert_casting_operators(self, node: torch.fx.Node):
81+
input_node = node.all_input_nodes[0]
82+
input_dtype = get_first_fake_tensor(input_node).dtype
83+
# Case 1: int32 -> int64 - removes the ops
84+
if input_dtype == torch.int32:
85+
users = [user for user in node.users if node != user]
86+
for user in users:
87+
logger.warning(
88+
f"Removing int32->int64 casting node {node.name} defined in"
89+
f" {node.meta.get('stack_trace','[no stack trace found]')}"
90+
)
91+
user.replace_input_with(node, input_node)
92+
# Case 2: other types -> int64 - rewrites to cast to int32
93+
else:
94+
if node.target in self.cast_ops_kwargs:
95+
set_node_arg(node, "dtype", torch.int32)
96+
elif node.target in self.cast_ops_args:
97+
set_node_arg(node, 1, torch.int32)
98+
else:
99+
raise RuntimeError(f"Unexpected target {node.target} in {node.name}")
100+
output_dtype = get_first_fake_tensor(node).dtype
101+
logger.warning(
102+
f"Converting casting node {node.name} from {input_dtype}->{output_dtype} to"
103+
f" {input_dtype}->torch.int32 defined in {node.meta.get('stack_trace','[no stack trace found]')}"
104+
)
105+
106+
def _convert_argmax_operators(self, node: torch.fx.Node, graph: torch.fx.Graph):
107+
output_tensor = node
108+
to_copy_op = self._get_decomposition(node.target)
109+
with graph.inserting_after(node):
110+
cast_after = create_node(
111+
graph,
112+
to_copy_op,
113+
args=(output_tensor,),
114+
kwargs={
115+
"dtype": torch.int32,
116+
},
117+
)
118+
users = [user for user in node.users if user != cast_after]
119+
for user in users:
120+
user.replace_input_with(output_tensor, cast_after)
121+
logger.warning(
122+
f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 output"
123+
f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}"
124+
)
125+
126+
def call(self, graph_module: torch.fx.GraphModule):
127+
modified = False
128+
graph = graph_module.graph
129+
for node in list(graph.nodes):
130+
if node.op != "call_function":
131+
continue
132+
if node.target not in self.aten_ops + self.edge_ops:
133+
continue
134+
output_dtype = get_first_fake_tensor(node).dtype
135+
if output_dtype != torch.int64:
136+
continue
137+
138+
if node.target in self.aten_cast_ops + self.edge_cast_ops:
139+
self._convert_casting_operators(node)
140+
elif node.target in self.aten_argmax_ops + self.edge_argmax_ops:
141+
# TODO: Add range check based on the input tensor shape before casting the output
142+
self._convert_argmax_operators(node, graph)
143+
else:
144+
raise RuntimeError(f"Unexpected target {node.target} in {node.name}")
145+
146+
modified = True
147+
148+
if modified:
149+
graph_module.graph.eliminate_dead_code()
150+
graph_module.recompile()
151+
graph_module = super().call(graph_module).graph_module
152+
153+
return PassResult(graph_module, modified)

backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import unittest
88

99
import torch
10-
from executorch.backends.arm._passes import InsertCastForOpsWithInt64InputPass
10+
from executorch.backends.arm._passes import (
11+
ConvertInt64ConstOpsToInt32Pass,
12+
ConvertInt64OutputOpsToInt32Pass,
13+
InsertCastForOpsWithInt64InputPass,
14+
)
1115

1216
from executorch.backends.arm.test import common
1317
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
@@ -28,13 +32,11 @@ class TestCLIPTextModelWithProjection(unittest.TestCase):
2832
# for that is some assert ops are removed by passes in the
2933
# .to_executorch step, i.e. after Arm partitioner.
3034
ops_after_partitioner = {
31-
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 3,
35+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 4,
3236
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
33-
"executorch_exir_dialects_edge__ops_aten_index_Tensor": 1,
34-
"executorch_exir_dialects_edge__ops_aten_lt_Tensor": 1,
35-
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
37+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
3638
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
37-
"torch.ops.higher_order.executorch_call_delegate": 3,
39+
"torch.ops.higher_order.executorch_call_delegate": 2,
3840
}
3941

4042
def _prepare_inputs(
@@ -60,15 +62,19 @@ def prepare_model_and_inputs(self):
6062

6163
return text_encoder_model, text_encoder_model_inputs
6264

63-
def test_CLIPTextModelWithProjection_tosa_MI(self):
65+
def test_CLIPTextModelWithProjection_tosa_FP(self):
6466
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
6567
with torch.no_grad():
6668
(
6769
ArmTester(
6870
text_encoder_model,
6971
example_inputs=text_encoder_model_inputs,
7072
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
71-
transform_passes=[InsertCastForOpsWithInt64InputPass()],
73+
transform_passes=[
74+
InsertCastForOpsWithInt64InputPass(),
75+
ConvertInt64ConstOpsToInt32Pass(),
76+
ConvertInt64OutputOpsToInt32Pass(),
77+
],
7278
)
7379
.export()
7480
.to_edge_transform_and_lower()

0 commit comments

Comments
 (0)