Skip to content

Commit 1228126

Browse files
benkli01freddan80
authored andcommitted
Add support for torch.ops.aten._to_copy.default
Lower torch.ops.aten._to_copy.default to TOSA CAST op. This resolves issues around arithmetic operators when using int scalars in unquantized networks (see new test cases in test_scalars.py). Note: Parameter 'memory_format' is not supported. Change-Id: I7a921ca510c5b46f15b5399218f9230ba0f93d88
1 parent fc50da1 commit 1228126

File tree

6 files changed

+249
-2
lines changed

6 files changed

+249
-2
lines changed

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from . import ( # noqa
99
mean_dim_support,
1010
right_shift_support,
11+
to_copy_support,
1112
tosa_supported_operators,
1213
var_correction_support,
1314
)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
import logging
8+
9+
import torch
10+
11+
import torch.fx as fx
12+
13+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
14+
register_tosa_support_check,
15+
SupportedTOSAOperatorCheck,
16+
)
17+
from executorch.backends.arm.tosa_specification import TosaSpecification
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
@register_tosa_support_check
24+
class ToCopySupported(SupportedTOSAOperatorCheck):
25+
targets = [exir_ops.edge.aten._to_copy.default]
26+
27+
tosa_specs = [
28+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
29+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
30+
]
31+
32+
SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]
33+
34+
@staticmethod
35+
def _merge_supported_types(
36+
dtypes1: SupportedTypeDict, dtypes2: SupportedTypeDict
37+
) -> SupportedTypeDict:
38+
merged_dtypes = dtypes1
39+
for k, v in dtypes2.items():
40+
merged_dtypes[k] = merged_dtypes.get(k, []) + v
41+
return merged_dtypes
42+
43+
SUPPORTED_INT_TYPES: SupportedTypeDict = {
44+
torch.bool: [torch.int8, torch.int16, torch.int32],
45+
torch.int8: [torch.bool, torch.int16, torch.int32],
46+
torch.int16: [torch.bool, torch.int8, torch.int32],
47+
torch.int32: [torch.bool, torch.int8, torch.int16],
48+
}
49+
SUPPORTED_FLOAT_TYPES: SupportedTypeDict = {
50+
torch.int8: [torch.float16, torch.bfloat16, torch.float32],
51+
torch.int16: [torch.float16, torch.bfloat16, torch.float32],
52+
torch.int32: [torch.float16, torch.bfloat16, torch.float32],
53+
torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32],
54+
torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32],
55+
torch.float32: [
56+
torch.int8,
57+
torch.int16,
58+
torch.int32,
59+
torch.bfloat16,
60+
torch.float16,
61+
],
62+
}
63+
ALL_SUPPORTED_TYPES = _merge_supported_types(
64+
SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES
65+
)
66+
POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32}
67+
68+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool:
69+
assert node.target in self.targets
70+
71+
if tosa_spec not in self.tosa_specs:
72+
return False
73+
74+
assert tosa_spec.support_integer()
75+
supported_dtypes = (
76+
self.ALL_SUPPORTED_TYPES
77+
if tosa_spec.support_float()
78+
else self.SUPPORTED_INT_TYPES
79+
)
80+
# Take into account possible type conversions
81+
supported_dtypes.update(
82+
(k, supported_dtypes[v])
83+
for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items()
84+
if v in supported_dtypes
85+
)
86+
87+
# Check input type
88+
assert len(node.all_input_nodes) == 1
89+
input_val = node.all_input_nodes[0].meta["val"]
90+
assert isinstance(input_val, torch._subclasses.FakeTensor)
91+
input_dtype = input_val.dtype
92+
if input_dtype not in supported_dtypes:
93+
logger.info(
94+
f"Input dtype {input_val.dtype} is not supported in "
95+
f"{node.target.name()}."
96+
)
97+
return False
98+
99+
# Check output type
100+
output_val = node.meta["val"]
101+
assert isinstance(output_val, torch._subclasses.FakeTensor)
102+
if output_val.dtype not in supported_dtypes[input_dtype]:
103+
logger.info(
104+
f"Output dtype {output_val.dtype} is not supported in "
105+
f"{node.target.name()} for input dtype {input_dtype}. "
106+
f"Supported output types: "
107+
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}"
108+
)
109+
return False
110+
111+
# Check memory format
112+
if "memory_format" in node.kwargs:
113+
if node.kwargs["memory_format"] in (torch.preserve_format,):
114+
logger.info(
115+
f"Argument 'memory_format' is not supported for "
116+
f"{node.target.name()} right now."
117+
)
118+
return False
119+
120+
return True

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
op_sub,
3737
op_sum,
3838
op_tanh,
39+
op_to_copy,
3940
op_transpose,
4041
op_unsqueeze,
4142
op_upsample_nearest2d,
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
import tosa.Op as TosaOp
12+
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
19+
20+
@register_node_visitor
21+
class ToCopyVisitor(NodeVisitor):
22+
"""
23+
Implement the type cast functionality of _to_copy.
24+
25+
Other features like setting of the memory_format or moving a tensor to a
26+
different device are not supported.
27+
28+
Also note that the node should not be quantized.
29+
"""
30+
31+
target = "aten._to_copy.default"
32+
33+
def define_node(
34+
self,
35+
node: torch.fx.Node,
36+
tosa_graph: ts.TosaSerializer,
37+
inputs: List[TosaArg],
38+
output: TosaArg,
39+
is_quant_node: bool,
40+
) -> None:
41+
assert not is_quant_node, "Casting of quantized values is not supported."
42+
assert inputs
43+
tosa_graph.addOperator(TosaOp.Op().CAST, [inputs[0].name], [output.name])

