@@ -221,47 +221,56 @@ def measure_performance(model_call, args, compiler, profile=False):
221221
222222
223223def check_outputs (args , expected_out , compiled_out ):
224- if isinstance (expected_out , paddle .Tensor ):
225- expected_out = [expected_out ]
226- if isinstance (compiled_out , paddle .Tensor ):
227- compiled_out = [compiled_out ]
228-
229- eager_dtypes = [None ] * len (expected_out )
230- for i , tensor in enumerate (expected_out ):
231- eager_dtypes [i ] = (
232- str (tensor .dtype ).replace ("paddle." , "" ) if tensor is not None else "None"
233- )
234-
235- compiled_dtypes = [None ] * len (compiled_out )
236- for i , tensor in enumerate (compiled_out ):
237- compiled_dtypes [i ] = (
238- str (tensor .dtype ).replace ("paddle." , "" ) if tensor is not None else "None"
239- )
240-
224+ def _flatten_outputs_to_list (outs ):
225+ flattened_outs = outs
226+ if isinstance (outs , paddle .Tensor ):
227+ flattened_outs = [outs ]
228+ else :
229+ flattened_outs = [
230+ x
231+ for out in outs
232+ for x in (out if isinstance (out , (tuple , list )) else (out ,))
233+ ]
234+ return flattened_outs
235+
236+ expected_out = _flatten_outputs_to_list (expected_out )
237+ compiled_out = _flatten_outputs_to_list (compiled_out )
238+
239+ def _get_output_dtypes (outs ):
240+ dtypes = [
241+ str (tensor .dtype ).replace ("paddle." , "" )
242+ if isinstance (tensor , paddle .Tensor )
243+ else None
244+ for i , tensor in enumerate (outs )
245+ ]
246+ return dtypes
247+
248+ eager_dtypes = _get_output_dtypes (expected_out )
249+ compiled_dtypes = _get_output_dtypes (compiled_out )
241250 type_match = test_compiler_util .check_output_datatype (
242251 args , eager_dtypes , compiled_dtypes
243252 )
244253
245- eager_shapes = [None ] * len (expected_out )
246- for i , tensor in enumerate (expected_out ):
247- eager_shapes [i ] = tensor .shape if tensor is not None else None
248-
249- compiled_shapes = [None ] * len (compiled_out )
250- for i , tensor in enumerate (compiled_out ):
251- compiled_shapes [i ] = tensor .shape if tensor is not None else None
254+ def _get_output_shapes (outs ):
255+ shapes = [
256+ tensor .shape if isinstance (tensor , paddle .Tensor ) else None
257+ for i , tensor in enumerate (outs )
258+ ]
259+ return shapes
252260
261+ eager_shapes = _get_output_shapes (expected_out )
262+ compiled_shapes = _get_output_shapes (compiled_out )
253263 shape_match = test_compiler_util .check_output_shape (
254264 args , eager_shapes , compiled_shapes
255265 )
256266
257- def transfer_to_float (origin_outputs ):
267+ def _transfer_to_float (origin_outputs ):
258268 outputs = []
259269 for item in origin_outputs :
260- if (
261- item is not None
262- and isinstance (item , paddle .Tensor )
263- and item .dtype not in [paddle .float32 , paddle .float64 ]
264- ):
270+ if isinstance (item , paddle .Tensor ) and item .dtype not in [
271+ paddle .float32 ,
272+ paddle .float64 ,
273+ ]:
265274 item = item .astype ("float32" )
266275 outputs .append (item )
267276 return outputs
@@ -274,8 +283,8 @@ def transfer_to_float(origin_outputs):
274283 cmp_equal_func = get_cmp_equal ,
275284 )
276285
277- expected_out_fp32 = transfer_to_float (expected_out )
278- compiled_out_fp32 = transfer_to_float (compiled_out )
286+ expected_out_fp32 = _transfer_to_float (expected_out )
287+ compiled_out_fp32 = _transfer_to_float (compiled_out )
279288 test_compiler_util .check_allclose (
280289 args ,
281290 expected_out_fp32 ,
0 commit comments