Skip to content

Commit ac9dbeb

Browse files
author
tianxin
authored
EFL: Fix save_steps not work bug (#577)
* "add mlm params to dygraph ernie1.0" * finish p-tuning v1.0 * mend * delete unused coment * add label_normalized * P-tuning: support Chid task of FewCLUE * 1. decouple evaluate and train * 1.add FewCLUE datasets(9/9) 2.implement p-tuning strategy by transform_function 3.unify train_script beteween `chid` task and other 8 tasks of FewCLUE * add README.md * update FewCLUE data * add predict.py for FewCLUE * update README * update README.md * add FewCLUE 9 datasets * update dataset Name * Tiny fix * finish EFL for iflytek task * finish EFL for FewCLUE 9 tasks * add EFL README.md * fix save_steps have no effect bug
1 parent 9c81a4a commit ac9dbeb

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

examples/few_shot/efl/train.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,14 @@ def do_train():
232232
10 / (time.time() - tic_train)))
233233
tic_train = time.time()
234234

235+
if global_step % args.save_steps == 0 and rank == 0:
236+
save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
237+
if not os.path.exists(save_dir):
238+
os.makedirs(save_dir)
239+
save_param_path = os.path.join(save_dir, 'model_state.pdparams')
240+
paddle.save(model.state_dict(), save_param_path)
241+
tokenizer.save_pretrained(save_dir)
242+
235243
loss.backward()
236244
optimizer.step()
237245
lr_scheduler.step()
@@ -268,4 +276,4 @@ def do_train():
268276

269277
if __name__ == "__main__":
270278
args = parse_args()
271-
do_train()
279+
do_train()

0 commit comments

Comments
 (0)