|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import contextlib |
| 15 | +import itertools |
15 | 16 | import math |
16 | 17 | import time |
17 | 18 | from collections import OrderedDict |
18 | 19 | from dataclasses import dataclass |
| 20 | +from itertools import islice |
19 | 21 | from typing import Any, Optional, Union |
20 | 22 |
|
21 | 23 | import torch |
@@ -94,6 +96,7 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s |
94 | 96 | self._batches_that_stepped: int = 0 |
95 | 97 | self._restart_stage = RestartStage.NONE |
96 | 98 | self._skip_next_val = False |
| 99 | + self._num_global_valid_tokens: Optional[int] = None |
97 | 100 |
|
98 | 101 | @property |
99 | 102 | def total_batch_idx(self) -> int: |
@@ -278,6 +281,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None: |
278 | 281 | StopIteration: When the epoch is canceled by the user returning -1 |
279 | 282 |
|
280 | 283 | """ |
| 284 | + # create a peekable iterator to look ahead without consuming the original data_fetcher |
| 285 | + it1, self._peekable_iter = itertools.tee(data_fetcher.iterator) |
| 286 | + data_fetcher.iterator = it1 |
| 287 | + |
281 | 288 | if self.restarting and self._should_check_val_fx(data_fetcher): |
282 | 289 | if self.val_loop.restarted_mid_evaluation: |
283 | 290 | # Go back and finish running validation |
@@ -346,6 +353,38 @@ def advance(self, data_fetcher: _DataFetcher) -> None: |
346 | 353 | if not using_dataloader_iter |
347 | 354 | else OrderedDict(any=dataloader_iter) |
348 | 355 | ) |
| 356 | + |
| 357 | + # Count valid tokens across global batch when using grad accumulation when using cross entropy loss |
| 358 | + # Only calculate at the first batch of accumulation window and then reuse |
| 359 | + if ( |
| 360 | + trainer.lightning_module.automatic_optimization |
| 361 | + and trainer.accumulate_grad_batches > 1 |
| 362 | + and batch_idx % trainer.accumulate_grad_batches == 0 |
| 363 | + ): |
| 364 | + # require all batches in accumulation window to be properly formatted |
| 365 | + total_valid_tokens = 0 |
| 366 | + all_formatted_batches = True |
| 367 | + # Take next N batches without consuming the original data_fetcher |
| 368 | + peek_batches = list(islice(self._peekable_iter, trainer.accumulate_grad_batches)) |
| 369 | + for batch in peek_batches: |
| 370 | + # unwrap Lightning's list/tuple wrapper |
| 371 | + if isinstance(batch, (list, tuple)): |
| 372 | + batch = batch[0] |
| 373 | + # require batch to be instance of dict and has labels, otherwise break |
| 374 | + if not isinstance(batch, dict): |
| 375 | + all_formatted_batches = False |
| 376 | + break |
| 377 | + labels = batch.get("labels") |
| 378 | + # break if labels missing or None |
| 379 | + if labels is None: |
| 380 | + all_formatted_batches = False |
| 381 | + break |
| 382 | + # safe to process |
| 383 | + labels = torch.as_tensor(labels) |
| 384 | + total_valid_tokens += int((labels != -100).sum().item()) |
| 385 | + self._num_global_valid_tokens = total_valid_tokens if all_formatted_batches else None |
| 386 | + |
| 387 | + kwargs["num_global_valid_tokens"] = self._num_global_valid_tokens |
349 | 388 | with trainer.profiler.profile("run_training_batch"): |
350 | 389 | if trainer.lightning_module.automatic_optimization: |
351 | 390 | # in automatic optimization, there can only be one optimizer |
|
0 commit comments