Skip to content

Commit 59b0ccb

Browse files
authored
Fix logging to loggers with multiple eval dataloaders (#12454)
1 parent 83bca4c commit 59b0ccb

File tree

6 files changed

+26
-12
lines changed

6 files changed

+26
-12
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
924924
- Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267))
925925

926926

927+
- Fixed logging to loggers with multiple eval dataloaders ([#12454](https://github.com/PyTorchLightning/pytorch-lightning/pull/12454))
928+
929+
927930
## [1.5.10] - 2022-02-08
928931

929932
### Fixed

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
import shutil
16-
from collections import OrderedDict
16+
from collections import ChainMap, OrderedDict
1717
from functools import partial
1818
from typing import Any, IO, Iterable, List, Optional, Sequence, Type, Union
1919

@@ -181,11 +181,13 @@ def on_run_end(self) -> List[_OUT_DICT]:
181181
logged_outputs, self._logged_outputs = self._logged_outputs, [] # free memory
182182
# include any logged outputs on epoch_end
183183
epoch_end_logged_outputs = self.trainer._logger_connector.update_eval_epoch_metrics()
184+
all_logged_outputs = dict(ChainMap(*logged_outputs)) # list[dict] -> dict
185+
all_logged_outputs.update(epoch_end_logged_outputs)
184186
for dl_outputs in logged_outputs:
185187
dl_outputs.update(epoch_end_logged_outputs)
186188

187189
# log metrics
188-
self.trainer._logger_connector.log_eval_end_metrics()
190+
self.trainer._logger_connector.log_eval_end_metrics(all_logged_outputs)
189191

190192
# hook
191193
self._on_evaluation_end()

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,13 @@ def update_eval_epoch_metrics(self) -> _OUT_DICT:
174174
self._logged_metrics.update(metrics["log"])
175175
return metrics["log"]
176176

177-
def log_eval_end_metrics(self) -> None:
177+
def log_eval_end_metrics(self, metrics: _OUT_DICT) -> None:
178178
assert self._epoch_end_reached
179179
if self.trainer.sanity_checking:
180180
return
181181

182182
# log all the metrics as a single dict
183-
self.log_metrics(self.metrics["log"])
183+
self.log_metrics(metrics)
184184

185185
"""
186186
Train metric updates

pytorch_lightning/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
8989

9090
_IS_WINDOWS = platform.system() == "Windows"
9191
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
92+
_PYTHON_GREATER_EQUAL_3_8_0 = _compare_version("python", operator.ge, "3.8.0")
9293
_TORCH_GREATER_EQUAL_1_8_1 = _compare_version("torch", operator.ge, "1.8.1")
9394
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
9495
_TORCH_GREATER_EQUAL_1_9_1 = _compare_version("torch", operator.ge, "1.9.1")

tests/loops/test_evaluation_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):
4848
def test_log_epoch_metrics_before_on_evaluation_end(update_eval_epoch_metrics_mock, tmpdir):
4949
"""Test that the epoch metrics are logged before the `on_evaluation_end` hook is fired."""
5050
order = []
51-
update_eval_epoch_metrics_mock.side_effect = lambda: order.append("log_epoch_metrics")
51+
update_eval_epoch_metrics_mock.side_effect = lambda _: order.append("log_epoch_metrics")
5252

5353
class LessBoringModel(BoringModel):
5454
def on_validation_end(self):

tests/trainer/logging_/test_eval_loop_logging.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pytorch_lightning.loops.dataloader import EvaluationLoop
2929
from pytorch_lightning.trainer.states import RunningStage
3030
from pytorch_lightning.utilities.exceptions import MisconfigurationException
31+
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
3132
from tests.helpers import BoringModel, RandomDataset
3233
from tests.helpers.runif import RunIf
3334

@@ -573,10 +574,8 @@ def test_step(self, batch, batch_idx):
573574
assert mock_log_metrics.mock_calls[0] == call({"hp_metric": -1}, 0)
574575

575576
def get_metrics_at_idx(idx):
576-
mock_calls = list(mock_log_metrics.mock_calls)
577-
if isinstance(mock_calls[idx].kwargs, dict):
578-
return mock_calls[idx].kwargs["metrics"]
579-
return mock_calls[idx][2]["metrics"]
577+
mock_call = mock_log_metrics.mock_calls[idx]
578+
return mock_call.kwargs["metrics"] if _PYTHON_GREATER_EQUAL_3_8_0 else mock_call[2]["metrics"]
580579

581580
expected = {"valid_loss_0_step", "valid_loss_2"}
582581
assert set(get_metrics_at_idx(1)) == expected
@@ -755,7 +754,8 @@ def test_dataloader(self):
755754
}
756755

757756

758-
def test_logging_multi_dataloader_on_epoch_end(tmpdir):
757+
@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics")
758+
def test_logging_multi_dataloader_on_epoch_end(mock_log_metrics, tmpdir):
759759
class CustomBoringModel(BoringModel):
760760
def test_step(self, batch, batch_idx, dataloader_idx):
761761
self.log("foo", dataloader_idx + 1)
@@ -765,13 +765,21 @@ def test_epoch_end(self, outputs) -> None:
765765
self.log("foobar", sum(sum(o) for o in outputs))
766766

767767
def test_dataloader(self):
768-
return [super().val_dataloader(), super().val_dataloader()]
768+
return [super().test_dataloader(), super().test_dataloader()]
769769

770770
model = CustomBoringModel()
771-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
771+
trainer = Trainer(default_root_dir=tmpdir, limit_test_batches=1)
772772
results = trainer.test(model)
773+
773774
# what's logged in `test_epoch_end` gets included in the results of each dataloader
774775
assert results == [{"foo/dataloader_idx_0": 1, "foobar": 3}, {"foo/dataloader_idx_1": 2, "foobar": 3}]
776+
cb_metrics = set(trainer.callback_metrics)
777+
assert cb_metrics == {"foo/dataloader_idx_0", "foo/dataloader_idx_1", "foobar"}
778+
779+
mock_call = mock_log_metrics.mock_calls[0]
780+
logged_metrics = mock_call.kwargs["metrics"] if _PYTHON_GREATER_EQUAL_3_8_0 else mock_call[2]["metrics"]
781+
cb_metrics.add("epoch")
782+
assert set(logged_metrics) == cb_metrics
775783

776784

777785
inputs0 = ([{"log": torch.tensor(5)}, {"no_log": torch.tensor(6)}], RunningStage.TESTING)

0 commit comments

Comments
 (0)