Skip to content

Commit 385fd56

Browse files
Introduce peekable iterator to count number of valid tokens in the global batch at start of accumulation window i.e. in first micro-batch.
1 parent 8f702b3 commit 385fd56

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import contextlib
15+
import itertools
1516
import math
1617
import time
1718
from collections import OrderedDict
1819
from dataclasses import dataclass
20+
from itertools import islice
1921
from typing import Any, Optional, Union
2022

2123
import torch
@@ -94,6 +96,7 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s
9496
self._batches_that_stepped: int = 0
9597
self._restart_stage = RestartStage.NONE
9698
self._skip_next_val = False
99+
self._num_global_valid_tokens: Optional[int] = None
97100

98101
@property
99102
def total_batch_idx(self) -> int:
@@ -278,6 +281,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
278281
StopIteration: When the epoch is canceled by the user returning -1
279282
280283
"""
284+
# create a peekable iterator to look ahead without consuming the original data_fetcher
285+
it1, self._peekable_iter = itertools.tee(data_fetcher.iterator)
286+
data_fetcher.iterator = it1
287+
281288
if self.restarting and self._should_check_val_fx(data_fetcher):
282289
if self.val_loop.restarted_mid_evaluation:
283290
# Go back and finish running validation
@@ -346,6 +353,38 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
346353
if not using_dataloader_iter
347354
else OrderedDict(any=dataloader_iter)
348355
)
356+
357+
# Count valid tokens across global batch when using grad accumulation when using cross entropy loss
358+
# Only calculate at the first batch of accumulation window and then reuse
359+
if (
360+
trainer.lightning_module.automatic_optimization
361+
and trainer.accumulate_grad_batches > 1
362+
and batch_idx % trainer.accumulate_grad_batches == 0
363+
):
364+
# require all batches in accumulation window to be properly formatted
365+
total_valid_tokens = 0
366+
all_formatted_batches = True
367+
# Take next N batches without consuming the original data_fetcher
368+
peek_batches = list(islice(self._peekable_iter, trainer.accumulate_grad_batches))
369+
for batch in peek_batches:
370+
# unwrap Lightning's list/tuple wrapper
371+
if isinstance(batch, (list, tuple)):
372+
batch = batch[0]
373+
# require batch to be instance of dict and has labels, otherwise break
374+
if not isinstance(batch, dict):
375+
all_formatted_batches = False
376+
break
377+
labels = batch.get("labels")
378+
# break if labels missing or None
379+
if labels is None:
380+
all_formatted_batches = False
381+
break
382+
# safe to process
383+
labels = torch.as_tensor(labels)
384+
total_valid_tokens += int((labels != -100).sum().item())
385+
self._num_global_valid_tokens = total_valid_tokens if all_formatted_batches else None
386+
387+
kwargs["num_global_valid_tokens"] = self._num_global_valid_tokens
349388
with trainer.profiler.profile("run_training_batch"):
350389
if trainer.lightning_module.automatic_optimization:
351390
# in automatic optimization, there can only be one optimizer

0 commit comments

Comments
 (0)