Skip to content

Commit 6aca974

Browse files
rohitgr7carmocca
authored andcommitted
Run main progress bar independent of val progress bar in TQDMProgressBar (#12563)
Co-authored-by: carmocca <[email protected]>
1 parent 23d3d46 commit 6aca974

File tree

3 files changed

+66
-19
lines changed

3 files changed

+66
-19
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7373

7474
### Fixed
7575

76+
- Run main progress bar updates independent of val progress bar updates in `TQDMProgressBar` ([#12563](https://github.com/PyTorchLightning/pytorch-lightning/pull/12563))
77+
78+
7679
- Avoid calling `average_parameters` multiple times per optimizer step ([#12452](https://github.com/PyTorchLightning/pytorch-lightning/pull/12452))
7780

7881

pytorch_lightning/callbacks/progress/tqdm_progress.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
266266
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
267267

268268
def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
269-
if self._should_update(self.train_batch_idx, self.total_train_batches):
270-
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
269+
current = self.train_batch_idx + self._val_processed
270+
if self._should_update(current, self.main_progress_bar.total):
271+
_update_n(self.main_progress_bar, current)
271272
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
272273

273274
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -292,10 +293,12 @@ def on_validation_batch_start(
292293
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
293294

294295
def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
295-
if self._should_update(self.val_batch_idx, self.total_val_batches_current_dataloader):
296+
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
296297
_update_n(self.val_progress_bar, self.val_batch_idx)
297-
if trainer.state.fn == "fit":
298-
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
298+
299+
current = self.train_batch_idx + self._val_processed
300+
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
301+
_update_n(self.main_progress_bar, current)
299302

300303
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
301304
if self._main_progress_bar is not None and trainer.state.fn == "fit":
@@ -316,7 +319,7 @@ def on_test_batch_start(
316319
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
317320

318321
def on_test_batch_end(self, *_: Any) -> None:
319-
if self._should_update(self.test_batch_idx, self.total_test_batches_current_dataloader):
322+
if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
320323
_update_n(self.test_progress_bar, self.test_batch_idx)
321324

322325
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -336,7 +339,7 @@ def on_predict_batch_start(
336339
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
337340

338341
def on_predict_batch_end(self, *_: Any) -> None:
339-
if self._should_update(self.predict_batch_idx, self.total_predict_batches_current_dataloader):
342+
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
340343
_update_n(self.predict_progress_bar, self.predict_batch_idx)
341344

342345
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -359,8 +362,8 @@ def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None:
359362
s = sep.join(map(str, args))
360363
active_progress_bar.write(s, **kwargs)
361364

362-
def _should_update(self, current: int, total: Union[int, float]) -> bool:
363-
return self.refresh_rate > 0 and (current % self.refresh_rate == 0 or current == total)
365+
def _should_update(self, current: int, total: int) -> bool:
366+
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
364367

365368
@staticmethod
366369
def _resolve_refresh_rate(refresh_rate: int) -> int:

tests/callbacks/test_tqdm_progress_bar.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import math
1415
import os
1516
import pickle
1617
import sys
@@ -361,10 +362,10 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
361362
[2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]],
362363
[0, 0, 3, None, None],
363364
[1, 0, 3, [1], None],
364-
[1, 1, 3, [1, 2], [1]],
365+
[1, 1, 3, [2], [1]],
365366
[5, 0, 3, [3, 5], None],
366-
[5, 2, 3, [3, 5, 7], [2]],
367-
[5, 2, 6, [5, 7], [2]],
367+
[5, 2, 3, [3, 6, 7], [2]],
368+
[5, 2, 6, [6, 7], [2]],
368369
],
369370
)
370371
def test_main_progress_bar_update_amount(
@@ -563,16 +564,56 @@ def test_tqdm_progress_bar_can_be_pickled():
563564
pickle.dumps(bar)
564565

565566

566-
@RunIf(min_gpus=2, standalone=True)
567567
@pytest.mark.parametrize(
568-
["total_train_samples", "train_batch_size", "total_val_samples", "val_batch_size", "val_check_interval"],
569-
[(8, 4, 2, 1, 0.2), (8, 4, 2, 1, 0.5)],
568+
["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"],
569+
[(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])],
570570
)
571571
def test_progress_bar_max_val_check_interval(
572-
tmpdir, total_train_samples, train_batch_size, total_val_samples, val_batch_size, val_check_interval
572+
tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates
573573
):
574+
limit_batches = 7
575+
model = BoringModel()
576+
trainer = Trainer(
577+
default_root_dir=tmpdir,
578+
num_sanity_val_steps=0,
579+
max_epochs=1,
580+
enable_model_summary=False,
581+
val_check_interval=val_check_interval,
582+
limit_train_batches=limit_batches,
583+
limit_val_batches=limit_batches,
584+
callbacks=TQDMProgressBar(refresh_rate=3),
585+
)
586+
with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
587+
trainer.fit(model)
588+
589+
pbar = trainer.progress_bar_callback
590+
assert pbar.main_progress_bar.n_values == main_progress_bar_updates
591+
assert pbar.val_progress_bar.n_values == val_progress_bar_updates
592+
593+
val_check_batch = (
594+
max(1, int(limit_batches * val_check_interval)) if isinstance(val_check_interval, float) else val_check_interval
595+
)
596+
assert trainer.val_check_batch == val_check_batch
597+
val_checks_per_epoch = math.ceil(limit_batches // val_check_batch)
598+
pbar_callback = trainer.progress_bar_callback
599+
total_val_batches = limit_batches * val_checks_per_epoch
600+
601+
assert pbar_callback.val_progress_bar.n == limit_batches
602+
assert pbar_callback.val_progress_bar.total == limit_batches
603+
assert pbar_callback.main_progress_bar.n == limit_batches + total_val_batches
604+
assert pbar_callback.main_progress_bar.total == limit_batches + total_val_batches
605+
assert pbar_callback.is_enabled
606+
607+
608+
@RunIf(min_gpus=2, standalone=True)
609+
@pytest.mark.parametrize("val_check_interval", [0.2, 0.5])
610+
def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval):
574611
world_size = 2
575-
train_data = DataLoader(RandomDataset(32, total_train_samples), batch_size=train_batch_size)
612+
total_train_samples = 16
613+
train_batch_size = 4
614+
total_val_samples = 2
615+
val_batch_size = 1
616+
train_data = DataLoader(RandomDataset(32, 8), batch_size=train_batch_size)
576617
val_data = DataLoader(RandomDataset(32, total_val_samples), batch_size=val_batch_size)
577618

578619
model = BoringModel()
@@ -599,8 +640,8 @@ def test_progress_bar_max_val_check_interval(
599640
assert pbar_callback.val_progress_bar.n == total_val_batches
600641
assert pbar_callback.val_progress_bar.total == total_val_batches
601642
total_val_batches = total_val_batches * val_checks_per_epoch
602-
assert pbar_callback.main_progress_bar.n == total_train_batches + total_val_batches
603-
assert pbar_callback.main_progress_bar.total == total_train_batches + total_val_batches
643+
assert pbar_callback.main_progress_bar.n == (total_train_batches + total_val_batches) // world_size
644+
assert pbar_callback.main_progress_bar.total == (total_train_batches + total_val_batches) // world_size
604645
assert pbar_callback.is_enabled
605646

606647

0 commit comments

Comments
 (0)