@@ -68,6 +68,7 @@ def main(args):
6868 logger .info ("cpu_num: {}" .format (os .getenv ("CPU_NUM" )))
6969
7070 use_gpu = config .get ("runner.use_gpu" , True )
71+ use_xpu = config .get ("runner.use_xpu" , False )
7172 use_auc = config .get ("runner.use_auc" , False )
7273 use_visual = config .get ("runner.use_visual" , False )
7374 auc_num = config .get ("runner.auc_num" , 1 )
@@ -80,12 +81,16 @@ def main(args):
8081 os .environ ["CPU_NUM" ] = str (config .get ("runner.thread_num" , 1 ))
8182 logger .info ("**************common.configs**********" )
8283 logger .info (
83- "use_gpu: {}, use_visual: {}, infer_batch_size: {}, test_data_dir: {}, start_epoch: {}, end_epoch: {}, print_interval: {}, model_load_path: {}" .
84- format (use_gpu , use_visual , batch_size , test_data_dir , start_epoch ,
85- end_epoch , print_interval , model_load_path ))
84+ "use_gpu: {}, use_xpu: {}, use_visual: {}, infer_batch_size: {}, test_data_dir: {}, start_epoch: {}, end_epoch: {}, print_interval: {}, model_load_path: {}" .
85+ format (use_gpu , use_xpu , use_visual , batch_size , test_data_dir ,
86+ start_epoch , end_epoch , print_interval , model_load_path ))
8687 logger .info ("**************common.configs**********" )
8788
88- place = paddle .set_device ('gpu' if use_gpu else 'cpu' )
89+ if use_xpu :
90+ xpu_device = 'xpu:{0}' .format (os .getenv ('FLAGS_selected_xpus' , 0 ))
91+ place = paddle .set_device (xpu_device )
92+ else :
93+ place = paddle .set_device ('gpu' if use_gpu else 'cpu' )
8994 exe = paddle .static .Executor (place )
9095 # initialize
9196 exe .run (paddle .static .default_startup_program ())
0 commit comments