Skip to content

Commit 55f5e2d

Browse files
rohitgr7carmocca
authored andcommitted
Fix TQDMProgressBar reset and update to show correct time estimation (#12889)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent ab7ad37 commit 55f5e2d

File tree

4 files changed

+30
-25
lines changed

4 files changed

+30
-25
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891))
2222
- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653))
2323
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965))
24+
- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889))
2425

2526

2627
## [1.6.2] - 2022-04-27

pytorch_lightning/callbacks/progress/tqdm_progress.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -262,13 +262,13 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
262262
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
263263
total_val_batches = total_val_batches * val_checks_per_epoch
264264
total_batches = total_train_batches + total_val_batches
265-
self.main_progress_bar.total = convert_inf(total_batches)
265+
self.main_progress_bar.reset(convert_inf(total_batches))
266266
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
267267

268268
def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
269269
current = self.train_batch_idx + self._val_processed
270270
if self._should_update(current, self.main_progress_bar.total):
271-
_update_n(self.main_progress_bar, current)
271+
_update_n(self.main_progress_bar, current, self.refresh_rate)
272272
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
273273

274274
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -288,17 +288,17 @@ def on_validation_batch_start(
288288
if not self.has_dataloader_changed(dataloader_idx):
289289
return
290290

291-
self.val_progress_bar.total = convert_inf(self.total_val_batches_current_dataloader)
291+
self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
292292
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
293293
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
294294

295295
def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
296296
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
297-
_update_n(self.val_progress_bar, self.val_batch_idx)
297+
_update_n(self.val_progress_bar, self.val_batch_idx, self.refresh_rate)
298298

299299
current = self.train_batch_idx + self._val_processed
300300
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
301-
_update_n(self.main_progress_bar, current)
301+
_update_n(self.main_progress_bar, current, self.refresh_rate)
302302

303303
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
304304
if self._main_progress_bar is not None and trainer.state.fn == "fit":
@@ -315,12 +315,12 @@ def on_test_batch_start(
315315
if not self.has_dataloader_changed(dataloader_idx):
316316
return
317317

318-
self.test_progress_bar.total = convert_inf(self.total_test_batches_current_dataloader)
318+
self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
319319
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
320320

321321
def on_test_batch_end(self, *_: Any) -> None:
322322
if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
323-
_update_n(self.test_progress_bar, self.test_batch_idx)
323+
_update_n(self.test_progress_bar, self.test_batch_idx, self.refresh_rate)
324324

325325
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
326326
self.test_progress_bar.close()
@@ -335,12 +335,12 @@ def on_predict_batch_start(
335335
if not self.has_dataloader_changed(dataloader_idx):
336336
return
337337

338-
self.predict_progress_bar.total = convert_inf(self.total_predict_batches_current_dataloader)
338+
self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
339339
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
340340

341341
def on_predict_batch_end(self, *_: Any) -> None:
342342
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
343-
_update_n(self.predict_progress_bar, self.predict_batch_idx)
343+
_update_n(self.predict_progress_bar, self.predict_batch_idx, self.refresh_rate)
344344

345345
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
346346
self.predict_progress_bar.close()
@@ -384,7 +384,10 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
384384
return x
385385

386386

387-
def _update_n(bar: _tqdm, value: int) -> None:
387+
def _update_n(bar: _tqdm, current: int, refresh_rate: int) -> None:
388388
if not bar.disable:
389-
bar.n = value
389+
total = bar.total
390+
leftover = current % refresh_rate
391+
advance = leftover if (current == total and leftover != 0) else refresh_rate
392+
bar.update(advance)
390393
bar.refresh()

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
numpy>=1.17.2
44
torch>=1.8.*
5-
tqdm>=4.41.0
5+
tqdm>=4.57.0
66
PyYAML>=5.4
77
fsspec[http]>=2021.05.0, !=2021.06.0
88
tensorboard>=2.2.0

tests/callbacks/test_tqdm_progress_bar.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def n(self):
5353
@n.setter
5454
def n(self, value):
5555
self.__n = value
56+
5657
# track the changes in the `n` value
5758
if not len(self.n_values) or value != self.n_values[-1]:
5859
self.n_values.append(value)
@@ -158,7 +159,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
158159
assert not pbar.val_progress_bar.leave
159160
assert trainer.num_sanity_val_batches == expected_sanity_steps
160161
assert pbar.val_progress_bar.total_values == expected_sanity_steps
161-
assert pbar.val_progress_bar.n_values == list(range(1, num_sanity_val_steps + 1)) * num_dl
162+
assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl
162163
assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)]
163164

164165
# fit
@@ -177,7 +178,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
177178

178179
# check val progress bar total
179180
assert pbar.val_progress_bar.total_values == m
180-
assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl
181+
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
181182
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
182183
assert not pbar.val_progress_bar.leave
183184

@@ -186,7 +187,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
186187
trainer.validate(model)
187188
assert trainer.num_val_batches == m
188189
assert pbar.val_progress_bar.total_values == m
189-
assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl
190+
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
190191
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
191192

192193
# test
@@ -195,7 +196,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
195196
assert pbar.test_progress_bar.leave
196197
k = trainer.num_test_batches
197198
assert pbar.test_progress_bar.total_values == k
198-
assert pbar.test_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl
199+
assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
199200
assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)]
200201
assert pbar.test_progress_bar.leave
201202

