2
2
#
3
3
# This source code is licensed under the BSD-style license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
+ """Declare operator support for ``_to_dim_order_copy`` in TOSA.
6
+
7
+ Provide dtype-compatibility checks for casting when converting to a specific
8
+ dimension order. Supported input/output dtype pairs depend on the active TOSA
9
+ profile (integer and/or float).
10
+
11
+ """
5
12
6
13
# pyre-unsafe
7
14
import copy
25
32
26
33
@register_tosa_support_check
27
34
class ToCopySupported (SupportedTOSAOperatorCheck ):
35
+ """Provide TOSA support check for ``_to_dim_order_copy``.
36
+
37
+ Attributes:
38
+ SUPPORTED_INT_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]):
39
+ Allowed output dtypes for each integer input dtype.
40
+ SUPPORTED_FP_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]):
41
+ Allowed output dtypes for each floating input dtype.
42
+
43
+ """
44
+
28
45
targets = [
29
46
exir_ops .edge .dim_order_ops ._to_dim_order_copy .default ,
30
47
]
@@ -40,21 +57,31 @@ def _merge_supported_types(
40
57
dtypes1 : SupportedTypeDict ,
41
58
dtypes2 : SupportedTypeDict ,
42
59
) -> SupportedTypeDict :
60
+ """Return a merged mapping of supported dtype transitions.
61
+
62
+ Args:
63
+ dtypes1 (dict[torch.dtype, list[torch.dtype]]): Base mapping.
64
+ dtypes2 (dict[torch.dtype, list[torch.dtype]]): Mapping to merge in.
65
+
66
+ Returns:
67
+ dict[torch.dtype, list[torch.dtype]]: Combined mapping.
68
+
69
+ """
43
70
merged_dtypes = copy .deepcopy (
44
71
dtypes1
45
- ) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_TYPES
72
+ ) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_PROFILE_DTYPES
46
73
for k , v in dtypes2 .items ():
47
74
merged_dtypes [k ] = merged_dtypes .get (k , []) + v
48
75
return merged_dtypes
49
76
50
- SUPPORTED_INT_TYPES : SupportedTypeDict = {
77
+ SUPPORTED_INT_PROFILE_DTYPES : SupportedTypeDict = {
51
78
torch .bool : [torch .bool , torch .int8 , torch .int16 , torch .int32 ],
52
79
torch .int8 : [torch .bool , torch .int8 , torch .int16 , torch .int32 ],
53
80
torch .int16 : [torch .bool , torch .int8 , torch .int16 , torch .int32 ],
54
81
torch .int32 : [torch .bool , torch .int8 , torch .int16 , torch .int32 ],
55
82
torch .int64 : [torch .bool , torch .int8 , torch .int16 , torch .int32 ],
56
83
}
57
- SUPPORTED_FLOAT_TYPES : SupportedTypeDict = {
84
+ SUPPORTED_FP_PROFILE_DTYPES : SupportedTypeDict = {
58
85
torch .int8 : [torch .int8 , torch .float16 , torch .bfloat16 , torch .float32 ],
59
86
torch .int16 : [torch .int16 , torch .float16 , torch .bfloat16 , torch .float32 ],
60
87
torch .int32 : [torch .int32 , torch .float16 , torch .bfloat16 , torch .float32 ],
@@ -92,22 +119,25 @@ def _merge_supported_types(
92
119
torch .float32 ,
93
120
],
94
121
}
95
- ALL_SUPPORTED_TYPES = _merge_supported_types (
96
- SUPPORTED_INT_TYPES , SUPPORTED_FLOAT_TYPES
97
- )
98
122
99
123
def is_node_tosa_supported (
100
124
self , node : fx .Node , tosa_spec : TosaSpecification
101
125
) -> bool :
126
+ """Return True if the node is supported by TOSA.
127
+
128
+ Check FakeTensor metadata, validate input dtype is supported for the
129
+ active profile, and ensure the output dtype is allowed for the given
130
+ input dtype.
102
131
132
+ """
103
133
supported_dtypes : SupportedTypeDict = {}
104
134
if tosa_spec .support_integer ():
105
135
supported_dtypes = self ._merge_supported_types (
106
- self .SUPPORTED_INT_TYPES , supported_dtypes
136
+ self .SUPPORTED_INT_PROFILE_DTYPES , supported_dtypes
107
137
)
108
138
if tosa_spec .support_float ():
109
139
supported_dtypes = self ._merge_supported_types (
110
- self .SUPPORTED_FLOAT_TYPES , supported_dtypes
140
+ self .SUPPORTED_FP_PROFILE_DTYPES , supported_dtypes
111
141
)
112
142
113
143
if len (node .all_input_nodes ) != 1 :
0 commit comments