Skip to content

Commit 042e087

Browse files
Arm backend: Add docstrings for to_dim_order_copy_support.py (#14537)
Signed-off-by: Sebastian Larsson <[email protected]>
1 parent c18abc8 commit 042e087

File tree

1 file changed

+38
-8
lines changed

1 file changed

+38
-8
lines changed

backends/arm/operator_support/to_dim_order_copy_support.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Declare operator support for ``_to_dim_order_copy`` in TOSA.
6+
7+
Provide dtype-compatibility checks for casting when converting to a specific
8+
dimension order. Supported input/output dtype pairs depend on the active TOSA
9+
profile (integer and/or float).
10+
11+
"""
512

613
# pyre-unsafe
714
import copy
@@ -25,6 +32,16 @@
2532

2633
@register_tosa_support_check
2734
class ToCopySupported(SupportedTOSAOperatorCheck):
35+
"""Provide TOSA support check for ``_to_dim_order_copy``.
36+
37+
Attributes:
38+
SUPPORTED_INT_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]):
39+
Allowed output dtypes for each integer input dtype.
40+
SUPPORTED_FP_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]):
41+
Allowed output dtypes for each floating input dtype.
42+
43+
"""
44+
2845
targets = [
2946
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
3047
]
@@ -40,21 +57,31 @@ def _merge_supported_types(
4057
dtypes1: SupportedTypeDict,
4158
dtypes2: SupportedTypeDict,
4259
) -> SupportedTypeDict:
60+
"""Return a merged mapping of supported dtype transitions.
61+
62+
Args:
63+
dtypes1 (dict[torch.dtype, list[torch.dtype]]): Base mapping.
64+
dtypes2 (dict[torch.dtype, list[torch.dtype]]): Mapping to merge in.
65+
66+
Returns:
67+
dict[torch.dtype, list[torch.dtype]]: Combined mapping.
68+
69+
"""
4370
merged_dtypes = copy.deepcopy(
4471
dtypes1
45-
) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_TYPES
72+
) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_PROFILE_DTYPES
4673
for k, v in dtypes2.items():
4774
merged_dtypes[k] = merged_dtypes.get(k, []) + v
4875
return merged_dtypes
4976

50-
SUPPORTED_INT_TYPES: SupportedTypeDict = {
77+
SUPPORTED_INT_PROFILE_DTYPES: SupportedTypeDict = {
5178
torch.bool: [torch.bool, torch.int8, torch.int16, torch.int32],
5279
torch.int8: [torch.bool, torch.int8, torch.int16, torch.int32],
5380
torch.int16: [torch.bool, torch.int8, torch.int16, torch.int32],
5481
torch.int32: [torch.bool, torch.int8, torch.int16, torch.int32],
5582
torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32],
5683
}
57-
SUPPORTED_FLOAT_TYPES: SupportedTypeDict = {
84+
SUPPORTED_FP_PROFILE_DTYPES: SupportedTypeDict = {
5885
torch.int8: [torch.int8, torch.float16, torch.bfloat16, torch.float32],
5986
torch.int16: [torch.int16, torch.float16, torch.bfloat16, torch.float32],
6087
torch.int32: [torch.int32, torch.float16, torch.bfloat16, torch.float32],
@@ -92,22 +119,25 @@ def _merge_supported_types(
92119
torch.float32,
93120
],
94121
}
95-
ALL_SUPPORTED_TYPES = _merge_supported_types(
96-
SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES
97-
)
98122

99123
def is_node_tosa_supported(
100124
self, node: fx.Node, tosa_spec: TosaSpecification
101125
) -> bool:
126+
"""Return True if the node is supported by TOSA.
127+
128+
Check FakeTensor metadata, validate input dtype is supported for the
129+
active profile, and ensure the output dtype is allowed for the given
130+
input dtype.
102131
132+
"""
103133
supported_dtypes: SupportedTypeDict = {}
104134
if tosa_spec.support_integer():
105135
supported_dtypes = self._merge_supported_types(
106-
self.SUPPORTED_INT_TYPES, supported_dtypes
136+
self.SUPPORTED_INT_PROFILE_DTYPES, supported_dtypes
107137
)
108138
if tosa_spec.support_float():
109139
supported_dtypes = self._merge_supported_types(
110-
self.SUPPORTED_FLOAT_TYPES, supported_dtypes
140+
self.SUPPORTED_FP_PROFILE_DTYPES, supported_dtypes
111141
)
112142

113143
if len(node.all_input_nodes) != 1:

0 commit comments

Comments
 (0)