@@ -35,21 +35,11 @@ def is_node_tosa_supported(
3535 ) -> bool :
3636 assert node .target in self .targets
3737
38- supported_dtypes = {torch .bool , torch .int8 , torch .int16 , torch .int32 }
39- if tosa_spec .support_float ():
40- supported_dtypes |= {torch .bfloat16 , torch .float16 , torch .float32 }
41-
4238 # Check input type
4339 assert len (node .all_input_nodes ) == 1
4440 input_val = node .all_input_nodes [0 ].meta ["val" ]
4541 assert isinstance (input_val , torch ._subclasses .FakeTensor )
4642 input_dtype = input_val .dtype
47- if input_dtype not in supported_dtypes :
48- self .reporter .report_reject (
49- node ,
50- f"Input dtype { input_val .dtype } is not supported in { node .target } ." ,
51- )
52- return False
5343
5444 # Check output type
5545 output_val = node .meta ["val" ]
@@ -61,6 +51,16 @@ def is_node_tosa_supported(
6151 )
6252 return False
6353
54+ # Check memory format
55+ if "memory_format" in node .kwargs :
56+ if node .kwargs ["memory_format" ] in (torch .preserve_format ,):
57+ self .reporter .report_reject (
58+ node ,
59+ f"Argument 'memory_format' is not supported for "
60+ f"{ node .target } right now." ,
61+ )
62+ return False
63+
6464 # Check dim_order
6565 if "dim_order" in node .kwargs :
6666 dim_order = node .kwargs ["dim_order" ]
0 commit comments