diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index e19b5761c4d4b..cb942a99d0988 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -59,7 +59,9 @@ def _clone_loss(self) -> None: self.loss = self.closure_loss.detach().clone() @classmethod - def from_training_step_output(cls, training_step_output: STEP_OUTPUT, normalize: int = 1) -> "ClosureResult": + def from_training_step_output( + cls, training_step_output: STEP_OUTPUT, normalize: int = 1, num_global_valid_tokens: Optional[int] = None + ) -> "ClosureResult": closure_loss, extra = None, {} if isinstance(training_step_output, Mapping): @@ -80,7 +82,10 @@ def from_training_step_output(cls, training_step_output: STEP_OUTPUT, normalize: if closure_loss is not None: # accumulate the loss. If ``accumulate_grad_batches == 1``, no effect # note: avoid in-place operation `x /= y` here on purpose - closure_loss = closure_loss / normalize + if num_global_valid_tokens is not None: + closure_loss = closure_loss / num_global_valid_tokens + elif normalize > 1: + closure_loss = closure_loss / normalize return cls(closure_loss, extra=extra) @@ -315,6 +320,7 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult: """ trainer = self.trainer + num_global_valid_tokens = kwargs.pop("num_global_valid_tokens", None) training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values()) self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility @@ -326,4 +332,6 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult: " place." ) - return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches) + return self.output_result_cls.from_training_step_output( + training_step_output, trainer.accumulate_grad_batches, num_global_valid_tokens=num_global_valid_tokens + ) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 6212bfe264e6e..710dde1442cb4 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +import itertools import math import time from collections import OrderedDict from dataclasses import dataclass +from itertools import islice from typing import Any, Optional, Union import torch @@ -94,6 +96,7 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s self._batches_that_stepped: int = 0 self._restart_stage = RestartStage.NONE self._skip_next_val = False + self._num_global_valid_tokens: Optional[int] = None @property def total_batch_idx(self) -> int: @@ -278,6 +281,12 @@ def advance(self, data_fetcher: _DataFetcher) -> None: StopIteration: When the epoch is canceled by the user returning -1 """ + # create a peekable iterator to look ahead without consuming the original data_fetcher + iterator = data_fetcher.iterator + assert iterator is not None + it1, self._peekable_iter = itertools.tee(iterator) + data_fetcher.iterator = it1 + if self.restarting and self._should_check_val_fx(data_fetcher): if self.val_loop.restarted_mid_evaluation: # Go back and finish running validation @@ -346,6 +355,38 @@ def advance(self, data_fetcher: _DataFetcher) -> None: if not using_dataloader_iter else OrderedDict(any=dataloader_iter) ) + + # Count valid tokens across global batch when using grad accumulation when using cross entropy loss + # Only calculate at the first batch of accumulation window and then reuse + if ( + trainer.lightning_module.automatic_optimization + and trainer.accumulate_grad_batches > 1 + and batch_idx % trainer.accumulate_grad_batches == 0 + ): + # require all batches in accumulation window to be properly formatted + total_valid_tokens = 0 + all_formatted_batches = True + # Take next N batches without consuming the original data_fetcher + peek_batches = list(islice(self._peekable_iter, trainer.accumulate_grad_batches)) + for batch in peek_batches: + # unwrap Lightning's list/tuple wrapper + if isinstance(batch, (list, tuple)): + batch = batch[0] + # require batch to be instance of dict and has labels, otherwise break + if not isinstance(batch, dict): + all_formatted_batches = False + break + labels = batch.get("labels") + # break if labels missing or None + if labels is None: + all_formatted_batches = False + break + # safe to process + labels = torch.as_tensor(labels) + total_valid_tokens += int((labels != -100).sum().item()) + self._num_global_valid_tokens = total_valid_tokens if all_formatted_batches else None + + kwargs["num_global_valid_tokens"] = self._num_global_valid_tokens with trainer.profiler.profile("run_training_batch"): if trainer.lightning_module.automatic_optimization: # in automatic optimization, there can only be one optimizer