File tree Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -128,7 +128,7 @@ def get_default_shortenable_ids(self):
128
128
idx = []
129
129
for p in self .template :
130
130
if 'shortenable' in p :
131
- idx .append (1 if d ['shortenable' ] else 0 )
131
+ idx .append (1 if p ['shortenable' ] else 0 )
132
132
else :
133
133
idx .append (1 if 'text' in p else 0 )
134
134
return idx
Original file line number Diff line number Diff line change @@ -335,6 +335,7 @@ def init_num_steps(args, num_samples_per_epoch):
335
335
num_samples_per_epoch % args .train_batch_size > 0 )
336
336
num_update_steps_per_epoch //= args .gradient_accumulation_steps
337
337
num_update_steps_per_epoch = max (num_update_steps_per_epoch , 1 )
338
+ args .num_update_steps_per_epoch = num_update_steps_per_epoch
338
339
339
340
if args .max_steps > 0 :
340
341
args .num_training_steps = args .max_steps
@@ -447,10 +448,10 @@ def train(
447
448
os .path .join (resume_from_checkpoint , TRAINER_STATE_NAME )):
448
449
self .state = TrainerState .load_from_json (
449
450
os .path .join (resume_from_checkpoint , TRAINER_STATE_NAME ))
450
- epochs_trained = self .state .global_step // num_update_steps_per_epoch
451
+ epochs_trained = self .state .global_step // args . num_update_steps_per_epoch
451
452
if not args .ignore_data_skip :
452
453
steps_trained_in_current_epoch = self .state .global_step % (
453
- num_update_steps_per_epoch )
454
+ args . num_update_steps_per_epoch )
454
455
steps_trained_in_current_epoch *= args .gradient_accumulation_steps
455
456
else :
456
457
steps_trained_in_current_epoch = 0
You can’t perform that action at this time.
0 commit comments