Skip to content

Commit f89b181

Browse files
rohitgr7carmocca
authored andcommitted
Fix epoch logging on train epoch end (#13025)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 902774a commit f89b181

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929
- Fixed issue where the CLI could not pass a `Profiler` to the `Trainer` ([#13084](https://github.com/PyTorchLightning/pytorch-lightning/pull/13084))
3030
- Fixed torchelastic detection with non-distributed installations ([#13142](https://github.com/PyTorchLightning/pytorch-lightning/pull/13142))
3131
- Fixed logging's step values when multiple dataloaders are used during evaluation ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184))
32+
- Fixed epoch logging on train epoch end ([#13025](https://github.com/PyTorchLightning/pytorch-lightning/pull/13025))
3233

3334

3435
## [1.6.3] - 2022-05-03

pytorch_lightning/loops/fit_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,15 @@ def on_advance_end(self) -> None:
305305
if self.epoch_loop._num_ready_batches_reached():
306306
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=True)
307307

308-
self.epoch_progress.increment_completed()
309-
310308
# we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics
311309
# even when the batch loop has finished
312310
self.epoch_loop._batches_that_stepped -= 1
313311
# log epoch metrics
314312
self.trainer._logger_connector.update_train_epoch_metrics()
315313
self.epoch_loop._batches_that_stepped += 1
316314

315+
self.epoch_progress.increment_completed()
316+
317317
# if fault tolerant is enabled and process has been notified, exit.
318318
self.trainer._exit_gracefully_on_signal()
319319

tests/trainer/logging_/test_train_loop_logging.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import collections
1717
import itertools
1818
from re import escape
19+
from unittest import mock
20+
from unittest.mock import call
1921

2022
import numpy as np
2123
import pytest
@@ -747,3 +749,37 @@ def validation_epoch_end(self, *_) -> None:
747749
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
748750
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
749751
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
752+
753+
754+
@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics")
755+
def test_log_metrics_epoch_step_values(mock_log_metrics, tmpdir):
756+
"""Tests the default epoch and step values logged."""
757+
758+
class MyModel(BoringModel):
759+
def training_step(self, batch, batch_idx):
760+
self.log("foo", 0.0, on_step=True, on_epoch=True)
761+
return super().training_step(batch, batch_idx)
762+
763+
model = MyModel()
764+
trainer = Trainer(
765+
default_root_dir=tmpdir,
766+
limit_train_batches=2,
767+
limit_val_batches=0,
768+
max_epochs=2,
769+
log_every_n_steps=1,
770+
enable_model_summary=False,
771+
enable_checkpointing=False,
772+
enable_progress_bar=False,
773+
)
774+
trainer.fit(model)
775+
776+
mock_log_metrics.assert_has_calls(
777+
[
778+
call(metrics={"foo_step": 0.0, "epoch": 0}, step=0),
779+
call(metrics={"foo_step": 0.0, "epoch": 0}, step=1),
780+
call(metrics={"foo_epoch": 0.0, "epoch": 0}, step=1),
781+
call(metrics={"foo_step": 0.0, "epoch": 1}, step=2),
782+
call(metrics={"foo_step": 0.0, "epoch": 1}, step=3),
783+
call(metrics={"foo_epoch": 0.0, "epoch": 1}, step=3),
784+
]
785+
)

0 commit comments

Comments
 (0)