Skip to content

Commit d15460f

Browse files
committed
refactor: revert some changes
1 parent 4da13b4 commit d15460f

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,8 @@ def on_validation_batch_start(
498498
visible=False,
499499
)
500500

501+
self.refresh()
502+
501503
def _add_task(self, total_batches: Union[int, float], description: str, visible: bool = True) -> "TaskID":
502504
assert self.progress is not None
503505
return self.progress.add_task(
@@ -512,22 +514,27 @@ def _initialize_train_progress_bar_id(self) -> None:
512514
self.train_progress_bar_id = self._add_task(total_batches, train_description)
513515

514516
def _update(
515-
self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True, refresh: bool = True
517+
self,
518+
progress_bar_id: Optional["TaskID"],
519+
current: int,
520+
visible: bool = True,
521+
hard: bool = False,
516522
) -> None:
517523
if self.progress is not None and self.is_enabled and progress_bar_id is not None:
518-
self.progress.update(progress_bar_id, completed=current, visible=visible, refresh=refresh)
524+
self.progress.update(progress_bar_id, completed=current, visible=visible)
525+
self.refresh(hard=hard)
519526

520527
@override
521528
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
522529
if self.is_enabled and self.val_progress_bar_id is not None and trainer.state.fn == "fit":
523530
assert self.progress is not None
524-
self.progress.update(self.val_progress_bar_id, advance=0, visible=False, refresh=True)
531+
self.progress.update(self.val_progress_bar_id, advance=0, visible=False)
532+
self.refresh()
525533

526534
@override
527535
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
528536
if trainer.state.fn == "fit":
529537
self._update_metrics(trainer, pl_module)
530-
self.refresh()
531538
self.reset_dataloader_idx_tracker()
532539

533540
@override
@@ -554,6 +561,7 @@ def on_test_batch_start(
554561
assert self.progress is not None
555562
self.progress.update(self.test_progress_bar_id, advance=0, visible=False)
556563
self.test_progress_bar_id = self._add_task(self.total_test_batches_current_dataloader, self.test_description)
564+
self.refresh()
557565

558566
@override
559567
def on_predict_batch_start(
@@ -573,6 +581,7 @@ def on_predict_batch_start(
573581
self.predict_progress_bar_id = self._add_task(
574582
self.total_predict_batches_current_dataloader, self.predict_description
575583
)
584+
self.refresh()
576585

577586
@override
578587
def on_train_batch_end(
@@ -586,7 +595,7 @@ def on_train_batch_end(
586595
if not self.is_disabled and self.train_progress_bar_id is None:
587596
# can happen when resuming from a mid-epoch restart
588597
self._initialize_train_progress_bar_id()
589-
self._update(self.train_progress_bar_id, batch_idx + 1, refresh=False)
598+
self._update(self.train_progress_bar_id, batch_idx + 1)
590599
self._update_metrics(trainer, pl_module)
591600
self.refresh()
592601

@@ -609,12 +618,12 @@ def on_validation_batch_end(
609618
return
610619
if trainer.sanity_checking:
611620
if self.val_sanity_progress_bar_id is not None:
612-
self._update(self.val_sanity_progress_bar_id, batch_idx + 1, refresh=True)
621+
self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
613622
return
614623

615624
if self.val_progress_bar_id is None:
616625
return
617-
self._update(self.val_progress_bar_id, batch_idx + 1, refresh=True)
626+
self._update(self.val_progress_bar_id, batch_idx + 1)
618627

619628
@override
620629
def on_test_batch_end(
@@ -628,7 +637,7 @@ def on_test_batch_end(
628637
) -> None:
629638
if self.is_disabled or self.test_progress_bar_id is None:
630639
return
631-
self._update(self.test_progress_bar_id, batch_idx + 1, refresh=True)
640+
self._update(self.test_progress_bar_id, batch_idx + 1)
632641

633642
@override
634643
def on_predict_batch_end(
@@ -642,7 +651,7 @@ def on_predict_batch_end(
642651
) -> None:
643652
if self.is_disabled or self.predict_progress_bar_id is None:
644653
return
645-
self._update(self.predict_progress_bar_id, batch_idx + 1, refresh=True)
654+
self._update(self.predict_progress_bar_id, batch_idx + 1)
646655

647656
def _get_train_description(self, current_epoch: int) -> str:
648657
train_description = f"Epoch {current_epoch}"

0 commit comments

Comments
 (0)