@@ -317,30 +317,42 @@ def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):
317317
318318
319319def test_multi_models (args ):
320+ verified_samples = None
321+ if args .verified_samples_path is not None :
322+ assert os .path .isfile (args .verified_samples_path )
323+ graphnet_root = path_utils .get_graphnet_root ()
324+ print (f"graphnet_root: { graphnet_root } " )
325+ verified_samples = []
326+ with open (args .verified_samples_path , "r" ) as f :
327+ for line in f .readlines ():
328+ verified_samples .append (os .path .join (graphnet_root , line .strip ()))
329+
320330 sample_idx = 0
321331 failed_samples = []
322332 for model_path in path_utils .get_recursively_model_path (args .model_path ):
323- print ( f"[ { sample_idx } ] test_compiler, model_path: { model_path } " )
324- cmd = " " . join (
325- [
326- sys . executable ,
327- "-m graph_net.paddle.test_compiler" ,
328- f"--model-path { model_path } " ,
329- f"--compiler { args . compiler } " ,
330- f"--device { args .device } " ,
331- f"--warmup { args .warmup } " ,
332- f"--trials { args .trials } " ,
333- f"--log-prompt { args .log_prompt } " ,
334- f"--output-dir { args .output_dir } " ,
335- ]
336- )
337- cmd_ret = os .system (cmd )
338- # assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
339- if cmd_ret != 0 :
340- failed_samples .append (model_path )
341- sample_idx += 1
333+ if verified_samples is None or os . path . abspath ( model_path ) in verified_samples :
334+ print ( f"[ { sample_idx } ] test_compiler, model_path: { model_path } " )
335+ cmd = " " . join (
336+ [
337+ sys . executable ,
338+ "-m graph_net.paddle.test_compiler " ,
339+ f"--model-path { model_path } " ,
340+ f"--compiler { args .compiler } " ,
341+ f"--device { args .device } " ,
342+ f"--warmup { args .warmup } " ,
343+ f"--trials { args .trials } " ,
344+ f"--log-prompt { args .log_prompt } " ,
345+ ]
346+ )
347+ cmd_ret = os .system (cmd )
348+ # assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
349+ if cmd_ret != 0 :
350+ failed_samples .append (model_path )
351+ sample_idx += 1
342352
343- print (f"Totally { sample_idx } samples, failed { len (failed_samples )} samples." )
353+ print (
354+ f"Totally { sample_idx } verified samples, failed { len (failed_samples )} samples."
355+ )
344356 for model_path in failed_samples :
345357 print (f"- { model_path } " )
346358
@@ -393,5 +405,12 @@ def main(args):
393405 default = "graph-net-test-compiler-log" ,
394406 help = "Log prompt for performance log filtering." ,
395407 )
408+ parser .add_argument (
409+ "--verified-samples-path" ,
410+ type = str ,
411+ required = False ,
412+ default = None ,
413+ help = "Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model" ,
414+ )
396415 args = parser .parse_args ()
397416 main (args = args )
0 commit comments