-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Gradient accumulation fix in cross entropy loss #21386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 5 commits
385fd56
9b7aa6f
7f5f88c
95d467d
fb7dbc8
01fcf62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,10 +12,12 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import contextlib | ||
| import itertools | ||
| import math | ||
| import time | ||
| from collections import OrderedDict | ||
| from dataclasses import dataclass | ||
| from itertools import islice | ||
| from typing import Any, Optional, Union | ||
|
|
||
| import torch | ||
|
|
@@ -94,6 +96,7 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s | |
| self._batches_that_stepped: int = 0 | ||
| self._restart_stage = RestartStage.NONE | ||
| self._skip_next_val = False | ||
| self._num_global_valid_tokens: Optional[int] = None | ||
|
|
||
| @property | ||
| def total_batch_idx(self) -> int: | ||
|
|
@@ -278,6 +281,12 @@ def advance(self, data_fetcher: _DataFetcher) -> None: | |
| StopIteration: When the epoch is canceled by the user returning -1 | ||
|
|
||
| """ | ||
| # create a peekable iterator to look ahead without consuming the original data_fetcher | ||
| iterator = data_fetcher.iterator | ||
| assert iterator is not None | ||
| it1, self._peekable_iter = itertools.tee(iterator) | ||
| data_fetcher.iterator = it1 | ||
|
|
||
| if self.restarting and self._should_check_val_fx(data_fetcher): | ||
| if self.val_loop.restarted_mid_evaluation: | ||
| # Go back and finish running validation | ||
|
|
@@ -346,6 +355,38 @@ def advance(self, data_fetcher: _DataFetcher) -> None: | |
| if not using_dataloader_iter | ||
| else OrderedDict(any=dataloader_iter) | ||
| ) | ||
|
|
||
| # Count valid tokens across global batch when using grad accumulation when using cross entropy loss | ||
| # Only calculate at the first batch of accumulation window and then reuse | ||
| if ( | ||
| trainer.lightning_module.automatic_optimization | ||
| and trainer.accumulate_grad_batches > 1 | ||
| and batch_idx % trainer.accumulate_grad_batches == 0 | ||
| ): | ||
| # require all batches in accumulation window to be properly formatted | ||
| total_valid_tokens = 0 | ||
| all_formatted_batches = True | ||
| # Take next N batches without consuming the original data_fetcher | ||
| peek_batches = list(islice(self._peekable_iter, trainer.accumulate_grad_batches)) | ||
| for batch in peek_batches: | ||
| # unwrap Lightning's list/tuple wrapper | ||
| if isinstance(batch, (list, tuple)): | ||
| batch = batch[0] | ||
| # require batch to be instance of dict and has labels, otherwise break | ||
| if not isinstance(batch, dict): | ||
| all_formatted_batches = False | ||
| break | ||
| labels = batch.get("labels") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so the core assumption to get this working is that the user have formatted their batches such that each batch has a labels tensor?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this will be listed in relevant docs (once the approach is confirmed) and its the same as |
||
| # break if labels missing or None | ||
| if labels is None: | ||
| all_formatted_batches = False | ||
| break | ||
| # safe to process | ||
| labels = torch.as_tensor(labels) | ||
| total_valid_tokens += int((labels != -100).sum().item()) | ||
| self._num_global_valid_tokens = total_valid_tokens if all_formatted_batches else None | ||
|
Comment on lines
+386
to
+387
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i do wonder if this is such a special case that it is better for the user to provide the information compared to lightning trying to calculate it. For example, what if the masking token is not
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this, I thought that if a user is passing |
||
|
|
||
| kwargs["num_global_valid_tokens"] = self._num_global_valid_tokens | ||
| with trainer.profiler.profile("run_training_batch"): | ||
| if trainer.lightning_module.automatic_optimization: | ||
| # in automatic optimization, there can only be one optimizer | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this definitely seems like the smart way to implement this kind of solution