Skip to content

Commit bd50b26

Browse files
rohitgr7justusschockcarmocca
authored andcommitted
Fix logging's step values when multiple dataloaders are used during evaluation (#12184)
Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 29b9963 commit bd50b26

File tree

7 files changed

+79
-29
lines changed

7 files changed

+79
-29
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828
- Fixed `materialize_module` setting a module's child recursively ([#12870](https://github.com/PyTorchLightning/pytorch-lightning/pull/12870))
2929
- Fixed issue where the CLI could not pass a `Profiler` to the `Trainer` ([#13084](https://github.com/PyTorchLightning/pytorch-lightning/pull/13084))
3030
- Fixed torchelastic detection with non-distributed installations ([#13142](https://github.com/PyTorchLightning/pytorch-lightning/pull/13142))
31+
- Fixed logging's step values when multiple dataloaders are used during evaluation ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184))
3132

3233

3334
## [1.6.3] - 2022-05-03

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,15 @@ def _get_max_batches(self) -> List[int]:
234234

235235
def _reload_evaluation_dataloaders(self) -> None:
236236
"""Reloads dataloaders if necessary."""
237+
dataloaders = None
237238
if self.trainer.testing:
238239
self.trainer.reset_test_dataloader()
240+
dataloaders = self.trainer.test_dataloaders
239241
elif self.trainer.val_dataloaders is None or self.trainer._data_connector._should_reload_val_dl:
240242
self.trainer.reset_val_dataloader()
243+
dataloaders = self.trainer.val_dataloaders
244+
if dataloaders is not None:
245+
self.epoch_loop._reset_dl_batch_idx(len(dataloaders))
241246

242247
def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
243248
"""Runs ``on_{validation/test}_start`` hooks."""

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self) -> None:
4949
self._dl_max_batches = 0
5050
self._data_fetcher: Optional[AbstractDataFetcher] = None
5151
self._dataloader_state_dict: Dict[str, Any] = {}
52+
self._dl_batch_idx = [0]
5253

5354
@property
5455
def done(self) -> bool:
@@ -135,7 +136,10 @@ def advance( # type: ignore[override]
135136
self.batch_progress.increment_completed()
136137

137138
# log batch metrics
138-
self.trainer._logger_connector.update_eval_step_metrics()
139+
if not self.trainer.sanity_checking:
140+
dataloader_idx = kwargs.get("dataloader_idx", 0)
141+
self.trainer._logger_connector.update_eval_step_metrics(self._dl_batch_idx[dataloader_idx])
142+
self._dl_batch_idx[dataloader_idx] += 1
139143

140144
# track epoch level outputs
141145
if self._should_track_batch_outputs_for_epoch_end() and output is not None:
@@ -287,3 +291,6 @@ def _should_track_batch_outputs_for_epoch_end(self) -> bool:
287291
if self.trainer.testing:
288292
return is_overridden("test_epoch_end", model)
289293
return is_overridden("validation_epoch_end", model)
294+
295+
def _reset_dl_batch_idx(self, num_dataloaders: int) -> None:
296+
self._dl_batch_idx = [0] * num_dataloaders

pytorch_lightning/loops/fit_loop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,9 @@ def on_run_start(self) -> None: # type: ignore[override]
204204
if not self._iteration_based_training():
205205
self.epoch_progress.current.completed = self.epoch_progress.current.processed
206206

207-
# reset train dataloader and val dataloader
208-
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
207+
self.trainer.reset_train_dataloader(self.trainer.lightning_module)
208+
# reload the evaluation dataloaders too for proper display in the progress bar
209+
self.epoch_loop.val_loop._reload_evaluation_dataloaders()
209210

210211
data_fetcher_cls = _select_data_fetcher(self.trainer)
211212
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger
2121
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
23-
from pytorch_lightning.trainer.states import RunningStage
2423
from pytorch_lightning.utilities import memory
2524
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
2625
from pytorch_lightning.utilities.metrics import metrics_to_scalars
@@ -37,8 +36,6 @@ def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None)
3736
"Please monitor GPU stats with the `DeviceStatsMonitor` callback directly instead."
3837
)
3938
self.log_gpu_memory = log_gpu_memory
40-
self._val_log_step: int = 0
41-
self._test_log_step: int = 0
4239
self._progress_bar_metrics: _PBAR_DICT = {}
4340
self._logged_metrics: _OUT_DICT = {}
4441
self._callback_metrics: _OUT_DICT = {}
@@ -134,35 +131,15 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
134131
Evaluation metric updates
135132
"""
136133

137-
@property
138-
def _eval_log_step(self) -> Optional[int]:
139-
if self.trainer.state.stage is RunningStage.VALIDATING:
140-
return self._val_log_step
141-
if self.trainer.state.stage is RunningStage.TESTING:
142-
return self._test_log_step
143-
return None
144-
145-
def _increment_eval_log_step(self) -> None:
146-
if self.trainer.state.stage is RunningStage.VALIDATING:
147-
self._val_log_step += 1
148-
elif self.trainer.state.stage is RunningStage.TESTING:
149-
self._test_log_step += 1
150-
151134
def _evaluation_epoch_end(self) -> None:
152135
results = self.trainer._results
153136
assert results is not None
154137
results.dataloader_idx = None
155138

156-
def update_eval_step_metrics(self) -> None:
139+
def update_eval_step_metrics(self, step: int) -> None:
157140
assert not self._epoch_end_reached
158-
if self.trainer.sanity_checking:
159-
return
160-
161141
# logs user requested information to logger
162-
self.log_metrics(self.metrics["log"], step=self._eval_log_step)
163-
164-
# increment the step even if nothing was logged
165-
self._increment_eval_log_step()
142+
self.log_metrics(self.metrics["log"], step=step)
166143

167144
def update_eval_epoch_metrics(self) -> _OUT_DICT:
168145
assert self._epoch_end_reached

tests/trainer/logging_/test_eval_loop_logging.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,3 +973,62 @@ def test_rich_print_results(inputs, expected):
973973
EvaluationLoop._print_results(*inputs)
974974
expected = expected[1:] # remove the initial line break from the """ string
975975
assert capture.get() == expected.lstrip()
976+
977+
978+
@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics")
979+
@pytest.mark.parametrize("num_dataloaders", (1, 2))
980+
def test_eval_step_logging(mock_log_metrics, tmpdir, num_dataloaders):
981+
"""Test that eval step during fit/validate/test is updated correctly."""
982+
983+
class CustomBoringModel(BoringModel):
984+
def validation_step(self, batch, batch_idx, dataloader_idx=None):
985+
self.log(f"val_log_{self.trainer.state.fn}", batch_idx, on_step=True, on_epoch=False)
986+
987+
def test_step(self, batch, batch_idx, dataloader_idx=None):
988+
self.log("test_log", batch_idx, on_step=True, on_epoch=False)
989+
990+
def val_dataloader(self):
991+
return [super().val_dataloader()] * num_dataloaders
992+
993+
def test_dataloader(self):
994+
return [super().test_dataloader()] * num_dataloaders
995+
996+
validation_epoch_end = None
997+
test_epoch_end = None
998+
999+
limit_batches = 4
1000+
max_epochs = 3
1001+
trainer = Trainer(
1002+
default_root_dir=tmpdir,
1003+
max_epochs=max_epochs,
1004+
limit_train_batches=1,
1005+
limit_val_batches=limit_batches,
1006+
limit_test_batches=limit_batches,
1007+
)
1008+
model = CustomBoringModel()
1009+
1010+
trainer.fit(model)
1011+
trainer.validate(model)
1012+
trainer.test(model)
1013+
1014+
def get_suffix(dl_idx):
1015+
return f"/dataloader_idx_{dl_idx}" if num_dataloaders == 2 else ""
1016+
1017+
eval_steps = range(limit_batches)
1018+
fit_calls = [
1019+
call(metrics={f"val_log_fit{get_suffix(dl_idx)}": float(step)}, step=step + (limit_batches * epoch))
1020+
for epoch in range(max_epochs)
1021+
for dl_idx in range(num_dataloaders)
1022+
for step in eval_steps
1023+
]
1024+
validate_calls = [
1025+
call(metrics={f"val_log_validate{get_suffix(dl_idx)}": float(val)}, step=val)
1026+
for dl_idx in range(num_dataloaders)
1027+
for val in eval_steps
1028+
]
1029+
test_calls = [
1030+
call(metrics={f"test_log{get_suffix(dl_idx)}": float(val)}, step=val)
1031+
for dl_idx in range(num_dataloaders)
1032+
for val in eval_steps
1033+
]
1034+
assert mock_log_metrics.mock_calls == fit_calls + validate_calls + test_calls

tests/trainer/test_dataloaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,7 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir):
12431243

12441244
assert tracker.mock_calls == [
12451245
call.reset_val_dataloader(),
1246-
call.reset_train_dataloader(model=model),
1246+
call.reset_train_dataloader(model),
12471247
call.reset_test_dataloader(),
12481248
]
12491249

0 commit comments

Comments
 (0)