Skip to content

Commit 39203cf

Browse files
Arm backend: Add support for bool->fp32 cast for INT+FP profile (pytorch#16363)
- Add RewriteBoolToFp32CastViaInt8Pass to rewrite unsupported bool->fp32 to_dim_order_copy into bool->int8->fp32 when both INT and FP TOSA profiles are enabled - Extend to_dim_order_copy operator support to include bool->fp32 for the INT+FP profile Signed-off-by: Yufeng Shi <[email protected]>
1 parent 479791d commit 39203cf

File tree

5 files changed

+125
-0
lines changed

5 files changed

+125
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@
113113
from .replace_scalar_with_tensor_pass import ( # noqa
114114
ReplaceScalarWithTensorByProfilePass,
115115
)
116+
from .rewrite_bool_to_fp32_cast_via_int8_pass import ( # noqa
117+
RewriteBoolToFp32CastViaInt8Pass,
118+
)
116119
from .rewrite_conv_pass import RewriteConvPass # noqa
117120
from .rewrite_matmul import RewriteMatmulPass # noqa
118121
from .rewrite_upsample import RewriteUpsamplePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
RemoveNoopPass,
104104
ReplaceInfAndLimitValuesPass,
105105
ReplaceScalarWithTensorByProfilePass,
106+
RewriteBoolToFp32CastViaInt8Pass,
106107
RewriteConvPass,
107108
RewriteMatmulPass,
108109
RewriteUpsamplePass,
@@ -221,6 +222,7 @@ def _tosa_pipeline(
221222
self.add_passes(
222223
[
223224
FuseQuantizedActivationPass(),
225+
RewriteBoolToFp32CastViaInt8Pass(),
224226
ConvertToClampPass(),
225227
DecomposeTOSAUnsupportedClampPass(),
226228
DecomposeGroupNormPass(),
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.backends.arm._passes.arm_pass_utils import (
12+
create_node,
13+
get_first_fake_tensor,
14+
set_node_arg,
15+
)
16+
from executorch.backends.arm.tosa.specification import get_context_spec
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
from executorch.exir.pass_base import ExportPass, PassResult
19+
20+
21+
class RewriteBoolToFp32CastViaInt8Pass(ArmPass):
22+
"""
23+
Legalizes unsupported bool->fp32 to_dim_order_copy casts for the Arm TOSA
24+
backend when both integer and float TOSA profiles are enabled.
25+
26+
For the combined INT+FP profile, this pass rewrites a single bool->fp32 cast
27+
into a bool->int8 cast followed by an int8->fp32 cast, so that each cast
28+
is individually supported by the TOSA INT and FP profiles. For other
29+
profiles (INT-only or FP-only) the pass is a no-op.
30+
"""
31+
32+
_passes_required_after: Set[Type[ExportPass]] = set()
33+
34+
targeted_ops = {exir_ops.edge.dim_order_ops._to_dim_order_copy.default}
35+
36+
def call(self, graph_module: torch.fx.GraphModule):
37+
modified = False
38+
39+
tosa_spec = get_context_spec()
40+
if not (tosa_spec.support_integer() and tosa_spec.support_float()):
41+
return PassResult(graph_module, modified)
42+
43+
graph = graph_module.graph
44+
for node in graph.nodes:
45+
if node.op != "call_function" or node.target not in self.targeted_ops:
46+
continue
47+
48+
input_node = node.all_input_nodes[0]
49+
input_dtype = get_first_fake_tensor(input_node).dtype
50+
if input_dtype != torch.bool:
51+
continue
52+
53+
output_dtype = get_first_fake_tensor(node).dtype
54+
if output_dtype != torch.float32:
55+
continue
56+
57+
set_node_arg(node, "dtype", torch.int8)
58+
59+
users = list(node.users)
60+
with graph.inserting_after(node):
61+
cast_after = create_node(
62+
graph,
63+
node.target,
64+
args=(node,),
65+
kwargs={
66+
"dtype": torch.float32,
67+
},
68+
)
69+
for user in users:
70+
user.replace_input_with(node, cast_after)
71+
modified = True
72+
73+
if modified:
74+
graph_module.recompile()
75+
graph_module = super().call(graph_module).graph_module
76+
77+
return PassResult(graph_module, modified)

backends/arm/operator_support/to_dim_order_copy_support.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def _merge_supported_types(
117117
torch.float32,
118118
],
119119
}
120+
SUPPORTED_INT_FP_PROFILE_DTYPES: SupportedTypeDict = {
121+
torch.bool: [torch.float32],
122+
}
120123

121124
def is_node_tosa_supported(
122125
self, node: fx.Node, tosa_spec: TosaSpecification
@@ -137,6 +140,10 @@ def is_node_tosa_supported(
137140
supported_dtypes = self._merge_supported_types(
138141
self.SUPPORTED_FP_PROFILE_DTYPES, supported_dtypes
139142
)
143+
if tosa_spec.support_integer() and tosa_spec.support_float():
144+
supported_dtypes = self._merge_supported_types(
145+
self.SUPPORTED_INT_FP_PROFILE_DTYPES, supported_dtypes
146+
)
140147

141148
if len(node.all_input_nodes) != 1:
142149
self.reporter.report_reject(

backends/arm/test/ops/test_to_copy.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,39 @@ def test_to_u55_INT(test_data: Tuple):
258258
non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty
259259
)
260260
pipeline.run()
261+
262+
263+
_TO_COPY_TEST_DATA_INT_FP = {
264+
"bool_fp32": lambda: (
265+
torch.tensor([True, False], dtype=torch.bool),
266+
torch.float32,
267+
),
268+
}
269+
270+
271+
@common.parametrize("test_data", _TO_COPY_TEST_DATA_INT_FP)
272+
@common.SkipIfNoModelConverter
273+
def test_to_vgf_no_quant_bool_fp32(test_data: Tuple):
274+
test_tensor, new_dtype = test_data()
275+
pipeline = VgfPipeline[input_t1](
276+
Cast(new_dtype),
277+
(test_tensor,),
278+
aten_op=[],
279+
exir_op=[],
280+
quantize=False,
281+
)
282+
pipeline.run()
283+
284+
285+
@common.parametrize("test_data", _TO_COPY_TEST_DATA_INT_FP)
286+
@common.SkipIfNoModelConverter
287+
def test_to_vgf_quant_bool_fp32(test_data: Tuple):
288+
test_tensor, new_dtype = test_data()
289+
pipeline = VgfPipeline[input_t1](
290+
Cast(new_dtype),
291+
(test_tensor,),
292+
aten_op=[],
293+
exir_op=[],
294+
quantize=True,
295+
)
296+
pipeline.run()

0 commit comments

Comments
 (0)