Skip to content

Commit eb49d2c

Browse files
awaelchlicarmocca
authored andcommitted
[bugfix] Resolve metrics not being properly resetted on validation epoch end (#9717)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 06af920 commit eb49d2c

File tree

4 files changed

+62
-6
lines changed

4 files changed

+62
-6
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
- Moved the gradient unscaling in `NativeMixedPrecisionPlugin` from `pre_optimizer_step` to `post_backward` ([#9606](https://github.com/PyTorchLightning/pytorch-lightning/pull/9606))
1111
- Fixed gradient unscaling being called too late, causing gradient clipping and gradient norm tracking to be applied incorrectly ([#9606](https://github.com/PyTorchLightning/pytorch-lightning/pull/9606))
12-
13-
1412
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))
13+
- Fixed `reset` metrics on validation epoch end ([#9717](https://github.com/PyTorchLightning/pytorch-lightning/pull/9717))
14+
1515

1616

1717
## [1.4.8] - 2021-09-22

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
208208
# summarize profile results
209209
self.trainer.profiler.describe()
210210

211-
# reset any `torchmetrics.Metric` and the logger connector state
212-
self.trainer.logger_connector.reset_results(metrics=True)
211+
# reset the logger connector state
212+
self.trainer.logger_connector.reset_results()
213213

214214
def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
215215
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks"""

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,9 @@ def reset_metrics(self) -> None:
282282
self._logged_metrics = {}
283283
self._callback_metrics = {}
284284

285-
def reset_results(self, metrics: Optional[bool] = None) -> None:
285+
def reset_results(self) -> None:
286286
if self.trainer._results is not None:
287-
self.trainer._results.reset(metrics=metrics)
287+
self.trainer._results.reset()
288288

289289
self._batch_idx = None
290290
self._split_idx = None

tests/trainer/logging_/test_eval_loop_logging.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,3 +618,59 @@ def validation_step(self, batch, batch_idx):
618618
)
619619

620620
trainer.fit(model)
621+
622+
623+
@pytest.mark.parametrize("val_check_interval", [0.5, 1.0])
624+
def test_multiple_dataloaders_reset(val_check_interval, tmpdir):
625+
class TestModel(BoringModel):
626+
def training_step(self, batch, batch_idx):
627+
out = super().training_step(batch, batch_idx)
628+
value = 1 + batch_idx
629+
if self.current_epoch != 0:
630+
value *= 10
631+
self.log("batch_idx", value, on_step=True, on_epoch=True, prog_bar=True)
632+
return out
633+
634+
def training_epoch_end(self, outputs):
635+
metrics = self.trainer.progress_bar_metrics
636+
v = 15 if self.current_epoch == 0 else 150
637+
assert metrics["batch_idx_epoch"] == (v / 5.0)
638+
639+
def validation_step(self, batch, batch_idx, dataloader_idx):
640+
value = (1 + batch_idx) * (1 + dataloader_idx)
641+
if self.current_epoch != 0:
642+
value *= 10
643+
self.log("val_loss", value, on_step=False, on_epoch=True, prog_bar=True, logger=True)
644+
return value
645+
646+
def validation_epoch_end(self, outputs):
647+
if self.current_epoch == 0:
648+
assert sum(outputs[0]) / 5 == 3
649+
assert sum(outputs[1]) / 5 == 6
650+
else:
651+
assert sum(outputs[0]) / 5 == 30
652+
assert sum(outputs[1]) / 5 == 60
653+
654+
tot_loss = torch.mean(torch.tensor(outputs, dtype=torch.float))
655+
if self.current_epoch == 0:
656+
assert tot_loss == (3 + 6) / 2
657+
else:
658+
assert tot_loss == (30 + 60) / 2
659+
assert self.trainer._results["validation_step.val_loss.0"].cumulated_batch_size == 5
660+
assert self.trainer._results["validation_step.val_loss.1"].cumulated_batch_size == 5
661+
662+
def val_dataloader(self):
663+
return [super().val_dataloader(), super().val_dataloader()]
664+
665+
model = TestModel()
666+
trainer = Trainer(
667+
default_root_dir=tmpdir,
668+
limit_train_batches=5,
669+
limit_val_batches=5,
670+
num_sanity_val_steps=0,
671+
val_check_interval=val_check_interval,
672+
max_epochs=3,
673+
log_every_n_steps=1,
674+
weights_summary=None,
675+
)
676+
trainer.fit(model)

0 commit comments

Comments
 (0)