We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9c81a4a commit ac9dbebCopy full SHA for ac9dbeb
examples/few_shot/efl/train.py
@@ -232,6 +232,14 @@ def do_train():
232
10 / (time.time() - tic_train)))
233
tic_train = time.time()
234
235
+ if global_step % args.save_steps == 0 and rank == 0:
236
+ save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
237
+ if not os.path.exists(save_dir):
238
+ os.makedirs(save_dir)
239
+ save_param_path = os.path.join(save_dir, 'model_state.pdparams')
240
+ paddle.save(model.state_dict(), save_param_path)
241
+ tokenizer.save_pretrained(save_dir)
242
+
243
loss.backward()
244
optimizer.step()
245
lr_scheduler.step()
@@ -268,4 +276,4 @@ def do_train():
268
276
269
277
if __name__ == "__main__":
270
278
args = parse_args()
271
- do_train()
279
+ do_train()
0 commit comments