Skip to content

Commit 075dbbc

Browse files
hutaiHangtaihangqgallouedec
authored
fix(trainer): Correct loss scaling for incomplete gradient accumulation steps (#39659)
* Fix issue[#38837]: wrong loss scaled in last step of epoch * chore: trigger CI * Update src/transformers/trainer.py Co-authored-by: Quentin Gallouédec <[email protected]> * Update src/transformers/modeling_flash_attention_utils.py Co-authored-by: Quentin Gallouédec <[email protected]> --------- Co-authored-by: taihang <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 1d06153 commit 075dbbc

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/transformers/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2530,6 +2530,9 @@ def _inner_training_loop(
25302530
update_step += 1
25312531
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
25322532
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
2533+
# Store the number of batches for current gradient accumulation
2534+
# This is used to correctly scale the loss when the last accumulation step has fewer batches
2535+
self.current_gradient_accumulation_steps = len(batch_samples)
25332536
for i, inputs in enumerate(batch_samples):
25342537
step += 1
25352538
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
@@ -3830,7 +3833,8 @@ def training_step(
38303833
else:
38313834
# Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
38323835
if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:
3833-
loss = loss / self.args.gradient_accumulation_steps
3836+
# If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
3837+
loss = loss / self.current_gradient_accumulation_steps
38343838

38353839
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
38363840
# https://github.com/huggingface/transformers/pull/35808

0 commit comments

Comments
 (0)