Skip to content

Commit fda3ff7

Browse files
author
tianxin
authored
add save_step for ernie_matching (#798)
* add FewCLUE 9 datasets * fix a bug for tnews * Add CI for Ernie text matching * Add CI for Ernie text matching * Add CI for Ernie text matching * fix encoding problem for windows * add save_step for ernie_matching
1 parent f1dc506 commit fda3ff7

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

examples/text_matching/ernie_matching/train_pairwise.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,19 @@
3333

3434
# yapf: disable
3535
parser = argparse.ArgumentParser()
36-
parser.add_argument("--margin", default=0.2, type=float, help="Margin for pos_score and neg_score")
37-
parser.add_argument("--eval_step", default=100, type=int, help="Steps interval for evaluation")
36+
parser.add_argument("--margin", default=0.2, type=float, help="Margin for pos_score and neg_score.")
3837
parser.add_argument("--save_dir", default='./checkpoint', type=str, help="The output directory where the model checkpoints will be written.")
3938
parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization. "
4039
"Sequences longer than this will be truncated, sequences shorter will be padded.")
4140
parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
4241
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
4342
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
4443
parser.add_argument("--epochs", default=3, type=int, help="Total number of training epochs to perform.")
44+
parser.add_argument("--eval_step", default=100, type=int, help="Step interval for evaluation.")
45+
parser.add_argument('--save_step', default=10000, type=int, help="Step interval for saving checkpoint.")
4546
parser.add_argument("--warmup_proportion", default=0.0, type=float, help="Linear warmup proption over the training process.")
4647
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
47-
parser.add_argument("--seed", type=int, default=1000, help="random seed for initialization")
48+
parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization.")
4849
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
4950
args = parser.parse_args()
5051
# yapf: enable
@@ -196,12 +197,12 @@ def do_train():
196197
optimizer.clear_grad()
197198

198199
if global_step % args.eval_step == 0 and rank == 0:
200+
evaluate(model, metric, dev_data_loader, "dev")
201+
202+
if global_step % args.save_step == 0 and rank == 0:
199203
save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
200204
if not os.path.exists(save_dir):
201205
os.makedirs(save_dir)
202-
203-
evaluate(model, metric, dev_data_loader, "dev")
204-
205206
save_param_path = os.path.join(save_dir, 'model_state.pdparams')
206207
paddle.save(model.state_dict(), save_param_path)
207208
tokenizer.save_pretrained(save_dir)

examples/text_matching/ernie_matching/train_pointwise.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
4141
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
4242
parser.add_argument("--epochs", default=3, type=int, help="Total number of training epochs to perform.")
43+
parser.add_argument("--eval_step", default=100, type=int, help="Step interval for evaluation.")
44+
parser.add_argument('--save_step', default=10000, type=int, help="Step interval for saving checkpoint.")
4345
parser.add_argument("--warmup_proportion", default=0.0, type=float, help="Linear warmup proption over the training process.")
4446
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
45-
parser.add_argument("--seed", type=int, default=1000, help="random seed for initialization")
47+
parser.add_argument("--seed", type=int, default=1000, help="Random seed for initialization.")
4648
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
4749
args = parser.parse_args()
4850
# yapf: enable
@@ -177,12 +179,14 @@ def do_train():
177179
optimizer.step()
178180
lr_scheduler.step()
179181
optimizer.clear_grad()
180-
if global_step % 100 == 0 and rank == 0:
182+
183+
if global_step % args.eval_step == 0 and rank == 0:
184+
evaluate(model, criterion, metric, dev_data_loader)
185+
186+
if global_step % args.save_step == 0 and rank == 0:
181187
save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
182188
if not os.path.exists(save_dir):
183189
os.makedirs(save_dir)
184-
evaluate(model, criterion, metric, dev_data_loader)
185-
186190
save_param_path = os.path.join(save_dir, 'model_state.pdparams')
187191
paddle.save(model.state_dict(), save_param_path)
188192
tokenizer.save_pretrained(save_dir)

0 commit comments

Comments
 (0)