-
Notifications
You must be signed in to change notification settings - Fork 3.6k
LRs updates are called at the end of a skipped epoch #21307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 4 commits
adb8780
4fa5c05
4a713e5
71d2ebd
60a7cd3
10d18d9
024933b
a8d9f8b
222754d
e115b84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -325,30 +325,33 @@ def advance(self, data_fetcher: _DataFetcher) -> None: | |
| trainer._logger_connector.on_batch_start(batch) | ||
|
|
||
| batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy | ||
| should_skip_rest_of_epoch = False | ||
|
|
||
| if batch is None and not using_dataloader_iter: | ||
| self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") | ||
| else: | ||
| # hook | ||
| call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx) | ||
| response = call._call_lightning_module_hook(trainer, "on_train_batch_start", batch, batch_idx) | ||
| call._call_strategy_hook(trainer, "on_train_batch_start", batch, batch_idx) | ||
| if response == -1: | ||
| self.batch_progress.increment_processed() | ||
| raise StopIteration | ||
|
|
||
| self.batch_progress.increment_started() | ||
|
|
||
| kwargs = ( | ||
| self._build_kwargs(OrderedDict(), batch, batch_idx) | ||
| if not using_dataloader_iter | ||
| else OrderedDict(any=dataloader_iter) | ||
| ) | ||
| with trainer.profiler.profile("run_training_batch"): | ||
| if trainer.lightning_module.automatic_optimization: | ||
| # in automatic optimization, there can only be one optimizer | ||
| batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) | ||
| else: | ||
| batch_output = self.manual_optimization.run(kwargs) | ||
| should_skip_rest_of_epoch = response == -1 | ||
| # Signal this is the last batch for the current epoch | ||
| if should_skip_rest_of_epoch: | ||
| self.batch_progress.increment_by(0, is_last_batch=True) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the logic here from changing from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'll do it. |
||
| else: | ||
| self.batch_progress.increment_started() | ||
|
|
||
| kwargs = ( | ||
| self._build_kwargs(OrderedDict(), batch, batch_idx) | ||
| if not using_dataloader_iter | ||
| else OrderedDict(any=dataloader_iter) | ||
| ) | ||
| with trainer.profiler.profile("run_training_batch"): | ||
| if trainer.lightning_module.automatic_optimization: | ||
| # in automatic optimization, there can only be one optimizer | ||
| batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) | ||
| else: | ||
| batch_output = self.manual_optimization.run(kwargs) | ||
|
|
||
| self.batch_progress.increment_processed() | ||
|
|
||
|
|
@@ -358,6 +361,10 @@ def advance(self, data_fetcher: _DataFetcher) -> None: | |
| if self._num_ready_batches_reached(): | ||
| self.update_lr_schedulers("epoch", update_plateau_schedulers=False) | ||
|
|
||
| if should_skip_rest_of_epoch: | ||
| # Only raise StopIteration now so that the training epoch loop can finish | ||
| raise StopIteration | ||
|
|
||
| if using_dataloader_iter: | ||
| # update the hook kwargs now that the step method might have consumed the iterator | ||
| batch = data_fetcher._batch | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.