5050def parse_args ():
5151 parser = argparse .ArgumentParser (description = 'paddle-rec run' )
5252 parser .add_argument ("-m" , "--config_yaml" , type = str )
53+ parser .add_argument ("--device" , type = str )
5354 args = parser .parse_args ()
5455 args .abs_dir = os .path .dirname (os .path .abspath (args .config_yaml ))
5556 args .config_yaml = get_abs_model (args .config_yaml )
@@ -63,19 +64,26 @@ def main(args):
6364 dy_model_class = load_dy_model_class (args .abs_dir )
6465 config ["config_abs_dir" ] = args .abs_dir
6566 # tools.vars
66- use_gpu = config .get ("runner.use_gpu" , True )
67+ if args .device is None :
68+ use_gpu = config .get ("runner.use_gpu" , True )
69+ elif args .device == "gpu" :
70+ use_gpu = True
71+ else :
72+ use_gpu = False
73+
6774 use_visual = config .get ("runner.use_visual" , False )
6875 test_data_dir = config .get ("runner.test_data_dir" , None )
6976 print_interval = config .get ("runner.print_interval" , None )
77+ infer_batch_size = config .get ("runner.infer_batch_size" , None )
7078 model_load_path = config .get ("runner.infer_load_path" , "model_output" )
7179 start_epoch = config .get ("runner.infer_start_epoch" , 0 )
7280 end_epoch = config .get ("runner.infer_end_epoch" , 10 )
7381
7482 logger .info ("**************common.configs**********" )
7583 logger .info (
76- "use_gpu: {}, use_visual: {}, test_data_dir: {}, start_epoch: {}, end_epoch: {}, print_interval: {}, model_load_path: {}" .
77- format (use_gpu , use_visual , test_data_dir , start_epoch , end_epoch ,
78- print_interval , model_load_path ))
84+ "use_gpu: {}, use_visual: {}, infer_batch_size: {}, test_data_dir: {}, start_epoch: {}, end_epoch: {}, print_interval: {}, model_load_path: {}" .
85+ format (use_gpu , use_visual , infer_batch_size , test_data_dir ,
86+ start_epoch , end_epoch , print_interval , model_load_path ))
7987 logger .info ("**************common.configs**********" )
8088
8189 place = paddle .set_device ('gpu' if use_gpu else 'cpu' )
@@ -105,12 +113,20 @@ def main(args):
105113 model_path = os .path .join (model_load_path , str (epoch_id ))
106114 load_model (model_path , dy_model )
107115 dy_model .eval ()
116+ infer_reader_cost = 0.0
117+ infer_run_cost = 0.0
118+ reader_start = time .time ()
119+
108120 for batch_id , batch in enumerate (test_dataloader ()):
121+ infer_reader_cost += time .time () - reader_start
122+ infer_start = time .time ()
109123 batch_size = len (batch [0 ])
110124
111125 metric_list , tensor_print_dict = dy_model_class .infer_forward (
112126 dy_model , metric_list , batch , config )
113127
128+ infer_run_cost += time .time () - infer_start
129+
114130 if batch_id % print_interval == 0 :
115131 tensor_print_str = ""
116132 if tensor_print_dict is not None :
@@ -133,13 +149,19 @@ def main(args):
133149 tag = "infer/" + metric_list_name [metric_id ],
134150 step = step_num ,
135151 value = metric_list [metric_id ].accumulate ())
136- logger .info ("epoch: {}, batch_id: {}, " .format (
137- epoch_id , batch_id ) + metric_str + tensor_print_str +
138- " speed: {:.2f} ins/s" .format (
139- print_interval * batch_size / (time .time (
140- ) - interval_begin )))
152+ logger .info (
153+ "epoch: {}, batch_id: {}, " .format (
154+ epoch_id , batch_id ) + metric_str + tensor_print_str +
155+ " avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, speed: {:.2f} ins/s" .
156+ format (infer_reader_cost / print_interval , (
157+ infer_reader_cost + infer_run_cost ) / print_interval ,
158+ print_interval * batch_size / (time .time () -
159+ interval_begin )))
141160 interval_begin = time .time ()
161+ infer_reader_cost = 0.0
162+ infer_run_cost = 0.0
142163 step_num = step_num + 1
164+ reader_start = time .time ()
143165
144166 metric_str = ""
145167 for metric_id in range (len (metric_list_name )):
0 commit comments