55
66import logging
77
8+ import torch
89import torch .fx as fx
910from executorch .backends .arm .operator_support .tosa_supported_operators import (
1011 register_tosa_support_check ,
1112 SupportedTOSAOperatorCheck ,
1213)
13- from executorch .backends .arm .tosa_specification import TosaSpecification
14+ from executorch .backends .arm .tosa import TosaSpecification
1415from executorch .exir .dialects ._ops import ops as exir_ops
1516
1617logger = logging .getLogger (__name__ )
1718
1819
1920@register_tosa_support_check
2021class CloneSupported (SupportedTOSAOperatorCheck ):
21- targets = [exir_ops .edge .aten . clone .default ]
22+ targets = [exir_ops .edge .dim_order_ops . _clone_dim_order .default ]
2223
2324 tosa_specs = [
2425 TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
@@ -28,10 +29,62 @@ class CloneSupported(SupportedTOSAOperatorCheck):
2829 def is_node_tosa_supported (
2930 self , node : fx .Node , tosa_spec : TosaSpecification
3031 ) -> bool :
32+ if node .target not in self .targets :
33+ self .reporter .report_reject (node , f"Target { node .target } is not supported." )
34+ return False
3135
3236 input_node = node .args [0 ]
3337 if not isinstance (input_node , fx .Node ):
3438 self .reporter .report_reject (node , "Non tensor clones are not supported" )
3539 return False
3640
41+ # Check input node
42+ if len (node .all_input_nodes ) != 1 :
43+ self .reporter .report_reject (
44+ node , f"Expected 1 input node, got { len (node .all_input_nodes )} "
45+ )
46+ return False
47+
48+ input_val = node .all_input_nodes [0 ].meta ["val" ]
49+ if not isinstance (input_val , torch ._subclasses .FakeTensor ):
50+ self .reporter .report_reject (node , "Expected input to be a FakeTensor." )
51+ return False
52+
53+ input_dtype = input_val .dtype
54+
55+ # Check output node
56+ output_val = node .meta ["val" ]
57+ if not isinstance (output_val , torch ._subclasses .FakeTensor ):
58+ self .reporter .report_reject (node , "Expected output to be a FakeTensor." )
59+ return False
60+
61+ if output_val .dtype != input_dtype :
62+ self .reporter .report_reject (
63+ node ,
64+ f"Input dtype { input_val .dtype } does not match { output_val .dtype } ." ,
65+ )
66+ return False
67+
68+ # Check memory format
69+ if "memory_format" in node .kwargs :
70+ if node .kwargs ["memory_format" ] in (torch .preserve_format ,):
71+ self .reporter .report_reject (
72+ node ,
73+ f"Argument 'memory_format' is not supported for "
74+ f"{ node .target } right now." ,
75+ )
76+ return False
77+
78+ # Check dim_order
79+ if "dim_order" in node .kwargs :
80+ dim_order = node .kwargs ["dim_order" ]
81+ # pyre-ignore[6]
82+ if dim_order != list (range (len (dim_order ))): # type: ignore[arg-type]
83+ self .reporter .report_reject (
84+ node ,
85+ f"Argument { dim_order = } is not supported for "
86+ f"{ node .target } right now." ,
87+ )
88+ return False
89+
3790 return True
0 commit comments