diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 03664c8e2d1ad..4c3f0412dfe8d 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105)) +- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147)) + --- diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 942ba3627efc0..74abb8ecd850c 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -265,7 +265,9 @@ def on_train_start(self, *_: Any) -> None: def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: if self._leave: self.train_progress_bar = self.init_train_tqdm() - self.train_progress_bar.reset(convert_inf(self.total_train_batches)) + total = convert_inf(self.total_train_batches) + self.train_progress_bar.reset() + self.train_progress_bar.total = total self.train_progress_bar.initial = 0 self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") @@ -306,7 +308,9 @@ def on_validation_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader)) + total = convert_inf(self.total_val_batches_current_dataloader) + self.val_progress_bar.reset() + self.val_progress_bar.total = total self.val_progress_bar.initial = 0 desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") @@ -348,7 +352,9 @@ def on_test_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader)) + total = convert_inf(self.total_test_batches_current_dataloader) + self.test_progress_bar.reset() + self.test_progress_bar.total = total self.test_progress_bar.initial = 0 self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") @@ -387,7 +393,9 @@ def on_predict_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader)) + total = convert_inf(self.total_predict_batches_current_dataloader) + self.predict_progress_bar.reset() + self.predict_progress_bar.total = total self.predict_progress_bar.initial = 0 self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 538f1bce57ce0..0bd29b998c598 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -812,3 +812,50 @@ def test_tqdm_leave(leave, tmp_path): ) trainer.fit(model) assert pbar.init_train_tqdm.call_count == (4 if leave else 1) + + +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) +def test_tqdm_progress_bar_reset_behavior(tmp_path): + """Test that progress bars call reset() without parameters and set total separately.""" + model = BoringModel() + + class ResetTrackingTqdm(MockTqdm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.reset_calls_with_params = [] + + def reset(self, total=None): + self.reset_calls_with_params.append(total) + super().reset(total) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + logger=False, + enable_checkpointing=False, + ) + + pbar = trainer.progress_bar_callback + + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", ResetTrackingTqdm): + trainer.fit(model) + + train_bar = pbar.train_progress_bar + assert None in train_bar.reset_calls_with_params, ( + f"train reset() should be called without parameters, got calls: {train_bar.reset_calls_with_params}" + ) + # Verify that total was set separately to the expected value + assert 2 in train_bar.total_values, ( + f"train total should be set to 2 after reset(), got total_values: {train_bar.total_values}" + ) + # Verify that validation progress bar reset() was called without parameters + val_bar = pbar.val_progress_bar + assert None in val_bar.reset_calls_with_params, ( + f"validation reset() should be called without parameters, got calls: {val_bar.reset_calls_with_params}" + ) + # Verify that total was set separately to the expected value + assert 2 in val_bar.total_values, ( + f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}" + )