Skip to content

Commit d3783ec

Browse files
kibitzingmuellerzr
authored andcommitted
Fix step shifting when accumulate gradient (huggingface#33673)
* replace total_batched_samples with step while counting grad accum step * remove unused variable * simplify condition for update step * fix format by ruff * simplify update step condition using accelerator.sync_gradients * simplify update condition using do_sync_step * remove print for test --------- Co-authored-by: Zach Mueller <[email protected]>
1 parent 1d7cc16 commit d3783ec

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

src/transformers/trainer.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,7 +2404,6 @@ def _inner_training_loop(
24042404
if args.eval_on_start:
24052405
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
24062406

2407-
total_batched_samples = 0
24082407
for epoch in range(epochs_trained, num_train_epochs):
24092408
epoch_dataloader = train_dataloader
24102409
if hasattr(epoch_dataloader, "set_epoch"):
@@ -2447,13 +2446,7 @@ def _inner_training_loop(
24472446
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
24482447
for inputs in batch_samples:
24492448
step += 1
2450-
total_batched_samples += 1
2451-
is_last_step_and_steps_less_than_grad_acc = (
2452-
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
2453-
)
2454-
do_sync_step = is_last_step_and_steps_less_than_grad_acc or (
2455-
total_batched_samples % args.gradient_accumulation_steps == 0
2456-
)
2449+
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
24572450
# Since we perform prefetching, we need to manually set sync_gradients
24582451
if not do_sync_step:
24592452
self.accelerator.gradient_state._set_sync_gradients(False)

0 commit comments

Comments
 (0)