@@ -76,8 +76,6 @@ def _merge_supported_types(
7676 def is_node_tosa_supported (
7777 self , node : fx .Node , tosa_spec : TosaSpecification
7878 ) -> bool :
79- assert node .target in self .targets
80-
8179 supported_dtypes = (
8280 self .ALL_SUPPORTED_TYPES
8381 if tosa_spec .support_float ()
@@ -90,10 +88,27 @@ def is_node_tosa_supported(
9088 if v in supported_dtypes
9189 )
9290
93- # Check input type
94- assert len (node .all_input_nodes ) == 1
91+ if len (node .all_input_nodes ) != 1 :
92+ self .reporter .report_reject (
93+ node ,
94+ (
95+ "Expected exactly one input node, "
96+ f"got { len (node .all_input_nodes )} for { node .target } ."
97+ ),
98+ )
99+ return False
95100 input_val = node .all_input_nodes [0 ].meta ["val" ]
96- assert isinstance (input_val , torch ._subclasses .FakeTensor )
101+ if not isinstance (input_val , torch ._subclasses .FakeTensor ):
102+ self .reporter .report_reject (
103+ node ,
104+ (
105+ "Invalid or missing meta: expected FakeTensor input, got "
106+ f"{ type (input_val ).__name__ } for { node .target } ."
107+ ),
108+ )
109+ return False
110+
111+ # Check input type
97112 input_dtype = input_val .dtype
98113 if input_dtype not in supported_dtypes :
99114 self .reporter .report_reject (
@@ -104,14 +119,24 @@ def is_node_tosa_supported(
104119
105120 # Check output type
106121 output_val = node .meta ["val" ]
107- assert isinstance (output_val , torch ._subclasses .FakeTensor )
122+ if not isinstance (output_val , torch ._subclasses .FakeTensor ):
123+ self .reporter .report_reject (
124+ node ,
125+ (
126+ "Invalid or missing meta: expected FakeTensor output, got "
127+ f"{ type (output_val ).__name__ } for { node .target } ."
128+ ),
129+ )
130+ return False
108131 if output_val .dtype not in supported_dtypes [input_dtype ]:
109132 self .reporter .report_reject (
110133 node ,
111- f"Output dtype { output_val .dtype } is not supported in "
112- f"{ node .target } for input dtype { input_dtype } . "
113- f"Supported output types: "
114- f"{ '' .join (str (t ) for t in supported_dtypes [input_dtype ])} " ,
134+ (
135+ f"Output dtype { output_val .dtype } is not supported in "
136+ f"{ node .target } for input dtype { input_dtype } . "
137+ f"Supported output types: "
138+ f"{ ', ' .join (str (t ) for t in supported_dtypes [input_dtype ])} "
139+ ),
115140 )
116141 return False
117142
@@ -120,8 +145,10 @@ def is_node_tosa_supported(
120145 if node .kwargs ["memory_format" ] in (torch .preserve_format ,):
121146 self .reporter .report_reject (
122147 node ,
123- f"Argument 'memory_format' is not supported for "
124- f"{ node .target } right now." ,
148+ (
149+ "Argument 'memory_format' is not supported for "
150+ f"{ node .target } right now."
151+ ),
125152 )
126153 return False
127154
@@ -132,8 +159,10 @@ def is_node_tosa_supported(
132159 if dim_order != list (range (len (dim_order ))): # type: ignore[arg-type]
133160 self .reporter .report_reject (
134161 node ,
135- f"Argument { dim_order = } is not supported for "
136- f"{ node .target } right now." ,
162+ (
163+ f"Argument { dim_order = } is not supported for "
164+ f"{ node .target } right now."
165+ ),
137166 )
138167 return False
139168
0 commit comments