@@ -205,7 +206,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
205206
assert pbar.predict_progress_bar.leave
206207
k = trainer.num_predict_batches
207208
assert pbar.predict_progress_bar.total_values == k
208-
assert pbar.predict_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl
209+
assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
209210
assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)]
210211
assert pbar.predict_progress_bar.leave
211212

@@ -359,13 +360,13 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
359360
@pytest.mark.parametrize(
360361
"train_batches,val_batches,refresh_rate,train_updates,val_updates",
361362
[
362-
[2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]],
363+
[2, 3, 1, [0, 1, 2, 3, 4, 5], [0, 1, 2, 3]],
363364
[0, 0, 3, None, None],
364-
[1, 0, 3, [1], None],
365-
[1, 1, 3, [2], [1]],
366-
[5, 0, 3, [3, 5], None],
367-
[5, 2, 3, [3, 6, 7], [2]],
368-
[5, 2, 6, [6, 7], [2]],
365+
[1, 0, 3, [0, 1], None],
366+
[1, 1, 3, [0, 2], [0, 1]],
367+
[5, 0, 3, [0, 3, 5], None],
368+
[5, 2, 3, [0, 3, 6, 7], [0, 2]],
369+
[5, 2, 6, [0, 6, 7], [0, 2]],
369370
],
370371
)
371372
def test_main_progress_bar_update_amount(
@@ -395,7 +396,7 @@ def test_main_progress_bar_update_amount(
395396
assert progress_bar.val_progress_bar.n_values == val_updates
396397

397398

398-
@pytest.mark.parametrize("test_batches,refresh_rate,updates", [[1, 3, [1]], [3, 1, [1, 2, 3]], [5, 3, [3, 5]]])
399+
@pytest.mark.parametrize("test_batches,refresh_rate,updates", [(1, 3, [0, 1]), (3, 1, [0, 1, 2, 3]), (5, 3, [0, 3, 5])])
399400
def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, updates: list):
400401
"""Test that test progress updates with the correct amount."""
401402
model = BoringModel()
@@ -566,7 +567,7 @@ def test_tqdm_progress_bar_can_be_pickled():
566567

567568
@pytest.mark.parametrize(
568569
["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"],
569-
[(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])],
570+
[(4, [0, 3, 6, 9, 12, 14], [0, 3, 6, 7]), (0.5, [0, 3, 6, 9, 12, 15, 18, 21], [0, 3, 6, 7])],
570571
)
571572
def test_progress_bar_max_val_check_interval(
572573
tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates

0 commit comments

Comments
 (0)