Skip to content

Commit df5bfd5

Browse files
Arm backend: Add dtype check for aten.where in operator support check (#14506)
- Previously, aten.where with non-quantized FP inputs can pass the the operator support check in INT profile, get partitioned, and then be removed from the partition due to FP inputs. - This will introduce dependency cycles, an invalid re-entry pattern `partition -> outside -> partition`, to the graph. - Workaround: when aten.where(cond, x, y) has FP x and y in INT profile, only partition if both come from dequantize ops (DQ_OPS). Note this may over-reject cases like `dq -> op1 -> aten.where -> q` that could be partitioned. - Don't partition aten.where with unsupported input dtype in FP profle. Signed-off-by: Yufeng Shi <[email protected]>
1 parent aec847d commit df5bfd5

File tree

5 files changed

+110
-11
lines changed

5 files changed

+110
-11
lines changed

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@
1919
slice_copy_support,
2020
to_dim_order_copy_support,
2121
tosa_supported_operators,
22+
where_support,
2223
)

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@
104104
exir_ops.edge.aten.squeeze_copy.dims,
105105
exir_ops.edge.aten.pow.Tensor_Scalar,
106106
exir_ops.edge.aten.pow.Tensor_Tensor,
107-
exir_ops.edge.aten.where.self,
108107
operator.getitem,
109108
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
110109
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
@@ -220,7 +219,6 @@
220219
exir_ops.edge.aten.squeeze_copy.dims,
221220
exir_ops.edge.aten.pow.Tensor_Scalar,
222221
exir_ops.edge.aten.pow.Tensor_Tensor,
223-
exir_ops.edge.aten.where.self,
224222
operator.getitem,
225223
exir_ops.edge.aten.constant_pad_nd.default,
226224
exir_ops.edge.aten.amax.default,
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
9+
import torch.fx as fx
10+
from executorch.backends.arm.constants import DQ_OPS
11+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
12+
register_tosa_support_check,
13+
SupportedTOSAOperatorCheck,
14+
)
15+
from executorch.backends.arm.tosa import TosaSpecification
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
18+
19+
@register_tosa_support_check
20+
class WhereSupported(SupportedTOSAOperatorCheck):
21+
targets = [exir_ops.edge.aten.where.self]
22+
23+
tosa_specs = [
24+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
25+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
26+
]
27+
28+
def is_node_tosa_supported(
29+
self, node: fx.Node, tosa_spec: TosaSpecification
30+
) -> bool: # type: ignore[override, misc]
31+
32+
if len(node.all_input_nodes) != 3:
33+
self.reporter.report_reject(
34+
node,
35+
(
36+
"Expected exactly three input nodes, "
37+
f"got {len(node.all_input_nodes)} for {node.target}."
38+
),
39+
)
40+
return False
41+
42+
condition, x, y = node.all_input_nodes
43+
if condition.meta["val"].dtype != torch.bool:
44+
self.reporter.report_reject(
45+
node,
46+
f"Type of condition in {node.target} is not torch.bool",
47+
)
48+
return False
49+
50+
x_dtype, y_dtype = x.meta["val"].dtype, y.meta["val"].dtype
51+
if tosa_spec.support_float():
52+
if x_dtype in (torch.bool, torch.float16, torch.float32) and y_dtype in (
53+
torch.bool,
54+
torch.float16,
55+
torch.float32,
56+
):
57+
return True
58+
59+
if tosa_spec.support_integer():
60+
if (
61+
x_dtype in (torch.bool, torch.int8, torch.int16, torch.int32)
62+
or (x_dtype == torch.float32 and x.target in DQ_OPS)
63+
) and (
64+
y_dtype in (torch.bool, torch.int8, torch.int16, torch.int32)
65+
or (y_dtype == torch.float32 and y.target in DQ_OPS)
66+
):
67+
return True
68+
69+
self.reporter.report_reject(
70+
node,
71+
(
72+
f"Tensor x dtype {x_dtype} and/or tensor y dtype {y_dtype} is not supported in {node.target} "
73+
f"for tosa specification {tosa_spec}"
74+
),
75+
)
76+
77+
return False

backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import unittest
88

9-
import pytest
109
import torch
1110
from executorch.backends.arm._passes import (
1211
ConvertInt64ConstOpsToInt32Pass,
@@ -28,16 +27,25 @@ class TestCLIPTextModelWithProjection(unittest.TestCase):
2827
CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium
2928
"""
3029

31-
# Adjust nbr below as we increase op support. Note: most of the delegates
32-
# calls are directly consecutive to each other in the .pte. The reason
33-
# for that is some assert ops are removed by passes in the
34-
# .to_executorch step, i.e. after Arm partitioner.
35-
ops_after_partitioner = {
30+
# Adjust nbr below as we increase op support.
31+
ops_after_partitioner_FP = {
3632
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
3733
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
3834
"torch.ops.higher_order.executorch_call_delegate": 2,
3935
}
4036

37+
ops_after_partitioner_INT = {
38+
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
39+
"executorch_exir_dialects_edge__ops_aten_full_default": 1,
40+
"executorch_exir_dialects_edge__ops_aten_index_select_default": 1,
41+
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 1,
42+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
43+
"executorch_exir_dialects_edge__ops_aten_where_self": 1,
44+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
45+
"torch.ops.aten.scalar_tensor.default": 1,
46+
"torch.ops.higher_order.executorch_call_delegate": 2,
47+
}
48+
4149
def _prepare_inputs(
4250
self,
4351
batch_size=12,
@@ -78,14 +86,13 @@ def test_CLIPTextModelWithProjection_tosa_FP(self):
7886
.export()
7987
.to_edge_transform_and_lower()
8088
.dump_operator_distribution()
81-
.check_count(self.ops_after_partitioner)
89+
.check_count(self.ops_after_partitioner_FP)
8290
.to_executorch()
8391
.run_method_and_compare_outputs(
8492
inputs=text_encoder_model_inputs,
8593
)
8694
)
8795

88-
@pytest.mark.xfail(raises=AssertionError, reason="Output difference.")
8996
def test_CLIPTextModelWithProjection_tosa_INT(self):
9097
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
9198
with torch.no_grad():
@@ -99,8 +106,10 @@ def test_CLIPTextModelWithProjection_tosa_INT(self):
99106
.export()
100107
.to_edge_transform_and_lower()
101108
.dump_operator_distribution()
109+
.check_count(self.ops_after_partitioner_INT)
102110
.to_executorch()
103111
.run_method_and_compare_outputs(
104112
inputs=text_encoder_model_inputs,
113+
atol=0.8,
105114
)
106115
)

backends/arm/test/ops/test_where.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,11 @@ def scalar_condition(input: torch.Tensor):
139139

140140
test_modules_FP = {
141141
**test_modules_common,
142-
"float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype,
143142
"float32_tensor_cond_tuple_dtype_bool": lambda: float32_tensor_cond_tuple_dtype_bool,
143+
}
144+
145+
test_modules_FP_unsupported_dtype = {
146+
"float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype,
144147
"int32_scalar_cond": lambda: int32_scalar_cond,
145148
}
146149

@@ -162,6 +165,17 @@ def test_where_self_tosa_FP(test_module):
162165
pipeline.run()
163166

164167

168+
@common.parametrize("test_module", test_modules_FP_unsupported_dtype)
169+
def test_where_self_tosa_FP_unsupported_dtype(test_module):
170+
pipeline = OpNotSupportedPipeline[input_t](
171+
test_module(),
172+
test_module().get_inputs(),
173+
{exir_op: 1},
174+
n_expected_delegates=1, # condition can be delegated
175+
)
176+
pipeline.run()
177+
178+
165179
@common.parametrize("test_module", test_modules_INT)
166180
def test_where_self_tosa_INT(test_module):
167181
pipeline = TosaPipelineINT[input_t](

0 commit comments

Comments
 (0)