Skip to content

Commit 1d35df2

Browse files
robellzingooscarandersson8218
authored
Arm Backend: improve non-persistent placeholder and bool handling (#15992)
Rework bool handling pass to make it more generic. * Improving the pass to handle type promotion in cases where one operand is not bool. * Rename bool promotion pass * Add tests for promote bool operands pass Handle non-persistent buffer placeholders. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai --------- Signed-off-by: Rob Elliott <[email protected]> Co-authored-by: Zingo Andersen <[email protected]> Co-authored-by: Oscar Andersson <[email protected]>
1 parent 3f34709 commit 1d35df2

File tree

6 files changed

+218
-76
lines changed

6 files changed

+218
-76
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
1010
from .annotate_output_dim_order_pass import AnnotateOutputDimOrderPass # noqa
1111
from .broadcast_args_pass import BroadcastArgsPass # noqa
12-
from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa
1312
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1413
from .cast_to_int32_pass import CastToInt32Pass # noqa
1514
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
@@ -101,6 +100,7 @@
101100
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
102101
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
103102
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
103+
from .promote_bool_operands_pass import PromoteBoolOperandsPass # noqa
104104
from .remove_getitem_pass import RemoveGetItemPass # noqa
105105
from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa
106106
from .remove_noop_pass import RemoveNoopPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
AnnotateDecomposedMatmulPass,
1515
AnnotateOutputDimOrderPass,
1616
BroadcastArgsPass,
17-
CastBoolToInt8Pass,
1817
CastInt64BuffersToInt32Pass,
1918
CastToInt32Pass,
2019
ComputeConstantOpsAOTPass,
@@ -93,6 +92,7 @@
9392
InsertTableOpsPass,
9493
MatchArgDtypePass,
9594
MatchArgRanksPass,
95+
PromoteBoolOperandsPass,
9696
QuantizeClampArgumentsPass,
9797
RemoveGetItemPass,
9898
RemoveGraphAssertsPass,
@@ -218,7 +218,7 @@ def _tosa_pipeline(
218218
DecomposeEluPass(),
219219
DecomposeExpm1Pass(),
220220
DecomposeIntPowPass(),
221-
CastBoolToInt8Pass(),
221+
PromoteBoolOperandsPass(),
222222
DecomposeSinhPass(),
223223
DecomposeSignPass(),
224224
DecomposeFloorDividePass(),
@@ -330,7 +330,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
330330
DecomposeScaledDotProductAttentionPass(),
331331
DecomposeRoundPass(),
332332
DecomposeLogitPass(),
333-
CastBoolToInt8Pass(),
333+
PromoteBoolOperandsPass(),
334334
DecomposeSignPass(),
335335
DecomposeAddmmPass(),
336336
DecomposeRemainderPass(),

backends/arm/_passes/cast_bool_to_int8_pass.py

Lines changed: 0 additions & 63 deletions
This file was deleted.
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool inputs.
7+
# When a targeted op receives boolean tensors, we promote them to an integer type before
8+
# invocation and cast the result back to the expected dtype afterwards.
9+
10+
from typing import Set, Type
11+
12+
import torch
13+
14+
from executorch.backends.arm._passes.arm_pass import ArmPass
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
from executorch.exir.pass_base import ExportPass
17+
18+
19+
class PromoteBoolOperandsPass(ArmPass):
20+
"""Promote boolean operands to the appropriate integer dtype for unsupported ops."""
21+
22+
_passes_required_after: Set[Type[ExportPass]] = set()
23+
24+
targeted_ops = {
25+
exir_ops.edge.aten.bitwise_and.Tensor,
26+
exir_ops.edge.aten.bitwise_or.Tensor,
27+
exir_ops.edge.aten.bitwise_xor.Tensor,
28+
exir_ops.edge.aten.mul.Tensor,
29+
}
30+
31+
def call_operator(self, op, args, kwargs, meta):
32+
if op not in self.targeted_ops:
33+
return super().call_operator(op, args, kwargs, meta)
34+
35+
original_dtypes = [arg.data.dtype for arg in args]
36+
if torch.bool not in original_dtypes:
37+
return super().call_operator(op, args, kwargs, meta)
38+
39+
# select the first non-bool dtype, or None if all bool
40+
promoted_dtype = next((dt for dt in original_dtypes if dt != torch.bool), None)
41+
42+
# if we don't have a dtype specified by the op, promote to default choice for the op
43+
if promoted_dtype is None:
44+
if op == exir_ops.edge.aten.mul.Tensor:
45+
# mul as int32
46+
promoted_dtype = torch.int32
47+
else:
48+
# bitwise ops can be int8
49+
promoted_dtype = torch.int8
50+
51+
target_dtypes = []
52+
for dt in original_dtypes:
53+
if dt == torch.bool:
54+
target_dtypes.append(promoted_dtype)
55+
else:
56+
target_dtypes.append(dt)
57+
58+
new_args = []
59+
for arg, original_dtype, target_dtype in zip(
60+
args, original_dtypes, target_dtypes
61+
):
62+
if original_dtype == target_dtype:
63+
new_args.append(arg)
64+
else:
65+
new_args.append(
66+
super().call_operator(
67+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
68+
(arg,),
69+
{"dtype": target_dtype},
70+
meta,
71+
)
72+
)
73+
74+
output = super().call_operator(
75+
op,
76+
tuple(new_args),
77+
kwargs,
78+
meta,
79+
)
80+
81+
if all(dtype == torch.bool for dtype in original_dtypes):
82+
output = super().call_operator(
83+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
84+
(output,),
85+
{"dtype": torch.bool},
86+
meta,
87+
)
88+
return output

backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.exir import ExportedProgram
1212
from executorch.exir.pass_base import ExportPass, PassResult
1313
from torch._export.utils import is_buffer, is_param
14+
from torch.export.graph_signature import InputKind
1415

1516

1617
class UnsqueezeScalarPlaceholdersPass(ArmPass):
@@ -42,17 +43,30 @@ def call(self, graph_module: torch.fx.GraphModule):
4243
else:
4344
continue
4445

45-
tensor = self.exported_program.state_dict[name]
46+
tensor = self.exported_program.state_dict.get(name)
4647

48+
# If we have a persistent=False buffer with no entry in state_dict
49+
spec = next(
50+
s
51+
for s in self.exported_program.graph_signature.input_specs
52+
if getattr(s.arg, "name", None) == node.name
53+
)
54+
is_non_persistent_buffer = (
55+
spec.kind is InputKind.BUFFER and spec.persistent is False
56+
)
57+
if tensor is None and is_non_persistent_buffer:
58+
fake = node.meta["val"]
59+
tensor = torch.ones_like(fake)
60+
61+
# If we have a scalar, unsqueeze it
4762
if tensor.dim() == 0:
48-
self.exported_program.state_dict[name] = tensor.unsqueeze(0)
49-
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
50-
tensor.unsqueeze(0), static_shapes=True
51-
)
52-
else:
53-
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
54-
tensor, static_shapes=True
55-
)
63+
tensor = tensor.unsqueeze(0)
64+
65+
# update or create entry in state_dict, recreate fake
66+
self.exported_program.state_dict[name] = tensor
67+
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
68+
tensor, static_shapes=True
69+
)
5670

