44from pathlib import Path
55import sys
66import os
7- from dataclasses import dataclass
8- from contextlib import contextmanager
9- import time
10- import math
117import numpy as np
128import random
139import platform
@@ -62,7 +58,7 @@ def get_hardward_name(args):
6258 )
6359 )
6460 )
65- except Exception as e :
61+ except Exception :
6662 pass
6763 elif args .device == "cpu" :
6864 hardware = platform .processor ()
@@ -128,7 +124,7 @@ def get_static_model(args, model):
128124 backend = None ,
129125 )
130126 static_model .eval ()
131- program = static_model .forward .concrete_program .main_program
127+ program = static_model .forward .concrete_program .main_program # noqa
132128 return static_model
133129
134130
@@ -225,47 +221,56 @@ def measure_performance(model_call, args, compiler, profile=False):
225221
226222
227223def check_outputs (args , expected_out , compiled_out ):
228- if isinstance (expected_out , paddle .Tensor ):
229- expected_out = [expected_out ]
230- if isinstance (compiled_out , paddle .Tensor ):
231- compiled_out = [compiled_out ]
232-
233- eager_dtypes = [None ] * len (expected_out )
234- for i , tensor in enumerate (expected_out ):
235- eager_dtypes [i ] = (
236- str (tensor .dtype ).replace ("paddle." , "" ) if tensor is not None else "None"
237- )
238-
239- compiled_dtypes = [None ] * len (compiled_out )
240- for i , tensor in enumerate (compiled_out ):
241- compiled_dtypes [i ] = (
242- str (tensor .dtype ).replace ("paddle." , "" ) if tensor is not None else "None"
243- )
244-
224+ def _flatten_outputs_to_list (outs ):
225+ flattened_outs = outs
226+ if isinstance (outs , paddle .Tensor ):
227+ flattened_outs = [outs ]
228+ else :
229+ flattened_outs = [
230+ x
231+ for out in outs
232+ for x in (out if isinstance (out , (tuple , list )) else (out ,))
233+ ]
234+ return flattened_outs
235+
236+ expected_out = _flatten_outputs_to_list (expected_out )
237+ compiled_out = _flatten_outputs_to_list (compiled_out )
238+
239+ def _get_output_dtypes (outs ):
240+ dtypes = [
241+ str (tensor .dtype ).replace ("paddle." , "" )
242+ if isinstance (tensor , paddle .Tensor )
243+ else None
244+ for i , tensor in enumerate (outs )
245+ ]
246+ return dtypes
247+
248+ eager_dtypes = _get_output_dtypes (expected_out )
249+ compiled_dtypes = _get_output_dtypes (compiled_out )
245250 type_match = test_compiler_util .check_output_datatype (
246251 args , eager_dtypes , compiled_dtypes
247252 )
248253
249- eager_shapes = [None ] * len (expected_out )
250- for i , tensor in enumerate (expected_out ):
251- eager_shapes [i ] = tensor .shape if tensor is not None else None
252-
253- compiled_shapes = [None ] * len (compiled_out )
254- for i , tensor in enumerate (compiled_out ):
255- compiled_shapes [i ] = tensor .shape if tensor is not None else None
254+ def _get_output_shapes (outs ):
255+ shapes = [
256+ tensor .shape if isinstance (tensor , paddle .Tensor ) else None
257+ for i , tensor in enumerate (outs )
258+ ]
259+ return shapes
256260
261+ eager_shapes = _get_output_shapes (expected_out )
262+ compiled_shapes = _get_output_shapes (compiled_out )
257263 shape_match = test_compiler_util .check_output_shape (
258264 args , eager_shapes , compiled_shapes
259265 )
260266
261- def transfer_to_float (origin_outputs ):
267+ def _transfer_to_float (origin_outputs ):
262268 outputs = []
263269 for item in origin_outputs :
264- if (
265- item is not None
266- and isinstance (item , paddle .Tensor )
267- and item .dtype not in [paddle .float32 , paddle .float64 ]
268- ):
270+ if isinstance (item , paddle .Tensor ) and item .dtype not in [
271+ paddle .float32 ,
272+ paddle .float64 ,
273+ ]:
269274 item = item .astype ("float32" )
270275 outputs .append (item )
271276 return outputs
@@ -278,8 +283,8 @@ def transfer_to_float(origin_outputs):
278283 cmp_equal_func = get_cmp_equal ,
279284 )
280285
281- expected_out_fp32 = transfer_to_float (expected_out )
282- compiled_out_fp32 = transfer_to_float (compiled_out )
286+ expected_out_fp32 = _transfer_to_float (expected_out )
287+ compiled_out_fp32 = _transfer_to_float (compiled_out )
283288 test_compiler_util .check_allclose (
284289 args ,
285290 expected_out_fp32 ,
@@ -308,11 +313,16 @@ def check_and_print_gpu_utilization(compiler):
308313
309314
310315def test_single_model (args ):
316+ model_path = os .path .normpath (args .model_path )
317+ test_compiler_util .print_with_log_prompt (
318+ "[Processing]" , model_path , args .log_prompt
319+ )
320+
311321 compiler = get_compiler_backend (args )
312322 check_and_print_gpu_utilization (compiler )
313323
314- input_dict = get_input_dict (args . model_path )
315- model = get_model (args . model_path )
324+ input_dict = get_input_dict (model_path )
325+ model = get_model (model_path )
316326 model .eval ()
317327
318328 test_compiler_util .print_basic_config (
@@ -341,7 +351,7 @@ def test_single_model(args):
341351 compiled_time_stats = {}
342352 try :
343353 print ("Run model in compiled mode." , file = sys .stderr , flush = True )
344- input_spec = get_input_spec (args . model_path )
354+ input_spec = get_input_spec (model_path )
345355 compiled_model = compiler (model , input_spec )
346356 compiled_out , compiled_time_stats = measure_performance (
347357 lambda : compiled_model (** input_dict ), args , compiler , profile = False
0 commit comments