@@ -206,7 +206,7 @@ def check_outputs(args, expected_out, compiled_out):
206206 args , eager_dtypes , compiled_dtypes
207207 )
208208
209- def regular_outputs (origin_outputs ):
209+ def transfer_to_float (origin_outputs ):
210210 outputs = []
211211 for item in origin_outputs :
212212 if (
@@ -219,14 +219,19 @@ def regular_outputs(origin_outputs):
219219 return outputs
220220
221221 if type_match :
222- expected_out = regular_outputs (expected_out )
223- compiled_out = regular_outputs (compiled_out )
224-
225- test_compiler_util .check_correctness (
222+ test_compiler_util .check_equal (
226223 args ,
227224 expected_out ,
228225 compiled_out ,
229226 cmp_equal_func = get_cmp_equal ,
227+ )
228+
229+ expected_out_fp32 = transfer_to_float (expected_out )
230+ compiled_out_fp32 = transfer_to_float (compiled_out )
231+ test_compiler_util .check_allclose (
232+ args ,
233+ expected_out_fp32 ,
234+ compiled_out_fp32 ,
230235 cmp_all_close_func = get_cmp_all_close ,
231236 cmp_max_diff_func = get_cmp_max_diff ,
232237 cmp_mean_diff_func = get_cmp_mean_diff ,
@@ -240,8 +245,6 @@ def test_single_model(args):
240245 model = get_model (args )
241246 model .eval ()
242247
243- # num_eager_ops = count_number_of_ops(args, model, eager_mode=True)
244-
245248 test_compiler_util .print_basic_config (
246249 args , get_hardward_name (args ), get_compile_framework_version (args )
247250 )
@@ -314,8 +317,11 @@ def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):
314317
315318
316319def test_multi_models (args ):
320+ sample_idx = 0
321+ failed_samples = []
317322 for model_path in path_utils .get_recursively_model_path (args .model_path ):
318- cmd = "" .join (
323+ print (f"[{ sample_idx } ] test_compiler, model_path: { model_path } " )
324+ cmd = " " .join (
319325 [
320326 sys .executable ,
321327 "-m graph_net.paddle.test_compiler" ,
@@ -329,7 +335,14 @@ def test_multi_models(args):
329335 ]
330336 )
331337 cmd_ret = os .system (cmd )
332- assert cmd_ret == 0 , f"{ cmd_ret = } , { cmd = } "
338+ # assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
339+ if cmd_ret != 0 :
340+ failed_samples .append (model_path )
341+ sample_idx += 1
342+
343+ print (f"Totally { sample_idx } samples, failed { len (failed_samples )} samples." )
344+ for model_path in failed_samples :
345+ print (f"- { model_path } " )
333346
334347
335348def main (args ):
@@ -380,12 +393,5 @@ def main(args):
380393 default = "graph-net-test-compiler-log" ,
381394 help = "Log prompt for performance log filtering." ,
382395 )
383- parser .add_argument (
384- "--output-dir" ,
385- type = str ,
386- required = False ,
387- default = None ,
388- help = "Directory to save the structured JSON result file." ,
389- )
390396 args = parser .parse_args ()
391397 main (args = args )
0 commit comments