5771
graph_module.recompile()
5872
graph_module = super().call(graph_module).graph_module
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import ClassVar, Dict, Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes import PromoteBoolOperandsPass
10+
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
13+
from executorch.backends.test.harness.stages import StageType
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
16+
tensor_pair_t = Tuple[torch.Tensor, torch.Tensor]
17+
18+
19+
def _collect_cast_dtypes(pipeline: PassPipeline[tensor_pair_t]) -> list[torch.dtype]:
20+
exported_program = pipeline.tester.get_artifact(
21+
StageType.RUN_PASSES
22+
).exported_program()
23+
graph_module = exported_program.graph_module
24+
cast_dtypes: list[torch.dtype] = []
25+
for node in graph_module.graph.nodes:
26+
if (
27+
node.op == "call_function"
28+
and node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default
29+
and "dtype" in node.kwargs
30+
):
31+
cast_dtypes.append(node.kwargs["dtype"])
32+
return cast_dtypes
33+
34+
35+
class BoolBitwiseAndModule(torch.nn.Module):
36+
test_data: ClassVar[Dict[str, tensor_pair_t]] = {
37+
"bool_tensors": (
38+
torch.tensor([[True, False], [False, True]], dtype=torch.bool),
39+
torch.tensor([[False, True], [True, False]], dtype=torch.bool),
40+
)
41+
}
42+
43+
def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor:
44+
return torch.bitwise_and(lhs, rhs)
45+
46+
47+
class MixedMulModule(torch.nn.Module):
48+
test_data: ClassVar[Dict[str, tensor_pair_t]] = {
49+
"mixed_tensors": (
50+
torch.tensor([True, False, True, False], dtype=torch.bool),
51+
torch.tensor([1, 2, 3, 4], dtype=torch.int32),
52+
)
53+
}
54+
55+
def forward(self, lhs: torch.Tensor, rhs: torch.Tensor) -> torch.Tensor:
56+
return torch.mul(lhs, rhs)
57+
58+
59+
@common.parametrize("test_data", BoolBitwiseAndModule.test_data)
60+
def test_promote_bool_operands_all_bool(test_data: tensor_pair_t) -> None:
61+
module = BoolBitwiseAndModule()
62+
ops_before_pass = {
63+
"executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor": 1,
64+
}
65+
ops_after_pass = {
66+
"executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor": 1,
67+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3,
68+
}
69+
pipeline = PassPipeline[tensor_pair_t](
70+
module,
71+
test_data,
72+
quantize=False,
73+
ops_before_pass=ops_before_pass,
74+
ops_after_pass=ops_after_pass,
75+
pass_list=[PromoteBoolOperandsPass],
76+
)
77+
pipeline.run()
78+
cast_dtypes = _collect_cast_dtypes(pipeline)
79+
assert cast_dtypes.count(torch.int8) == 2
80+
assert cast_dtypes.count(torch.bool) == 1
81+
82+
83+
@common.parametrize("test_data", MixedMulModule.test_data)
84+
def test_promote_bool_operands_mixed_types(test_data: tensor_pair_t) -> None:
85+
module = MixedMulModule()
86+
ops_before_pass = {
87+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
88+
}
89+
ops_after_pass = {
90+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
91+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
92+
}
93+
pipeline = PassPipeline[tensor_pair_t](
94+
module,
95+
test_data,
96+
quantize=False,
97+
ops_before_pass=ops_before_pass,
98+
ops_after_pass=ops_after_pass,
99+
pass_list=[PromoteBoolOperandsPass],
100+
)
101+
pipeline.run()
102+
cast_dtypes = _collect_cast_dtypes(pipeline)
103+
assert cast_dtypes.count(torch.int32) == 1

0 commit comments

Comments
 (0)