Skip to content

Commit 8ebd28f

Browse files
awaelchliBorda
authored andcommitted
fix incomplete progress bar when refresh_rate > num batches (#4577)
* fix progress bar overshoot * fix updates for partially incomplete main progress bar when val loop starts * add tests * chlog (cherry picked from commit 89e8796)
1 parent 1ecfa6f commit 8ebd28f

File tree

3 files changed

+99
-8
lines changed

3 files changed

+99
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4343

4444

4545

46+
- Fixed incomplete progress bars when total batches not divisible by refresh rate ([#4577](https://github.com/PyTorchLightning/pytorch-lightning/pull/4577))
4647

4748
## [1.0.7] - 2020-11-17
4849

pytorch_lightning/callbacks/progress.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -334,21 +334,22 @@ def on_epoch_start(self, trainer, pl_module):
334334

335335
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
336336
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
337-
if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0:
338-
self.main_progress_bar.update(self.refresh_rate)
337+
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
338+
self._update_bar(self.main_progress_bar)
339339
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
340340

341341
def on_validation_start(self, trainer, pl_module):
342342
super().on_validation_start(trainer, pl_module)
343343
if not trainer.running_sanity_check:
344+
self._update_bar(self.main_progress_bar) # fill up remaining
344345
self.val_progress_bar = self.init_validation_tqdm()
345346
self.val_progress_bar.total = convert_inf(self.total_val_batches)
346347

347348
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
348349
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
349-
if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0:
350-
self.val_progress_bar.update(self.refresh_rate)
351-
self.main_progress_bar.update(self.refresh_rate)
350+
if self._should_update(self.val_batch_idx, self.total_val_batches):
351+
self._update_bar(self.val_progress_bar)
352+
self._update_bar(self.main_progress_bar)
352353

353354
def on_validation_end(self, trainer, pl_module):
354355
super().on_validation_end(trainer, pl_module)
@@ -366,13 +367,26 @@ def on_test_start(self, trainer, pl_module):
366367

367368
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
368369
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
369-
if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0:
370-
self.test_progress_bar.update(self.refresh_rate)
370+
if self._should_update(self.test_batch_idx, self.total_test_batches):
371+
self._update_bar(self.test_progress_bar)
371372

372373
def on_test_end(self, trainer, pl_module):
373374
super().on_test_end(trainer, pl_module)
374375
self.test_progress_bar.close()
375376

377+
def _should_update(self, current, total):
378+
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
379+
380+
def _update_bar(self, bar):
381+
""" Updates the bar by the refresh rate without overshooting. """
382+
if bar.total is not None:
383+
delta = min(self.refresh_rate, bar.total - bar.n)
384+
else:
385+
# infinite / unknown size
386+
delta = self.refresh_rate
387+
if delta > 0:
388+
bar.update(delta)
389+
376390

377391
def convert_inf(x):
378392
""" The tqdm doesn't support inf values. We have to convert it to None. """

tests/callbacks/test_progress_bar.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from unittest.mock import Mock, call
16+
1517
import pytest
1618
from unittest import mock
1719

1820
from pytorch_lightning import Trainer
1921
from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint
2022
from pytorch_lightning.utilities.exceptions import MisconfigurationException
21-
from tests.base import EvalModelTemplate
23+
from tests.base import EvalModelTemplate, BoringModel
2224

2325

2426
@pytest.mark.parametrize('callbacks,refresh_rate', [
@@ -252,3 +254,77 @@ def test_progress_bar_warning_on_colab(tmpdir):
252254
)
253255

254256
assert trainer.progress_bar_callback.refresh_rate == 19
257+
258+
259+
class MockedUpdateProgressBars(ProgressBar):
260+
""" Mocks the update method once bars get initializied. """
261+
262+
def _mock_bar_update(self, bar):
263+
bar.update = Mock(wraps=bar.update)
264+
return bar
265+
266+
def init_train_tqdm(self):
267+
bar = super().init_train_tqdm()
268+
return self._mock_bar_update(bar)
269+
270+
def init_validation_tqdm(self):
271+
bar = super().init_validation_tqdm()
272+
return self._mock_bar_update(bar)
273+
274+
def init_test_tqdm(self):
275+
bar = super().init_test_tqdm()
276+
return self._mock_bar_update(bar)
277+
278+
279+
@pytest.mark.parametrize("train_batches,val_batches,refresh_rate,train_deltas,val_deltas", [
280+
[2, 3, 1, [1, 1, 1, 1, 1], [1, 1, 1]],
281+
[0, 0, 3, [], []],
282+
[1, 0, 3, [1], []],
283+
[1, 1, 3, [2], [1]],
284+
[5, 0, 3, [3, 2], []],
285+
[5, 2, 3, [3, 3, 1], [2]],
286+
[5, 2, 6, [6, 1], [2]],
287+
])
288+
def test_main_progress_bar_update_amount(tmpdir, train_batches, val_batches, refresh_rate, train_deltas, val_deltas):
289+
"""
290+
Test that the main progress updates with the correct amount together with the val progress. At the end of
291+
the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate.
292+
"""
293+
model = BoringModel()
294+
progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate)
295+
trainer = Trainer(
296+
default_root_dir=tmpdir,
297+
max_epochs=1,
298+
limit_train_batches=train_batches,
299+
limit_val_batches=val_batches,
300+
callbacks=[progress_bar],
301+
logger=False,
302+
checkpoint_callback=False,
303+
)
304+
trainer.fit(model)
305+
progress_bar.main_progress_bar.update.assert_has_calls([call(delta) for delta in train_deltas])
306+
if val_batches > 0:
307+
progress_bar.val_progress_bar.update.assert_has_calls([call(delta) for delta in val_deltas])
308+
309+
310+
@pytest.mark.parametrize("test_batches,refresh_rate,test_deltas", [
311+
[1, 3, [1]],
312+
[3, 1, [1, 1, 1]],
313+
[5, 3, [3, 2]],
314+
])
315+
def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate, test_deltas):
316+
"""
317+
Test that test progress updates with the correct amount.
318+
"""
319+
model = BoringModel()
320+
progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate)
321+
trainer = Trainer(
322+
default_root_dir=tmpdir,
323+
max_epochs=1,
324+
limit_test_batches=test_batches,
325+
callbacks=[progress_bar],
326+
logger=False,
327+
checkpoint_callback=False,
328+
)
329+
trainer.test(model)
330+
progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas])

0 commit comments

Comments
 (0)