Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))


- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147))

---


Expand Down
16 changes: 12 additions & 4 deletions src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ def on_train_start(self, *_: Any) -> None:
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._leave:
self.train_progress_bar = self.init_train_tqdm()
self.train_progress_bar.reset(convert_inf(self.total_train_batches))
total = convert_inf(self.total_train_batches)
self.train_progress_bar.reset()
self.train_progress_bar.total = total
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

Expand Down Expand Up @@ -306,7 +308,9 @@ def on_validation_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return

self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
total = convert_inf(self.total_val_batches_current_dataloader)
self.val_progress_bar.reset()
self.val_progress_bar.total = total
self.val_progress_bar.initial = 0
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
Expand Down Expand Up @@ -348,7 +352,9 @@ def on_test_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return

self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
total = convert_inf(self.total_test_batches_current_dataloader)
self.test_progress_bar.reset()
self.test_progress_bar.total = total
self.test_progress_bar.initial = 0
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")

Expand Down Expand Up @@ -387,7 +393,9 @@ def on_predict_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return

self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
total = convert_inf(self.total_predict_batches_current_dataloader)
self.predict_progress_bar.reset()
self.predict_progress_bar.total = total
self.predict_progress_bar.initial = 0
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -812,3 +812,50 @@ def test_tqdm_leave(leave, tmp_path):
)
trainer.fit(model)
assert pbar.init_train_tqdm.call_count == (4 if leave else 1)


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
def test_tqdm_progress_bar_reset_behavior(tmp_path):
"""Test that progress bars call reset() without parameters and set total separately."""
model = BoringModel()

class ResetTrackingTqdm(MockTqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reset_calls_with_params = []

def reset(self, total=None):
self.reset_calls_with_params.append(total)
super().reset(total)

trainer = Trainer(
default_root_dir=tmp_path,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
logger=False,
enable_checkpointing=False,
)

pbar = trainer.progress_bar_callback

with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", ResetTrackingTqdm):
trainer.fit(model)

train_bar = pbar.train_progress_bar
assert None in train_bar.reset_calls_with_params, (
f"train reset() should be called without parameters, got calls: {train_bar.reset_calls_with_params}"
)
# Verify that total was set separately to the expected value
assert 2 in train_bar.total_values, (
f"train total should be set to 2 after reset(), got total_values: {train_bar.total_values}"
)
# Verify that validation progress bar reset() was called without parameters
val_bar = pbar.val_progress_bar
assert None in val_bar.reset_calls_with_params, (
f"validation reset() should be called without parameters, got calls: {val_bar.reset_calls_with_params}"
)
# Verify that total was set separately to the expected value
assert 2 in val_bar.total_values, (
f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}"
)
Loading