20
20
21
21
logger = logging .getLogger (__name__ )
22
22
23
+ SupportedTypeDict = dict [torch .dtype , list [torch .dtype ]]
24
+
23
25
24
26
@register_tosa_support_check
25
27
class ToCopySupported (SupportedTOSAOperatorCheck ):
@@ -33,8 +35,6 @@ class ToCopySupported(SupportedTOSAOperatorCheck):
33
35
TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
34
36
]
35
37
36
- SupportedTypeDict = dict [torch .dtype , list [torch .dtype ]]
37
-
38
38
@staticmethod
39
39
def _merge_supported_types (
40
40
# pyre-ignore[11]
@@ -53,11 +53,22 @@ def _merge_supported_types(
53
53
torch .int8 : [torch .bool , torch .int16 , torch .int32 ],
54
54
torch .int16 : [torch .bool , torch .int8 , torch .int32 ],
55
55
torch .int32 : [torch .bool , torch .int8 , torch .int16 ],
56
+ torch .int64 : [torch .bool , torch .int8 , torch .int16 , torch .int32 ],
56
57
}
57
58
SUPPORTED_FLOAT_TYPES : SupportedTypeDict = {
58
59
torch .int8 : [torch .float16 , torch .bfloat16 , torch .float32 ],
59
60
torch .int16 : [torch .float16 , torch .bfloat16 , torch .float32 ],
60
61
torch .int32 : [torch .float16 , torch .bfloat16 , torch .float32 ],
62
+ # INT64 inputs to casts *should* be ok, since they should be rejected by
63
+ # CheckInt64InputsAndOutputs if the cast can't be done AOT.
64
+ torch .int64 : [
65
+ torch .int8 ,
66
+ torch .int16 ,
67
+ torch .int32 ,
68
+ torch .float16 ,
69
+ torch .bfloat16 ,
70
+ torch .float32 ,
71
+ ],
61
72
torch .bfloat16 : [torch .int8 , torch .int16 , torch .int32 , torch .float32 ],
62
73
torch .float16 : [torch .int8 , torch .int16 , torch .int32 , torch .float32 ],
63
74
torch .float32 : [
@@ -71,22 +82,20 @@ def _merge_supported_types(
71
82
ALL_SUPPORTED_TYPES = _merge_supported_types (
72
83
SUPPORTED_INT_TYPES , SUPPORTED_FLOAT_TYPES
73
84
)
74
- POSSIBLE_TYPE_CONVERSIONS = {torch .int64 : torch .int32 }
75
85
76
86
def is_node_tosa_supported (
77
87
self , node : fx .Node , tosa_spec : TosaSpecification
78
88
) -> bool :
79
- supported_dtypes = (
80
- self .ALL_SUPPORTED_TYPES
81
- if tosa_spec .support_float ()
82
- else self .SUPPORTED_INT_TYPES
83
- )
84
- # Take into account possible type conversions
85
- supported_dtypes .update (
86
- (k , supported_dtypes [v ])
87
- for k , v in self .POSSIBLE_TYPE_CONVERSIONS .items ()
88
- if v in supported_dtypes
89
- )
89
+
90
+ supported_dtypes : SupportedTypeDict = {}
91
+ if tosa_spec .support_integer ():
92
+ supported_dtypes = self ._merge_supported_types (
93
+ self .SUPPORTED_INT_TYPES , supported_dtypes
94
+ )
95
+ if tosa_spec .support_float ():
96
+ supported_dtypes = self ._merge_supported_types (
97
+ self .SUPPORTED_FLOAT_TYPES , supported_dtypes
98
+ )
90
99
91
100
if len (node .all_input_nodes ) != 1 :
92
101
self .reporter .report_reject (
@@ -156,7 +165,7 @@ def is_node_tosa_supported(
156
165
if "dim_order" in node .kwargs :
157
166
dim_order = node .kwargs ["dim_order" ]
158
167
# pyre-ignore[6]
159
- if dim_order != list (range (len (dim_order ))): # type: ignore[arg-type]
168
+ if dim_order is not None and dim_order != list (range (len (dim_order ))): # type: ignore[arg-type]
160
169
self .reporter .report_reject (
161
170
node ,
162
171
(
0 commit comments