Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.5.0] -

### Added

### Changed

- `batch_progress.increment_completed()` is now called right prior to `on_train_batch_end`, `on_validation_batch_end`, and `on_predict_batch_end` to ensure hooks see up-to-date progress information about the completed batch and avoid skews during restarts ([]())

### Removed

### Fixed


## [2.4.0] - 2024-08-06

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,10 @@ def _evaluation_step(
call._call_callback_hooks(trainer, hook_name, output, *hook_kwargs.values())
call._call_lightning_module_hook(trainer, hook_name, output, *hook_kwargs.values())

trainer._logger_connector.on_batch_end()

self.batch_progress.increment_completed()

trainer._logger_connector.on_batch_end()

if not trainer.sanity_checking:
# indicate the loop has run
self._has_run = True
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,11 @@ def _predict_step(
dataloader_idx = data_fetcher._dataloader_idx
hook_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx if self.num_dataloaders > 1 else None)

self.batch_progress.increment_completed()

call._call_callback_hooks(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values())
call._call_lightning_module_hook(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values())

self.batch_progress.increment_completed()

if self._return_predictions or any_on_epoch:
self._predictions[dataloader_idx].append(move_data_to_device(predictions, torch.device("cpu")))

Expand Down
13 changes: 8 additions & 5 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,18 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
# update `is_last_batch` again after dataloader_iter was fetched in `training_step()`
self.batch_progress.is_last_batch = data_fetcher.done

# we increment prior to on_batch_end so checkpoints can see progress correctly
# failure to do this will lead to incorrect total batch completed progress counts
self.batch_progress.increment_completed()

if not self._should_accumulate():
# this is increased once per batch disregarding multiple optimizers on purpose for loggers
self._batches_that_stepped += 1

call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
call._call_lightning_module_hook(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
trainer._logger_connector.on_batch_end()

self.batch_progress.increment_completed()

# -----------------------------------------
# SAVE METRICS TO LOGGERS AND PROGRESS_BAR
# -----------------------------------------
Expand Down Expand Up @@ -299,9 +305,6 @@ def on_advance_end(self, data_fetcher: _DataFetcher) -> None:
# update plateau LR scheduler after metrics are logged
self.update_lr_schedulers("step", update_plateau_schedulers=True)

if not self._should_accumulate():
# this is increased once per batch disregarding multiple optimizers on purpose for loggers
self._batches_that_stepped += 1
# this will save based on the `batches_that_stepped` value
self._save_loggers_on_train_batch_end()

Expand Down
20 changes: 12 additions & 8 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,12 @@ def test_fit_loop_reset(tmp_path):
assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 2
assert epoch_loop.batch_progress.total.processed == 2
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 1 # currents get set to the completed value
assert epoch_loop.batch_progress.current.processed == 1
assert epoch_loop.batch_progress.current.completed == 1
assert epoch_loop.batch_progress.total.completed == 2 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 2 # currents get set to the completed value
assert epoch_loop.batch_progress.current.processed == 2
assert epoch_loop.batch_progress.current.completed == 2

assert epoch_loop._batches_that_stepped == 2

assert optimizer_loop.restarting

Expand All @@ -600,10 +602,12 @@ def test_fit_loop_reset(tmp_path):
assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 4
assert epoch_loop.batch_progress.total.processed == 4
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 3 # currents get set to the completed value
assert epoch_loop.batch_progress.current.processed == 3
assert epoch_loop.batch_progress.current.completed == 3
assert epoch_loop.batch_progress.total.completed == 4 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 4 # currents get set to the completed value
assert epoch_loop.batch_progress.current.processed == 4
assert epoch_loop.batch_progress.current.completed == 4

assert epoch_loop._batches_that_stepped == 4


@pytest.mark.parametrize(
Expand Down
Loading