@@ -134,7 +134,7 @@ def run_worker(self):
134134 if save_model_path and (not os .path .exists (save_model_path )):
135135 os .makedirs (save_model_path )
136136
137- reader_type = self .config .get ("runner.reader_type" , None )
137+ reader_type = self .config .get ("runner.reader_type" , "QueueDataset" )
138138 epochs = int (self .config .get ("runner.epochs" ))
139139 sync_mode = self .config .get ("runner.sync_mode" )
140140
@@ -150,10 +150,6 @@ def run_worker(self):
150150 self .dataset_train_loop (epoch )
151151 elif reader_type == "InmemoryDataset" :
152152 self .dataset_train_loop (epoch )
153- elif reader_type == "DataLoader" :
154- self .dataloader_train_loop (epoch )
155- elif reader_type == None or reader_type == "RecDataset" :
156- self .recdataset_train_loop (epoch )
157153
158154 epoch_time = time .time () - epoch_start_time
159155 epoch_speed = self .example_nums / epoch_time
@@ -182,6 +178,8 @@ def run_worker(self):
182178 def init_reader (self ):
183179 if fleet .is_server ():
184180 return
181+ self .config ["runner.reader_type" ] = self .config .get (
182+ "runner.reader_type" , "QueueDataset" )
185183 self .reader , self .file_list = get_reader (self .input_data , config )
186184 self .example_nums = 0
187185 self .count_method = self .config .get ("runner.example_count_method" ,
@@ -222,91 +220,6 @@ def dataset_train_loop(self, epoch):
222220 print_period = print_step ,
223221 debug = debug )
224222
225- def dataloader_train_loop (self , epoch ):
226- logger .info ("Epoch: {}, Running DataLoader Begin." .format (epoch ))
227- batch_id = 0
228- train_run_cost = 0.0
229- total_examples = 0
230- self .reader .start ()
231- while True :
232- try :
233- train_start = time .time ()
234- # --------------------------------------------------- #
235- fetch_var = self .exe .run (
236- program = paddle .static .default_main_program (),
237- fetch_list = [var for _ , var in self .metrics .items ()])
238- # --------------------------------------------------- #
239- train_run_cost += time .time () - train_start
240- total_examples += (self .config .get ("runner.train_batch_size" ))
241- batch_id += 1
242- print_step = int (config .get ("runner.print_interval" ))
243- if batch_id % print_step == 0 :
244- metrics_string = ""
245- for var_idx , var_name in enumerate (self .metrics ):
246- metrics_string += "{}: {}, " .format (
247- var_name , fetch_var [var_idx ]
248- if var_name != "LOSS" or not config ['pure_bf16' ]
249- else bf16_to_fp32 (fetch_var [var_idx ][0 ]))
250- profiler_string = ""
251- profiler_string += "avg_batch_cost: {} sec, " .format (
252- format ((train_run_cost ) / print_step , '.5f' ))
253- profiler_string += "avg_samples: {}, " .format (
254- format (total_examples / print_step , '.5f' ))
255- profiler_string += "ips: {} {}/sec " .format (
256- format (total_examples / (train_run_cost ), '.5f' ),
257- self .count_method )
258- logger .info ("Epoch: {}, Batch: {}, {} {}" .format (
259- epoch , batch_id , metrics_string , profiler_string ))
260- train_run_cost = 0.0
261- total_examples = 0
262- except paddle .fluid .core .EOFException :
263- self .reader .reset ()
264- break
265-
266- def recdataset_train_loop (self , epoch ):
267- logger .info ("Epoch: {}, Running RecDatast Begin." .format (epoch ))
268-
269- input_data_names = [var .name for var in self .input_data ]
270- batch_size = config .get ("runner.train_batch_size" , None )
271- print_interval = config .get ("runner.print_interval" , None )
272-
273- batch_id = 0
274- train_run_cost = 0.0
275- train_reader_cost = 0.0
276- total_samples = 0
277- reader_start = time .time ()
278- for batch_id , batch_data in enumerate (self .reader ()):
279- train_reader_cost += time .time () - reader_start
280- train_start = time .time ()
281- # --------------------------------------------------- #
282- fetch_batch_var = self .exe .run (
283- program = paddle .static .default_main_program (),
284- feed = dict (zip (input_data_names , batch_data )),
285- fetch_list = [var for _ , var in self .metrics .items ()])
286- # --------------------------------------------------- #
287- train_run_cost += time .time () - train_start
288- total_samples += batch_size
289- if batch_id % print_interval == 0 :
290- metric_str = ""
291- for var_idx , var_name in enumerate (self .metrics ):
292- metric_str += "{}: {}, " .format (
293- var_name , fetch_batch_var [var_idx ]
294- if var_name != "LOSS" or config ['pure_bf16' ] is False
295- else bf16_to_fp32 (fetch_batch_var [var_idx ][0 ]))
296- logger .info (
297- "Epoch: {}, Batch_id: {}, " .format (epoch ,
298- batch_id ) + metric_str +
299- " avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} {}/sec"
300- .format (train_reader_cost / print_interval , (
301- train_reader_cost + train_run_cost ) / print_interval ,
302- total_samples / print_interval , total_samples / (
303- train_reader_cost + train_run_cost ),
304- self .count_method ))
305- train_reader_cost = 0.0
306- train_run_cost = 0.0
307- total_samples = 0
308- reader_start = time .time ()
309-
310223 def heter_train_loop (self , epoch ):
311224 logger .info (
312225 "Epoch: {}, Running Begin. Check running metrics at heter_log" .
0 commit comments