backends/arm/test/ops/test_scalars.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,21 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple):
153153
.run_method_and_compare_outputs(inputs=test_data)
154154
)
155155

156-
# Most MI tests fail, just show one working for now.
157-
@parameterized.expand((tensor_scalar_tests[6],))
156+
@parameterized.expand(tensor_scalar_tests)
158157
def test_MI(self, test_name: str, op: torch.nn.Module, x, y):
158+
expected_exception = None
159+
if any(token in test_name for token in ("Sub_int", "Sub__int")):
160+
expected_exception = RuntimeError
161+
elif test_name.endswith("_st"):
162+
expected_exception = AttributeError
163+
164+
if expected_exception:
165+
with self.assertRaises(
166+
expected_exception, msg=f"Test {test_name} is expected to fail."
167+
):
168+
self._test_add_tosa_MI_pipeline(op, (x, y))
169+
return
170+
159171
self._test_add_tosa_MI_pipeline(op, (x, y))
160172

161173
# op(Scalar float, tensor) works if the scalar is constant.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
#
8+
# Tests the _to_copy op which is interpreted as a cast for our purposes.
9+
#
10+
11+
import unittest
12+
13+
import torch
14+
15+
from executorch.backends.arm.test import common
16+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17+
18+
from parameterized import parameterized
19+
20+
21+
class Cast(torch.nn.Module):
22+
def __init__(self, target_dtype):
23+
super().__init__()
24+
self.target_dtype = target_dtype
25+
26+
def forward(self, x: torch.Tensor):
27+
return x.to(dtype=self.target_dtype)
28+
29+
30+
class TestToCopy(unittest.TestCase):
31+
"""
32+
Tests the _to_copy operation.
33+
34+
Only test unquantized graphs as explicit casting of dtypes messes with the
35+
quantization.
36+
37+
Note: This is also covered by test_scalars.py.
38+
"""
39+
40+
_TO_COPY_TEST_DATA = (
41+
(torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.float32),
42+
(torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.float16),
43+
(torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), torch.float32),
44+
(torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), torch.int32),
45+
(torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32), torch.int8),
46+
)
47+
48+
def _test_to_copy_tosa_MI_pipeline(
49+
self, module: torch.nn.Module, test_data: torch.Tensor
50+
):
51+
(
52+
ArmTester(
53+
module,
54+
example_inputs=test_data,
55+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
56+
)
57+
.export()
58+
.dump_artifact()
59+
.check_count({"torch.ops.aten._to_copy.default": 1})
60+
.to_edge()
61+
.dump_artifact()
62+
.partition()
63+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
64+
.to_executorch()
65+
.run_method_and_compare_outputs(inputs=test_data)
66+
)
67+
68+
@parameterized.expand(_TO_COPY_TEST_DATA)
69+
def test_view_tosa_MI(self, test_tensor: torch.Tensor, new_dtype):
70+
self._test_to_copy_tosa_MI_pipeline(Cast(new_dtype), (test_tensor,))

0 commit comments

Comments
 (0)