Skip to content

Commit 2c82f03

Browse files
authored
[trainer] Fix bug when resume_from_checkpoint (#3201)
1 parent caaa102 commit 2c82f03

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

paddlenlp/prompt/template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def get_default_shortenable_ids(self):
128128
idx = []
129129
for p in self.template:
130130
if 'shortenable' in p:
131-
idx.append(1 if d['shortenable'] else 0)
131+
idx.append(1 if p['shortenable'] else 0)
132132
else:
133133
idx.append(1 if 'text' in p else 0)
134134
return idx

paddlenlp/trainer/trainer_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def init_num_steps(args, num_samples_per_epoch):
335335
num_samples_per_epoch % args.train_batch_size > 0)
336336
num_update_steps_per_epoch //= args.gradient_accumulation_steps
337337
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
338+
args.num_update_steps_per_epoch = num_update_steps_per_epoch
338339

339340
if args.max_steps > 0:
340341
args.num_training_steps = args.max_steps
@@ -447,10 +448,10 @@ def train(
447448
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)):
448449
self.state = TrainerState.load_from_json(
449450
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
450-
epochs_trained = self.state.global_step // num_update_steps_per_epoch
451+
epochs_trained = self.state.global_step // args.num_update_steps_per_epoch
451452
if not args.ignore_data_skip:
452453
steps_trained_in_current_epoch = self.state.global_step % (
453-
num_update_steps_per_epoch)
454+
args.num_update_steps_per_epoch)
454455
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
455456
else:
456457
steps_trained_in_current_epoch = 0

0 commit comments

Comments
 (0)