@@ -345,17 +345,17 @@ def do_train(args):
345
345
evaluate (model , loss_fct , metric , test_data_loader ,
346
346
language )
347
347
print ("eval done total : %s s" % (time .time () - tic_eval ))
348
- if paddle .distributed .get_rank () == 0 :
349
- output_dir = os .path .join (
350
- args . output_dir ,
351
- "ernie_m_ft_model_%d.pdparams" % (global_step ))
352
- if not os .path .exists (output_dir ):
353
- os .makedirs (output_dir )
354
- # Need better way to get inner model of DataParallel
355
- model_to_save = model ._layers if isinstance (
356
- model , paddle .DataParallel ) else model
357
- model_to_save .save_pretrained (output_dir )
358
- tokenizer .save_pretrained (output_dir )
348
+ if paddle .distributed .get_rank () == 0 :
349
+ output_dir = os .path .join (args . output_dir ,
350
+ "ernie_m_ft_model_%d.pdparams" %
351
+ (global_step ))
352
+ if not os .path .exists (output_dir ):
353
+ os .makedirs (output_dir )
354
+ # Need better way to get inner model of DataParallel
355
+ model_to_save = model ._layers if isinstance (
356
+ model , paddle .DataParallel ) else model
357
+ model_to_save .save_pretrained (output_dir )
358
+ tokenizer .save_pretrained (output_dir )
359
359
if global_step >= num_training_steps :
360
360
break
361
361
if global_step >= num_training_steps :
0 commit comments