22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5+ """Declare operator support for ``_to_dim_order_copy`` in TOSA.
6+
7+ Provide dtype-compatibility checks for casting when converting to a specific
8+ dimension order. Supported input/output dtype pairs depend on the active TOSA
9+ profile (integer and/or float).
10+
11+ """
512
613# pyre-unsafe
714import copy
2532
2633@register_tosa_support_check
2734class ToCopySupported (SupportedTOSAOperatorCheck ):
35+ """Provide TOSA support check for ``_to_dim_order_copy``.
36+
37+ Attributes:
38+ SUPPORTED_INT_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]):
39+ Allowed output dtypes for each integer input dtype.
40+ SUPPORTED_FP_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]):
41+ Allowed output dtypes for each floating input dtype.
42+
43+ """
44+
2845 targets = [
2946 exir_ops .edge .dim_order_ops ._to_dim_order_copy .default ,
3047 ]
@@ -40,21 +57,31 @@ def _merge_supported_types(
4057 dtypes1 : SupportedTypeDict ,
4158 dtypes2 : SupportedTypeDict ,
4259 ) -> SupportedTypeDict :
60+ """Return a merged mapping of supported dtype transitions.
61+
62+ Args:
63+ dtypes1 (dict[torch.dtype, list[torch.dtype]]): Base mapping.
64+ dtypes2 (dict[torch.dtype, list[torch.dtype]]): Mapping to merge in.
65+
66+ Returns:
67+ dict[torch.dtype, list[torch.dtype]]: Combined mapping.
68+
69+ """
4370 merged_dtypes = copy .deepcopy (
4471 dtypes1
45- ) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_TYPES
72+ ) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_PROFILE_DTYPES
4673 for k , v in dtypes2 .items ():
4774 merged_dtypes [k ] = merged_dtypes .get (k , []) + v
4875 return merged_dtypes
4976
50- SUPPORTED_INT_TYPES : SupportedTypeDict = {
77+ SUPPORTED_INT_PROFILE_DTYPES : SupportedTypeDict = {
5178 torch .bool : [torch .bool , torch .int8 , torch .int16 , torch .int32 ],
5279 torch .int8 : [torch .bool , torch .int8 , torch .int16 , torch .int32 ],
5380 torch .int16 : [torch .bool , torch .int8 , torch .int16 , torch .int32 ],
5481 torch .int32 : [torch .bool , torch .int8 , torch .int16 , torch .int32 ],
5582 torch .int64 : [torch .bool , torch .int8 , torch .int16 , torch .int32 ],
5683 }
57- SUPPORTED_FLOAT_TYPES : SupportedTypeDict = {
84+ SUPPORTED_FP_PROFILE_DTYPES : SupportedTypeDict = {
5885 torch .int8 : [torch .int8 , torch .float16 , torch .bfloat16 , torch .float32 ],
5986 torch .int16 : [torch .int16 , torch .float16 , torch .bfloat16 , torch .float32 ],
6087 torch .int32 : [torch .int32 , torch .float16 , torch .bfloat16 , torch .float32 ],
@@ -92,22 +119,25 @@ def _merge_supported_types(
92119 torch .float32 ,
93120 ],
94121 }
95- ALL_SUPPORTED_TYPES = _merge_supported_types (
96- SUPPORTED_INT_TYPES , SUPPORTED_FLOAT_TYPES
97- )
98122
99123 def is_node_tosa_supported (
100124 self , node : fx .Node , tosa_spec : TosaSpecification
101125 ) -> bool :
126+ """Return True if the node is supported by TOSA.
127+
128+ Check FakeTensor metadata, validate input dtype is supported for the
129+ active profile, and ensure the output dtype is allowed for the given
130+ input dtype.
102131
132+ """
103133 supported_dtypes : SupportedTypeDict = {}
104134 if tosa_spec .support_integer ():
105135 supported_dtypes = self ._merge_supported_types (
106- self .SUPPORTED_INT_TYPES , supported_dtypes
136+ self .SUPPORTED_INT_PROFILE_DTYPES , supported_dtypes
107137 )
108138 if tosa_spec .support_float ():
109139 supported_dtypes = self ._merge_supported_types (
110- self .SUPPORTED_FLOAT_TYPES , supported_dtypes
140+ self .SUPPORTED_FP_PROFILE_DTYPES , supported_dtypes
111141 )
112142
113143 if len (node .all_input_nodes ) != 1 :
0 commit comments