Skip to content

Commit 6631bb8

Browse files
rohitgr7carmocca
authored andcommitted
Fix to avoid val progress bar disappear after validate (#11700)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 0bd69c9 commit 6631bb8

File tree

4 files changed

+8
-4
lines changed

4 files changed

+8
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414
- Fixed an issue to make the `step` argument in `WandbLogger.log_image` work ([#11716](https://github.com/PyTorchLightning/pytorch-lightning/pull/11716))
1515
- Fixed `restore_optimizers` for mapping states ([#11757](https://github.com/PyTorchLightning/pytorch-lightning/pull/11757))
1616
- With `DPStrategy`, the batch is not explictly moved to the device ([#11780](https://github.com/PyTorchLightning/pytorch-lightning/pull/11780))
17-
17+
- Fixed an issue to avoid val bar disappear after `trainer.validate()` ([#11700](https://github.com/PyTorchLightning/pytorch-lightning/pull/11700))
1818

1919

2020
## [1.5.9] - 2022-01-18

_notebooks

pytorch_lightning/callbacks/progress/tqdm_progress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,12 @@ def init_predict_tqdm(self) -> Tqdm:
183183
def init_validation_tqdm(self) -> Tqdm:
184184
"""Override this to customize the tqdm bar for validation."""
185185
# The main progress bar doesn't exist in `trainer.validate()`
186-
has_main_bar = self.main_progress_bar is not None
186+
has_main_bar = self.trainer.state.fn != "validate"
187187
bar = Tqdm(
188188
desc="Validating",
189189
position=(2 * self.process_position + has_main_bar),
190190
disable=self.is_disabled,
191-
leave=False,
191+
leave=not has_main_bar,
192192
dynamic_ncols=True,
193193
file=sys.stdout,
194194
)

tests/callbacks/test_tqdm_progress_bar.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,12 @@ def test_tqdm_progress_bar_totals(tmpdir):
104104
m = bar.total_val_batches
105105
assert len(trainer.train_dataloader) == n
106106
assert bar.main_progress_bar.total == n + m
107+
assert bar.main_progress_bar.leave
107108

108109
# check val progress bar total
109110
assert sum(len(loader) for loader in trainer.val_dataloaders) == m
110111
assert bar.val_progress_bar.total == m
112+
assert not bar.val_progress_bar.leave
111113

112114
# main progress bar should have reached the end (train batches + val batches)
113115
assert bar.main_progress_bar.n == n + m
@@ -126,13 +128,15 @@ def test_tqdm_progress_bar_totals(tmpdir):
126128
assert bar.val_progress_bar.total == m
127129
assert bar.val_progress_bar.n == m
128130
assert bar.val_batch_idx == m
131+
assert bar.val_progress_bar.leave
129132

130133
trainer.test(model)
131134

132135
# check test progress bar total
133136
k = bar.total_test_batches
134137
assert sum(len(loader) for loader in trainer.test_dataloaders) == k
135138
assert bar.test_progress_bar.total == k
139+
assert bar.test_progress_bar.leave
136140

137141
# test progress bar should have reached the end
138142
assert bar.test_progress_bar.n == k

0 commit comments

Comments
 (0)