@@ -90,11 +90,8 @@ def get_compiled_model(args, model):
9090
9191
9292def regular_item (item ):
93- if isinstance (item , paddle .Tensor ) and (item .dtype == paddle .bfloat16 ):
94- item = np .array (item .astype ("float32" ))
95- else :
96- item = np .array (item )
97- if item .dtype == np .bool_ :
93+ assert isinstance (item , paddle .Tensor )
94+ if item .dtype not in [paddle .float32 , paddle .float64 ]:
9895 item = item .astype ("float32" )
9996 return item
10097
@@ -306,32 +303,34 @@ def print_cmp(key, func, **kwargs):
306303
307304def get_cmp_equal (expected_out , compiled_out ):
308305 return " " .join (
309- str (int (np . sum ( np . equal ( a , b ) ))) for a , b in zip (expected_out , compiled_out )
306+ str (int (paddle . equal_all ( a , b ))) for a , b in zip (expected_out , compiled_out )
310307 )
311308
312309
313310def get_cmp_all_close (expected_out , compiled_out , atol , rtol ):
314311 return " " .join (
315- str (int (np .allclose (a , b , atol = atol , rtol = rtol )))
312+ str (int (paddle .allclose (a , b , atol = atol , rtol = rtol )))
316313 for a , b in zip (expected_out , compiled_out )
317314 )
318315
319316
320317def get_cmp_max_diff (expected_out , compiled_out ):
321318 return " " .join (
322- str (np .max (np .abs (a - b )).item ()) for a , b in zip (expected_out , compiled_out )
319+ str (paddle .max (paddle .abs (a - b )).item ())
320+ for a , b in zip (expected_out , compiled_out )
323321 )
324322
325323
326324def get_cmp_mean_diff (expected_out , compiled_out ):
327325 return " " .join (
328- str (np .mean (np .abs (a - b )).item ()) for a , b in zip (expected_out , compiled_out )
326+ str (paddle .mean (paddle .abs (a - b )).item ())
327+ for a , b in zip (expected_out , compiled_out )
329328 )
330329
331330
332331def get_cmp_diff_count (expected_out , compiled_out , atol , rtol ):
333332 return " " .join (
334- str (np .sum (~ np .isclose (a , b , atol = atol , rtol = rtol )).item ())
333+ str (paddle .sum (~ paddle .isclose (a , b , atol = atol , rtol = rtol )).item ())
335334 for a , b in zip (expected_out , compiled_out )
336335 )
337336
@@ -344,9 +343,11 @@ def test_multi_models(args):
344343 "-m graph_net.paddle.test_compiler" ,
345344 f"--model-path { model_path } " ,
346345 f"--compiler { args .compiler } " ,
346+ f"--device { args .device } " ,
347347 f"--warmup { args .warmup } " ,
348348 f"--trials { args .trials } " ,
349349 f"--log-prompt { args .log_prompt } " ,
350+ f"--output-dir { args .output_dir } " ,
350351 ]
351352 )
352353 cmd_ret = os .system (cmd )
0 commit comments