Skip to content

Commit 72d6590

Browse files
authored
fix max_steps of msra_ner. (#1451)
1 parent 292ac62 commit 72d6590

File tree

1 file changed

+4
-2
lines changed
  • examples/information_extraction/msra_ner

1 file changed

+4
-2
lines changed

examples/information_extraction/msra_ner/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
parser = argparse.ArgumentParser()
4343

4444
# yapf: disable
45-
parser.add_argument("--model_type", default="bert", type=str, required=True, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), )
45+
parser.add_argument("--model_type", default="bert", type=str, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), )
4646
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join( sum([ list(classes[-1].pretrained_init_configuration.keys()) for classes in MODEL_CLASSES.values() ], [])), )
4747
parser.add_argument("--dataset", default="msra_ner", type=str, choices=["msra_ner", "peoples_daily_ner"] ,help="The named entity recognition datasets.")
4848
parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.")
@@ -218,7 +218,7 @@ def do_train(args):
218218
optimizer.step()
219219
lr_scheduler.step()
220220
optimizer.clear_grad()
221-
if global_step % args.save_steps == 0 or global_step == last_step:
221+
if global_step % args.save_steps == 0 or global_step == num_training_steps:
222222
if paddle.distributed.get_rank() == 0:
223223
if args.dataset == "peoples_daily_ner":
224224
evaluate(model, loss_fct, metric, dev_data_loader,
@@ -229,6 +229,8 @@ def do_train(args):
229229
paddle.save(model.state_dict(),
230230
os.path.join(args.output_dir,
231231
"model_%d.pdparams" % global_step))
232+
if global_step >= num_training_steps:
233+
return
232234

233235

234236
if __name__ == "__main__":

0 commit comments

Comments
 (0)