Skip to content

Commit 34fcb87

Browse files
kaushikb11tchatoncarmocca
authored
Add leave argument to RichProgressBar (#10301)
* Add display_every_n_epochs argument to RichProgressBar * Add tests * Update test * Update test * Update changelog * use leave argument instead * Update pytorch_lightning/callbacks/progress/rich_progress.py Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: thomas chaton <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 9e844d9 commit 34fcb87

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8383
* Added Rich progress bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929), [#9559](https://github.com/PyTorchLightning/pytorch-lightning/pull/9559))
8484
* Added Support for iterable datasets ([#9734](https://github.com/PyTorchLightning/pytorch-lightning/pull/9734))
8585
* Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546))
86+
* Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288))
87+
* Added `leave` argument to `RichProgressBar` ([#10301](https://github.com/PyTorchLightning/pytorch-lightning/pull/10301))
8688
- Added input validation logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))
8789
- Added support for CPU AMP autocast ([#9084](https://github.com/PyTorchLightning/pytorch-lightning/pull/9084))
8890
- Added `on_exception` callback hook ([#9183](https://github.com/PyTorchLightning/pytorch-lightning/pull/9183))
@@ -128,7 +130,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
128130
- Added support for `devices="auto"` ([#10264](https://github.com/PyTorchLightning/pytorch-lightning/pull/10264))
129131
- Added a `filename` argument in `ModelCheckpoint.format_checkpoint_name` ([#9818](https://github.com/PyTorchLightning/pytorch-lightning/pull/9818))
130132
- Added support for empty `gpus` list to run on CPU ([#10246](https://github.com/PyTorchLightning/pytorch-lightning/pull/10246))
131-
- Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288))
132133
- Added a warning if multiple batch sizes are found from ambiguous batch ([#10247](https://github.com/PyTorchLightning/pytorch-lightning/pull/10247))
133134

134135

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ class RichProgressBar(ProgressBarBase):
195195
196196
Args:
197197
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
198+
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
198199
theme: Contains styles used to stylize the progress bar.
199200
200201
Raises:
@@ -205,6 +206,7 @@ class RichProgressBar(ProgressBarBase):
205206
def __init__(
206207
self,
207208
refresh_rate_per_second: int = 10,
209+
leave: bool = False,
208210
theme: RichProgressBarTheme = RichProgressBarTheme(),
209211
) -> None:
210212
if not _RICH_AVAILABLE:
@@ -213,6 +215,7 @@ def __init__(
213215
)
214216
super().__init__()
215217
self._refresh_rate_per_second: int = refresh_rate_per_second
218+
self._leave: bool = leave
216219
self._enabled: bool = True
217220
self.progress: Optional[Progress] = None
218221
self.val_sanity_progress_bar_id: Optional[int] = None
@@ -323,9 +326,15 @@ def on_train_epoch_start(self, trainer, pl_module):
323326
total_batches = total_train_batches + total_val_batches
324327

325328
train_description = self._get_train_description(trainer.current_epoch)
329+
if self.main_progress_bar_id is not None and self._leave:
330+
self._stop_progress()
331+
self._init_progress(trainer, pl_module)
326332
if self.main_progress_bar_id is None:
327333
self.main_progress_bar_id = self._add_task(total_batches, train_description)
328-
self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description)
334+
else:
335+
self.progress.reset(
336+
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
337+
)
329338

330339
def on_validation_epoch_start(self, trainer, pl_module):
331340
super().on_validation_epoch_start(trainer, pl_module)

tests/callbacks/test_rich_progress_bar.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def on_train_start(self) -> None:
144144

145145

146146
@RunIf(rich=True)
147-
def test_rich_progress_bar_configure_columns(tmpdir):
147+
def test_rich_progress_bar_configure_columns():
148148
from rich.progress import TextColumn
149149

150150
custom_column = TextColumn("[progress.description]Testing Rich!")
@@ -159,3 +159,24 @@ def configure_columns(self, trainer, pl_module):
159159

160160
assert progress_bar.progress.columns[0] == custom_column
161161
assert len(progress_bar.progress.columns) == 1
162+
163+
164+
@RunIf(rich=True)
165+
@pytest.mark.parametrize(("leave", "reset_call_count"), ([(True, 0), (False, 5)]))
166+
def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):
167+
# Calling `reset` means continuing on the same progress bar.
168+
model = BoringModel()
169+
170+
with mock.patch(
171+
"pytorch_lightning.callbacks.progress.rich_progress.Progress.reset", autospec=True
172+
) as mock_progress_reset:
173+
progress_bar = RichProgressBar(leave=leave)
174+
trainer = Trainer(
175+
default_root_dir=tmpdir,
176+
num_sanity_val_steps=0,
177+
limit_train_batches=1,
178+
max_epochs=6,
179+
callbacks=progress_bar,
180+
)
181+
trainer.fit(model)
182+
assert mock_progress_reset.call_count == reset_call_count

0 commit comments

Comments
 (0)