Skip to content

Commit cec1400

Browse files
Arm backend: Add INT64 support to to_copy (#13993)
Partition to_copy from INT64 as these will be computed AOT and cast to INT32. Also disables partitioning of int-to-int casts for FP-profile. Signed-off-by: Oscar Andersson <[email protected]>
1 parent c6a6caa commit cec1400

File tree

4 files changed

+46
-26
lines changed

4 files changed

+46
-26
lines changed

backends/arm/operator_support/to_copy_support.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
logger = logging.getLogger(__name__)
2222

23+
SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]
24+
2325

2426
@register_tosa_support_check
2527
class ToCopySupported(SupportedTOSAOperatorCheck):
@@ -33,8 +35,6 @@ class ToCopySupported(SupportedTOSAOperatorCheck):
3335
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3436
]
3537

36-
SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]
37-
3838
@staticmethod
3939
def _merge_supported_types(
4040
# pyre-ignore[11]
@@ -53,11 +53,22 @@ def _merge_supported_types(
5353
torch.int8: [torch.bool, torch.int16, torch.int32],
5454
torch.int16: [torch.bool, torch.int8, torch.int32],
5555
torch.int32: [torch.bool, torch.int8, torch.int16],
56+
torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32],
5657
}
5758
SUPPORTED_FLOAT_TYPES: SupportedTypeDict = {
5859
torch.int8: [torch.float16, torch.bfloat16, torch.float32],
5960
torch.int16: [torch.float16, torch.bfloat16, torch.float32],
6061
torch.int32: [torch.float16, torch.bfloat16, torch.float32],
62+
# INT64 inputs to casts *should* be ok, since they should be rejected by
63+
# CheckInt64InputsAndOutputs if the cast can't be done AOT.
64+
torch.int64: [
65+
torch.int8,
66+
torch.int16,
67+
torch.int32,
68+
torch.float16,
69+
torch.bfloat16,
70+
torch.float32,
71+
],
6172
torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32],
6273
torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32],
6374
torch.float32: [
@@ -71,22 +82,20 @@ def _merge_supported_types(
7182
ALL_SUPPORTED_TYPES = _merge_supported_types(
7283
SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES
7384
)
74-
POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32}
7585

7686
def is_node_tosa_supported(
7787
self, node: fx.Node, tosa_spec: TosaSpecification
7888
) -> bool:
79-
supported_dtypes = (
80-
self.ALL_SUPPORTED_TYPES
81-
if tosa_spec.support_float()
82-
else self.SUPPORTED_INT_TYPES
83-
)
84-
# Take into account possible type conversions
85-
supported_dtypes.update(
86-
(k, supported_dtypes[v])
87-
for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items()
88-
if v in supported_dtypes
89-
)
89+
90+
supported_dtypes: SupportedTypeDict = {}
91+
if tosa_spec.support_integer():
92+
supported_dtypes = self._merge_supported_types(
93+
self.SUPPORTED_INT_TYPES, supported_dtypes
94+
)
95+
if tosa_spec.support_float():
96+
supported_dtypes = self._merge_supported_types(
97+
self.SUPPORTED_FLOAT_TYPES, supported_dtypes
98+
)
9099

91100
if len(node.all_input_nodes) != 1:
92101
self.reporter.report_reject(
@@ -156,7 +165,7 @@ def is_node_tosa_supported(
156165
if "dim_order" in node.kwargs:
157166
dim_order = node.kwargs["dim_order"]
158167
# pyre-ignore[6]
159-
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
168+
if dim_order is not None and dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
160169
self.reporter.report_reject(
161170
node,
162171
(

backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ class TestT5EncoderModel(unittest.TestCase):
3333
# .to_executorch step, i.e. after Arm partitioner.
3434
ops_after_partitioner = {
3535
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
36+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
3637
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
37-
"torch.ops.higher_order.executorch_call_delegate": 2,
38+
"torch.ops.higher_order.executorch_call_delegate": 3,
3839
}
3940

4041
def _prepare_inputs(

backends/arm/test/ops/test_to_copy.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ def test_copy_tosa_FP(test_data: Tuple):
7070
aten_op=[],
7171
exir_op=[],
7272
)
73+
# int to int cast is not supported in TOSA+FP profile
74+
if not new_dtype.is_floating_point and not torch.is_floating_point(test_tensor):
75+
pipeline.change_args(
76+
"check_count.exir",
77+
{
78+
"torch.ops.higher_order.executorch_call_delegate": 0,
79+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
80+
},
81+
)
82+
pipeline.pop_stage("run_method_and_compare_outputs")
7383
pipeline.run()
7484

7585

@@ -84,6 +94,15 @@ def test_copy_vgf_FP(test_data: Tuple):
8494
exir_op=[],
8595
tosa_version="TOSA-1.0+FP",
8696
)
97+
# int to int cast is not supported in TOSA+FP profile
98+
if not new_dtype.is_floating_point and not torch.is_floating_point(test_tensor):
99+
pipeline.change_args(
100+
"check_count.exir",
101+
{
102+
"torch.ops.higher_order.executorch_call_delegate": 0,
103+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
104+
},
105+
)
87106
pipeline.run()
88107

89108

backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,8 @@ def forward(self, x: torch.Tensor):
3232
test_data_suite_convert = {
3333
"fp32_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.int64),
3434
"fp16_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.int64),
35-
"int16_input": lambda: (
36-
torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int16),
37-
torch.int64,
38-
),
39-
"int8_input": lambda: (
40-
torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8),
41-
torch.int64,
42-
),
4335
}
4436

45-
4637
test_data_suite_remove = {
4738
"int32_input": lambda: (
4839
torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32),
@@ -52,7 +43,7 @@ def forward(self, x: torch.Tensor):
5243

5344

5445
@common.parametrize("test_data", test_data_suite_convert)
55-
def test_convert_or_remove_casting_to_int64_covnert_tosa_FP(test_data: Tuple):
46+
def test_convert_or_remove_casting_to_int64_convert_tosa_FP(test_data: Tuple):
5647
test_tensor, target_dtype = test_data()
5748
module = CastingToInt64Model(target_dtype)
5849

0 commit comments

Comments
 (0)