Skip to content

Commit 3eb7947

Browse files
authored
Arm backend: Refactor partitioning of int64 (#13726)
Earlier logic partitioned some edge cases that could not be handled by the backend, and caused small single-constant partitions. The new logic is more conservative, so some cases that could be partitioned before might not be anymore. However, the idea is to trade a little less partitioning for more robustness. Update tests affected by the change. Also add small dtype debug info to aid with future int64 bug squashing. Signed-off-by: Erik Lundell <[email protected]>
1 parent dfc387b commit 3eb7947

File tree

8 files changed

+220
-32
lines changed

8 files changed

+220
-32
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 83 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def tosa_support_factory(
116116

117117
# Negative checks: Remove nodes from partitioning
118118
negative_checks: list[OperatorSupportBase] = [
119-
CheckInt64Inputs(exported_program, reporter),
119+
CheckInt64InputsAndOutputs(exported_program, reporter),
120120
CheckFloat64Inputs(exported_program, reporter),
121121
RankCheck(reporter, max_rank=5),
122122
*[
@@ -454,7 +454,18 @@ def is_node_supported(
454454
return True
455455

456456

457-
class CheckInt64Inputs(OperatorSupportBase):
457+
class CheckInt64InputsAndOutputs(OperatorSupportBase):
458+
"""TOSA does not support int64 tensors so in general, ops with int64 inputs or outputs should not be partitioned.
459+
There are however some exceptions:
460+
- Nodes with int64 output can be partitioned if they are constant, within int32,
461+
and all users cast to something else. In this case, the int64 tensor can safely be cast to int32 AOT.
462+
- Nodes with int64 output can be partitioned if all users are getitem with non-int64 output.
463+
In this case, there are multiple outputs and the int64 ones are not used.
464+
- Nodes with int64 inputs can be partitioned if the inputs are constant placeholders, or constant
465+
ops fulfilling the criteria above.
466+
Note that we don't check placeholders here, they are partitioned based on whether their users are partitioned
467+
or not.
468+
"""
458469

459470
def __init__(
460471
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
@@ -465,27 +476,85 @@ def __init__(
465476
if spec.kind == InputKind.USER_INPUT
466477
]
467478
self.reporter = reporter
479+
self.int32_min = torch.iinfo(torch.int32).min
480+
self.int32_max = torch.iinfo(torch.int32).max
468481
super().__init__()
469482

483+
def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
484+
"""Node is assumed to be call_function with int64 output."""
485+
if isinstance(node.target, str):
486+
return False
487+
data = node.target(*node.args, **node.kwargs)
488+
min_val, max_val = int(torch.min(data)), int(torch.max(data))
489+
return min_val >= self.int32_min and max_val <= self.int32_max
490+
470491
def is_node_supported(
471492
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
472493
) -> bool:
473494

495+
vals = node.meta["val"]
496+
tensor_list = vals if isinstance(vals, (list, tuple)) else [vals]
497+
498+
any_int64 = any(tensor.dtype == torch.int64 for tensor in tensor_list)
499+
# Don't partition nodes with int64 output...
500+
if any_int64:
501+
# ... Except for constant ops that are directly cast to something non-int64.
502+
# This could be an explicit cast, or something like a less than that outputs a different dtype than the input.
503+
users_output_non_int64 = all(
504+
get_first_fake_tensor(output_node).dtype != torch.int64
505+
for output_node in node.users
506+
)
507+
if (
508+
node.target in ComputeConstantOpsAOT.targeted_ops
509+
and users_output_non_int64
510+
):
511+
if not self.inside_int32_bounds(node):
512+
self.reporter.report_reject(
513+
node, "Constant node outside int32 range."
514+
)
515+
return False
516+
# Will never have input nodes, safe to return True
517+
return True
518+
519+
# ... Or ops with multiple outputs where only non-int64 are used.
520+
users_are_getitem = all(
521+
user.target == operator.getitem for user in node.users
522+
)
523+
if users_are_getitem and users_output_non_int64:
524+
# Passed output check, go to input check.
525+
pass
526+
else:
527+
self.reporter.report_reject(
528+
node, "Non-constant node with int64 output."
529+
)
530+
return False
531+
532+
# Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned.
533+
# If it is not partitioned, the partition will get an int64 input and fail.
474534
for input_node in node.all_input_nodes:
475-
# We can cast constant placeholders and constant ops AOT, such int64 are ok.
476-
# Otherwise, don't partition if one or more inputs are int64.
535+
tensor_in = get_first_fake_tensor(input_node)
536+
if tensor_in.dtype != torch.int64:
537+
continue
538+
# Constant placeholder
477539
if (
478-
input_node.name in self.input_names
479-
or not input_node.op == "placeholder"
540+
input_node.op != "call_function"
541+
and input_node.name not in self.input_names
480542
):
481-
tensor = get_first_fake_tensor(input_node)
482-
if tensor.dtype == torch.int64:
483-
if input_node.target not in ComputeConstantOpsAOT.targeted_ops:
484-
self.reporter.report_reject(
485-
node,
486-
f"Had int64 input {input_node.name} that couldn't be handled.",
487-
)
488-
return False
543+
continue
544+
# Constant operator
545+
if input_node.op == "call_function":
546+
if input_node.target in ComputeConstantOpsAOT.targeted_ops:
547+
# This is not perfect since the input_node can still be rejected by other checks but
548+
# this should cover the majority of cases.
549+
if self.is_node_supported(
550+
None, input_node # type: ignore[arg-type] #(we don't use 'submodules')
551+
):
552+
continue
553+
self.reporter.report_reject(
554+
node, f"Non-constant int64 input {input_node.name}"
555+
)
556+
return False
557+
489558
return True
490559

491560

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Tuple
6+
7+
import torch
8+
from executorch.backends.arm.test import common
9+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
10+
11+
12+
class ConstAdd(torch.nn.Module):
13+
def __init__(self, dtype: torch.dtype, bias=0):
14+
super().__init__()
15+
self.dtype = dtype
16+
self.bias = bias
17+
18+
def forward(self, x: torch.Tensor):
19+
c = torch.arange(self.bias, self.bias + 10, 1, dtype=self.dtype)
20+
# Add explicit float cast to make quantization work, will be inserted by type promotion otherwise.
21+
return x + c.to(torch.float32)
22+
23+
24+
class BufferAdd(torch.nn.Module):
25+
def __init__(self, dtype: torch.dtype, bias=0):
26+
super().__init__()
27+
self.dtype = dtype
28+
self.buffer = torch.arange(0, 10, 1, dtype=self.dtype) + bias
29+
self.bias = bias
30+
31+
def forward(self, x: torch.Tensor):
32+
c = self.buffer
33+
# Add explicit float cast to make quantization work, will be inserted by type promotion otherwise.
34+
return x + c.to(torch.float32)
35+
36+
37+
class ConstChainAdd(torch.nn.Module):
38+
def __init__(self, dtype: torch.dtype):
39+
super().__init__()
40+
self.dtype = dtype
41+
42+
def forward(self, x: torch.Tensor):
43+
c = torch.arange(0, 10, 1, dtype=self.dtype).reshape((2, 5)).unsqueeze(-1)
44+
# Add explicit float cast to make quantization work, will be inserted by type promotion otherwise.
45+
return x + c.to(torch.float32)
46+
47+
48+
class BufferChainAdd(torch.nn.Module):
49+
def __init__(self, dtype: torch.dtype):
50+
super().__init__()
51+
self.dtype = dtype
52+
self.buffer = torch.arange(0, 10, 1, dtype=self.dtype)
53+
54+
def forward(self, x: torch.Tensor):
55+
c = self.buffer.reshape((2, 5)).unsqueeze(-1)
56+
# Add explicit float cast to make quantization work, will be inserted by type promotion otherwise.
57+
return x + c.to(torch.float32)
58+
59+
60+
test_data_suite = {
61+
"fp32_in+int64_buffer": (BufferAdd(torch.int64), (torch.rand(10) - 0.5,)),
62+
"fp32_in+int64_buffer_overflow": (
63+
BufferAdd(torch.int64, 2**40),
64+
(torch.rand(10) - 0.5,),
65+
),
66+
"fp32_in+int64_const": (ConstAdd(torch.int64), (torch.rand(10) - 0.5,)),
67+
"fp32_in+int64_const_overflow": (
68+
ConstAdd(torch.int64, 2**40),
69+
(torch.rand(10) - 0.5,),
70+
),
71+
"int64_in+float_const": (
72+
ConstAdd(torch.float32),
73+
(torch.randint(0, 10, (10,)),),
74+
),
75+
"fp32_in+int64_buffer_chain": (
76+
BufferChainAdd(torch.int64),
77+
(torch.rand(2, 5, 3) - 0.5,),
78+
),
79+
"fp32_in+int64_const_chain": (
80+
ConstChainAdd(torch.int64),
81+
(torch.rand(2, 5, 3) - 0.5,),
82+
),
83+
"int64_in+float_const_chain": (
84+
ConstChainAdd(torch.float32),
85+
(torch.randint(0, 10, (2, 5, 3)),),
86+
),
87+
}
88+
89+
90+
@common.parametrize("test_data", test_data_suite)
91+
def test_int64_tosa_FP(test_data: Tuple):
92+
model, inputs = test_data
93+
(
94+
ArmTester(
95+
model,
96+
inputs,
97+
common.get_tosa_compile_spec("TOSA-1.0+FP", custom_path="tosa/int64"),
98+
)
99+
.export()
100+
.to_edge_transform_and_lower()
101+
.to_executorch()
102+
.run_method_and_compare_outputs(inputs)
103+
)
104+
105+
106+
@common.parametrize("test_data", test_data_suite)
107+
def test_int64_tosa_INT(test_data: Tuple):
108+
model, inputs = test_data
109+
(
110+
ArmTester(model, inputs, common.get_tosa_compile_spec("TOSA-1.0+INT"))
111+
.quantize()
112+
.export()
113+
.to_edge_transform_and_lower()
114+
.to_executorch()
115+
.run_method_and_compare_outputs(inputs)
116+
)

backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,18 @@ class TestT5EncoderModel(unittest.TestCase):
3131
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
3232
"executorch_exir_dialects_edge__ops_aten_abs_default": 1,
3333
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 3,
34+
"executorch_exir_dialects_edge__ops_aten_arange_start_step": 2,
3435
"executorch_exir_dialects_edge__ops_aten_full_like_default": 1,
3536
"executorch_exir_dialects_edge__ops_aten_gt_Scalar": 1,
3637
"executorch_exir_dialects_edge__ops_aten_lt_Scalar": 1,
3738
"executorch_exir_dialects_edge__ops_aten_minimum_default": 1,
3839
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
3940
"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1,
41+
"executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 2,
4042
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
4143
"executorch_exir_dialects_edge__ops_aten_where_self": 1,
4244
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3,
43-
"torch.ops.higher_order.executorch_call_delegate": 3,
45+
"torch.ops.higher_order.executorch_call_delegate": 2,
4446
}
4547

4648
def _prepare_inputs(

backends/arm/test/models/test_nn_functional.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,13 @@ def forward(self, *args):
8181
@parametrize(
8282
"test_data",
8383
module_tests,
84-
xfails={
85-
"affine_grid": "Int64 input. Partition handling fails since arange int64 output is split between 2 partitions.",
86-
"unfold": "ValueError: Invalid TOSA graph",
87-
"fold": "ValueError: Invalid TOSA graph",
88-
},
8984
)
9085
def test_nn_functional_FP(test_data):
9186
module, inputs = test_data
9287
pipeline = TosaPipelineFP[input_t](
9388
module, inputs, "", use_to_edge_transform_and_lower=False
9489
)
9590
pipeline.pop_stage("check.aten")
96-
pipeline.dump_artifact("to_edge")
9791
pipeline.pop_stage("check_count.exir")
9892
try:
9993
pipeline.run()
@@ -105,14 +99,11 @@ def test_nn_functional_FP(test_data):
10599
raise e
106100

107101

108-
x_fails = {
109-
"normalize": "MLETORCH-852: Support aten.index_put.default",
110-
"unfold": "Int64 input && MLETORCH-827: Support aten.index.Tensor",
111-
"fold": "Int64 input && MLETORCH-827: Support aten.index_put.default",
112-
}
113-
114-
115-
@parametrize("test_data", module_tests, x_fails, strict=False)
102+
@parametrize(
103+
"test_data",
104+
module_tests,
105+
{"normalize": "MLETORCH-1255: Unsupported dtype in InsertTableOpsPass"},
106+
)
116107
def test_nn_functional_INT(test_data):
117108
module, inputs = test_data
118109
pipeline = TosaPipelineINT[input_t](

backends/arm/test/ops/test_arange.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.backends.arm.test.tester.test_pipeline import (
1313
EthosU55PipelineINT,
1414
EthosU85PipelineINT,
15+
OpNotSupportedPipeline,
1516
TosaPipelineFP,
1617
TosaPipelineINT,
1718
VgfPipeline,
@@ -46,6 +47,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4647
lambda: (torch.randint(0, 10, [10], dtype=torch.int32),),
4748
(0.0, 10.0, 1.0, torch.int32),
4849
),
50+
}
51+
test_reject: dict[str, test_data_t] = {
4952
"int32_int64": (
5053
lambda: (torch.randint(0, 10, [10], dtype=torch.int32),),
5154
(0.0, 10.0, 1.0, torch.int64),
@@ -77,6 +80,15 @@ def test_arange_start_step_tosa_FP_dtypes(test_data: test_data_t):
7780
pipeline.run()
7881

7982

83+
@common.parametrize("test_data", ArangeAdd.test_reject)
84+
def test_arange_start_step_tosa_FP_not_delegated(test_data: test_data_t):
85+
input_data, init_data = test_data
86+
pipeline = OpNotSupportedPipeline[input_t](
87+
ArangeAdd(*init_data), input_data(), non_delegated_ops={ArangeAdd.exir_op: 1}
88+
)
89+
pipeline.run()
90+
91+
8092
@common.parametrize("test_data", ArangeAdd.test_data)
8193
def test_arange_start_step_tosa_INT(test_data: test_data_t):
8294
input_data, init_data = test_data

backends/arm/test/ops/test_ones.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def test_ones_u85_INT(test_data: test_data_t):
106106
xfails={
107107
"fp32_int32": "MLETORCG-716: Do not delegate empty networks to vela",
108108
"fp32_int64": "MLETORCG-716: Do not delegate empty networks to vela",
109-
"int32_int64": "MLETORCG-716: Do not delegate empty networks to vela",
110109
},
111110
)
112111
def test_ones_tosa_INT_not_delegated(test_data: test_data_t):

backends/arm/test/ops/test_zeros.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def test_zeros_u85_INT(test_data: test_data_t):
106106
xfails={
107107
"fp32_int32": "MLETORCG-716: Do not delegate empty networks to vela",
108108
"fp32_int64": "MLETORCG-716: Do not delegate empty networks to vela",
109-
"int32_int64": "MLETORCG-716: Do not delegate empty networks to vela",
110109
},
111110
)
112111
def test_zeros_tosa_INT_not_delegated(test_data: test_data_t):

backends/arm/tosa/dialect/ops/table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,6 @@ def TABLE(a, table):
4848
raise TosaValueError(f"Table dtype {table.dtype} is not int32", op="TABLE")
4949
return_dtype = torch.int32
5050
else:
51-
raise TosaValueError(f"Unsupported dtype for {tosa_spec}", op="TABLE")
51+
raise TosaValueError(f"Unsupported dtype {a.dtype} for {tosa_spec}", op="TABLE")
5252

5353
return torch.empty_like(a, dtype=return_dtype)

0 commit comments

Comments
 (0)