File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -2530,6 +2530,9 @@ def _inner_training_loop(
25302530 update_step += 1
25312531 num_batches = args .gradient_accumulation_steps if update_step != (total_updates - 1 ) else remainder
25322532 batch_samples , num_items_in_batch = self .get_batch_samples (epoch_iterator , num_batches , args .device )
2533+ # Store the number of batches for current gradient accumulation
2534+ # This is used to correctly scale the loss when the last accumulation step has fewer batches
2535+ self .current_gradient_accumulation_steps = len (batch_samples )
25332536 for i , inputs in enumerate (batch_samples ):
25342537 step += 1
25352538 do_sync_step = (step + 1 ) % args .gradient_accumulation_steps == 0 or (step + 1 ) == steps_in_epoch
@@ -3830,7 +3833,8 @@ def training_step(
38303833 else :
38313834 # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
38323835 if (not self .model_accepts_loss_kwargs or num_items_in_batch is None ) and self .compute_loss_func is None :
3833- loss = loss / self .args .gradient_accumulation_steps
3836+ # If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
3837+ loss = loss / self .current_gradient_accumulation_steps
38343838
38353839 # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
38363840 # https://github.com/huggingface/transformers/pull/35808
You can’t perform that action at this time.
0 commit comments