@@ -76,8 +76,6 @@ def _merge_supported_types(
76
76
def is_node_tosa_supported (
77
77
self , node : fx .Node , tosa_spec : TosaSpecification
78
78
) -> bool :
79
- assert node .target in self .targets
80
-
81
79
supported_dtypes = (
82
80
self .ALL_SUPPORTED_TYPES
83
81
if tosa_spec .support_float ()
@@ -90,10 +88,27 @@ def is_node_tosa_supported(
90
88
if v in supported_dtypes
91
89
)
92
90
93
- # Check input type
94
- assert len (node .all_input_nodes ) == 1
91
+ if len (node .all_input_nodes ) != 1 :
92
+ self .reporter .report_reject (
93
+ node ,
94
+ (
95
+ "Expected exactly one input node, "
96
+ f"got { len (node .all_input_nodes )} for { node .target } ."
97
+ ),
98
+ )
99
+ return False
95
100
input_val = node .all_input_nodes [0 ].meta ["val" ]
96
- assert isinstance (input_val , torch ._subclasses .FakeTensor )
101
+ if not isinstance (input_val , torch ._subclasses .FakeTensor ):
102
+ self .reporter .report_reject (
103
+ node ,
104
+ (
105
+ "Invalid or missing meta: expected FakeTensor input, got "
106
+ f"{ type (input_val ).__name__ } for { node .target } ."
107
+ ),
108
+ )
109
+ return False
110
+
111
+ # Check input type
97
112
input_dtype = input_val .dtype
98
113
if input_dtype not in supported_dtypes :
99
114
self .reporter .report_reject (
@@ -104,14 +119,24 @@ def is_node_tosa_supported(
104
119
105
120
# Check output type
106
121
output_val = node .meta ["val" ]
107
- assert isinstance (output_val , torch ._subclasses .FakeTensor )
122
+ if not isinstance (output_val , torch ._subclasses .FakeTensor ):
123
+ self .reporter .report_reject (
124
+ node ,
125
+ (
126
+ "Invalid or missing meta: expected FakeTensor output, got "
127
+ f"{ type (output_val ).__name__ } for { node .target } ."
128
+ ),
129
+ )
130
+ return False
108
131
if output_val .dtype not in supported_dtypes [input_dtype ]:
109
132
self .reporter .report_reject (
110
133
node ,
111
- f"Output dtype { output_val .dtype } is not supported in "
112
- f"{ node .target } for input dtype { input_dtype } . "
113
- f"Supported output types: "
114
- f"{ '' .join (str (t ) for t in supported_dtypes [input_dtype ])} " ,
134
+ (
135
+ f"Output dtype { output_val .dtype } is not supported in "
136
+ f"{ node .target } for input dtype { input_dtype } . "
137
+ f"Supported output types: "
138
+ f"{ ', ' .join (str (t ) for t in supported_dtypes [input_dtype ])} "
139
+ ),
115
140
)
116
141
return False
117
142
@@ -120,8 +145,10 @@ def is_node_tosa_supported(
120
145
if node .kwargs ["memory_format" ] in (torch .preserve_format ,):
121
146
self .reporter .report_reject (
122
147
node ,
123
- f"Argument 'memory_format' is not supported for "
124
- f"{ node .target } right now." ,
148
+ (
149
+ "Argument 'memory_format' is not supported for "
150
+ f"{ node .target } right now."
151
+ ),
125
152
)
126
153
return False
127
154
@@ -132,8 +159,10 @@ def is_node_tosa_supported(
132
159
if dim_order != list (range (len (dim_order ))): # type: ignore[arg-type]
133
160
self .reporter .report_reject (
134
161
node ,
135
- f"Argument { dim_order = } is not supported for "
136
- f"{ node .target } right now." ,
162
+ (
163
+ f"Argument { dim_order = } is not supported for "
164
+ f"{ node .target } right now."
165
+ ),
137
166
)
138
167
return False
139
168
0 commit comments