Skip to content

Commit 9445a84

Browse files
rohitgr7carmocca
andauthored
Fix epoch logging on train epoch end (Lightning-AI#13025)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 18cdfab commit 9445a84

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
223223
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))
224224

225225

226+
- Fixed epoch logging on train epoch end ([#13025](https://github.com/PyTorchLightning/pytorch-lightning/pull/13025))
227+
228+
226229
- Fixed `materialize_module` setting a module's child recursively ([#12870](https://github.com/PyTorchLightning/pytorch-lightning/pull/12870))
227230

228231

pytorch_lightning/loops/fit_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,15 @@ def on_advance_end(self) -> None:
305305
if self.epoch_loop._num_ready_batches_reached():
306306
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=True)
307307

308-
self.epoch_progress.increment_completed()
309-
310308
# we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics
311309
# even when the batch loop has finished
312310
self.epoch_loop._batches_that_stepped -= 1
313311
# log epoch metrics
314312
self.trainer._logger_connector.update_train_epoch_metrics()
315313
self.epoch_loop._batches_that_stepped += 1
316314

315+
self.epoch_progress.increment_completed()
316+
317317
# if fault tolerant is enabled and process has been notified, exit.
318318
self.trainer._exit_gracefully_on_signal()
319319

tests/trainer/logging_/test_train_loop_logging.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import collections
1717
import itertools
1818
from re import escape
19+
from unittest import mock
20+
from unittest.mock import call
1921

2022
import numpy as np
2123
import pytest
@@ -748,3 +750,37 @@ def validation_epoch_end(self, *_) -> None:
748750
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
749751
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
750752
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
753+
754+
755+
@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics")
756+
def test_log_metrics_epoch_step_values(mock_log_metrics, tmpdir):
757+
"""Tests the default epoch and step values logged."""
758+
759+
class MyModel(BoringModel):
760+
def training_step(self, batch, batch_idx):
761+
self.log("foo", 0.0, on_step=True, on_epoch=True)
762+
return super().training_step(batch, batch_idx)
763+
764+
model = MyModel()
765+
trainer = Trainer(
766+
default_root_dir=tmpdir,
767+
limit_train_batches=2,
768+
limit_val_batches=0,
769+
max_epochs=2,
770+
log_every_n_steps=1,
771+
enable_model_summary=False,
772+
enable_checkpointing=False,
773+
enable_progress_bar=False,
774+
)
775+
trainer.fit(model)
776+
777+
mock_log_metrics.assert_has_calls(
778+
[
779+
call(metrics={"foo_step": 0.0, "epoch": 0}, step=0),
780+
call(metrics={"foo_step": 0.0, "epoch": 0}, step=1),
781+
call(metrics={"foo_epoch": 0.0, "epoch": 0}, step=1),
782+
call(metrics={"foo_step": 0.0, "epoch": 1}, step=2),
783+
call(metrics={"foo_step": 0.0, "epoch": 1}, step=3),
784+
call(metrics={"foo_epoch": 0.0, "epoch": 1}, step=3),
785+
]
786+
)

0 commit comments

Comments
 (0)