Skip to content

Commit 5f616ac

Browse files
rohitgr7carmoccaBorda
authored andcommitted
Fix rich progress bar metric render on epoch end (#11689)
Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: Jirka <[email protected]>
1 parent 8524d03 commit 5f616ac

File tree

4 files changed

+77
-0
lines changed

4 files changed

+77
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- Fixed the format of the configuration saved automatically by the CLI's `SaveConfigCallback` ([#11532](https://github.com/PyTorchLightning/pytorch-lightning/pull/11532))
1212
- Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552))
13+
- The Rich progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))
1314

1415

1516
## [1.5.9] - 2022-01-18

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from datetime import timedelta
1717
from typing import Any, Optional, Union
1818

19+
import pytorch_lightning as pl
1920
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
2021
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2122
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
@@ -379,6 +380,10 @@ def on_validation_epoch_end(self, trainer, pl_module):
379380
if self.val_progress_bar_id is not None:
380381
self._update(self.val_progress_bar_id, visible=False)
381382

383+
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
384+
if trainer.state.fn == "fit":
385+
self._update_metrics(trainer, pl_module)
386+
382387
def on_test_epoch_start(self, trainer, pl_module):
383388
super().on_train_epoch_start(trainer, pl_module)
384389
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
@@ -392,6 +397,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
392397
self._update(self.main_progress_bar_id)
393398
self._update_metrics(trainer, pl_module)
394399

400+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
401+
self._update_metrics(trainer, pl_module)
402+
395403
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
396404
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
397405
if trainer.sanity_checking:

tests/callbacks/test_rich_progress_bar.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from collections import defaultdict
1415
from unittest import mock
1516
from unittest.mock import DEFAULT, Mock
1617

@@ -201,3 +202,68 @@ def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int):
201202

202203
trainer.fit(model)
203204
assert progress_bar.progress.tasks[0].completed == min(num_sanity_val_steps, limit_val_batches)
205+
206+
207+
@RunIf(rich=True)
208+
def test_rich_progress_bar_correct_value_epoch_end(tmpdir):
209+
"""Rich counterpart to test_tqdm_progress_bar::test_tqdm_progress_bar_correct_value_epoch_end."""
210+
211+
class MockedProgressBar(RichProgressBar):
212+
calls = defaultdict(list)
213+
214+
def get_metrics(self, trainer, pl_module):
215+
items = super().get_metrics(trainer, model)
216+
del items["v_num"]
217+
del items["loss"]
218+
# this is equivalent to mocking `set_postfix` as this method gets called every time
219+
self.calls[trainer.state.fn].append(
220+
(trainer.state.stage, trainer.current_epoch, trainer.global_step, items)
221+
)
222+
return items
223+
224+
class MyModel(BoringModel):
225+
def training_step(self, batch, batch_idx):
226+
self.log("a", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
227+
return super().training_step(batch, batch_idx)
228+
229+
def validation_step(self, batch, batch_idx):
230+
self.log("b", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
231+
return super().validation_step(batch, batch_idx)
232+
233+
def test_step(self, batch, batch_idx):
234+
self.log("c", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
235+
return super().test_step(batch, batch_idx)
236+
237+
model = MyModel()
238+
pbar = MockedProgressBar()
239+
trainer = Trainer(
240+
default_root_dir=tmpdir,
241+
limit_train_batches=2,
242+
limit_val_batches=2,
243+
limit_test_batches=2,
244+
max_epochs=2,
245+
enable_model_summary=False,
246+
enable_checkpointing=False,
247+
log_every_n_steps=1,
248+
callbacks=pbar,
249+
)
250+
251+
trainer.fit(model)
252+
assert pbar.calls["fit"] == [
253+
("sanity_check", 0, 0, {"b": 0}),
254+
("train", 0, 0, {}),
255+
("train", 0, 1, {}),
256+
("validate", 0, 1, {"b": 1}), # validation end
257+
# epoch end over, `on_epoch=True` metrics are computed
258+
("train", 0, 2, {"a": 1, "b": 1}), # training epoch end
259+
("train", 1, 2, {"a": 1, "b": 1}),
260+
("train", 1, 3, {"a": 1, "b": 1}),
261+
("validate", 1, 3, {"a": 1, "b": 3}), # validation end
262+
("train", 1, 4, {"a": 3, "b": 3}), # training epoch end
263+
]
264+
265+
trainer.validate(model, verbose=False)
266+
assert pbar.calls["validate"] == []
267+
268+
trainer.test(model, verbose=False)
269+
assert pbar.calls["test"] == []

tests/callbacks/test_tqdm_progress_bar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,8 @@ def test_tqdm_progress_bar_main_bar_resume():
611611

612612

613613
def test_tqdm_progress_bar_correct_value_epoch_end(tmpdir):
614+
"""TQDM counterpart to test_rich_progress_bar::test_rich_progress_bar_correct_value_epoch_end."""
615+
614616
class MockedProgressBar(TQDMProgressBar):
615617
calls = defaultdict(list)
616618

0 commit comments

Comments
 (0)