Skip to content

Commit 22bd118

Browse files
awaelchliRohit Gupta
authored andcommitted
Reset metrics before each task starts (#9410)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent a69b940 commit 22bd118

File tree

5 files changed

+76
-16
lines changed

5 files changed

+76
-16
lines changed

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def on_skip(self) -> List:
9393
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
9494
"""Runs the ``on_evaluation_model_eval``, ``on_evaluation_start`` and ``on_evaluation_epoch_start`` hooks"""
9595
void(*args, **kwargs)
96+
9697
# hook
9798
self.on_evaluation_model_eval()
9899
self.trainer.lightning_module.zero_grad()
@@ -208,7 +209,7 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
208209
self.trainer.profiler.describe()
209210

210211
# reset any `torchmetrics.Metric` and the logger connector state
211-
self.trainer.logger_connector.reset(metrics=True)
212+
self.trainer.logger_connector.reset_results(metrics=True)
212213

213214
def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
214215
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks"""

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,15 @@ def should_reset_tensors(self, fx: str) -> bool:
277277
is_first_batch = self._batch_idx + self._split_idx == 0
278278
return is_different_fx and is_first_batch
279279

280-
def reset(self, metrics: Optional[bool] = None) -> None:
281-
if self.trainer.sanity_checking:
282-
# reset metrics
283-
self._progress_bar_metrics = {}
284-
self._logged_metrics = {}
285-
self._callback_metrics = {}
286-
self.trainer._results.reset(metrics=metrics)
280+
def reset_metrics(self) -> None:
281+
self._progress_bar_metrics = {}
282+
self._logged_metrics = {}
283+
self._callback_metrics = {}
284+
285+
def reset_results(self, metrics: Optional[bool] = None) -> None:
286+
if self.trainer._results is not None:
287+
self.trainer._results.reset(metrics=metrics)
288+
287289
self._batch_idx = None
288290
self._split_idx = None
289291
self._current_fx = None

pytorch_lightning/trainer/trainer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,11 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
903903
# ----------------------------
904904
# TRAIN
905905
# ----------------------------
906+
907+
# reset logger connector
908+
self.logger_connector.reset_results()
909+
self.logger_connector.reset_metrics()
910+
906911
# hook
907912
if self.state.fn == TrainerFn.FITTING:
908913
self.call_hook("on_fit_start")
@@ -1103,8 +1108,11 @@ def _run_sanity_check(self, ref_model):
11031108
stage = self.state.stage
11041109
self.sanity_checking = True
11051110

1106-
# hook and callback
1107-
self.on_sanity_check_start()
1111+
# reset logger connector
1112+
self.logger_connector.reset_results()
1113+
self.logger_connector.reset_metrics()
1114+
1115+
self.call_hook("on_sanity_check_start")
11081116

11091117
# reload dataloaders
11101118
self._evaluation_loop.reload_evaluation_dataloaders()
@@ -1115,8 +1123,9 @@ def _run_sanity_check(self, ref_model):
11151123

11161124
self.on_sanity_check_end()
11171125

1118-
# reset validation metrics
1119-
self.logger_connector.reset()
1126+
# reset logger connector
1127+
self.logger_connector.reset_results()
1128+
self.logger_connector.reset_metrics()
11201129

11211130
# reset the seed to what it was before sanity check
11221131
# prevents sanity check to affect random sampling in training

tests/trainer/logging_/test_eval_loop_logging.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,12 @@ def test_step(self, batch, batch_idx):
550550
# hp_metric + 2 steps + epoch + 2 steps + epoch
551551
expected_num_calls = 1 + 2 + 1 + 2 + 1
552552

553+
assert set(trainer.callback_metrics) == {
554+
"train_loss",
555+
"valid_loss_0_epoch",
556+
"valid_loss_0",
557+
"valid_loss_1",
558+
}
553559
assert len(mock_log_metrics.mock_calls) == expected_num_calls
554560
assert mock_log_metrics.mock_calls[0] == call({"hp_metric": -1}, 0)
555561

@@ -583,10 +589,6 @@ def get_metrics_at_idx(idx):
583589

584590
results = trainer.test(model)
585591
assert set(trainer.callback_metrics) == {
586-
"train_loss",
587-
"valid_loss_0_epoch",
588-
"valid_loss_0",
589-
"valid_loss_1",
590592
"test_loss",
591593
}
592594
assert set(results[0]) == {"test_loss"}

tests/trainer/test_trainer.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,3 +1897,49 @@ def current_memory():
18971897
trainer_2.fit(model)
18981898

18991899
assert current_memory() <= initial
1900+
1901+
1902+
def test_trainer_metrics_reset_before_each_task(tmpdir):
1903+
"""Test that callback, logged and progress bar metrics are reset before each task starts."""
1904+
1905+
class TestMetricRestartCallback(Callback):
1906+
def _make_assertions(self, trainer):
1907+
assert trainer.callback_metrics == {}
1908+
assert trainer.progress_bar_metrics == {}
1909+
assert trainer.logged_metrics == {}
1910+
1911+
def on_train_start(self, trainer, *args, **kwargs):
1912+
self._make_assertions(trainer)
1913+
1914+
def on_validation_start(self, trainer, *args, **kwargs):
1915+
if trainer.state.fn == TrainerFn.VALIDATING:
1916+
self._make_assertions(trainer)
1917+
1918+
def on_test_start(self, trainer, *args, **kwargs):
1919+
self._make_assertions(trainer)
1920+
1921+
def on_predict_start(self, trainer, *args, **kwargs):
1922+
self._make_assertions(trainer)
1923+
1924+
class CustomBoringModel(BoringModel):
1925+
def __init__(self):
1926+
super().__init__()
1927+
1928+
def training_step(self, *args, **kwargs):
1929+
self.log("train/metric", 7.0)
1930+
return super().training_step(*args, **kwargs)
1931+
1932+
def validation_step(self, *args, **kwargs):
1933+
self.log("val/metric", 14.0)
1934+
return super().validation_step(*args, **kwargs)
1935+
1936+
def test_step(self, *args, **kwargs):
1937+
self.log("test/metric", 21.0)
1938+
return super().test_step(*args, **kwargs)
1939+
1940+
model = CustomBoringModel()
1941+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=4, callbacks=[TestMetricRestartCallback()])
1942+
trainer.fit(model)
1943+
trainer.validate(model)
1944+
trainer.test(model)
1945+
trainer.predict(model)

0 commit comments

Comments
 (0)