diff --git a/src/lightning/pytorch/callbacks/progress/progress_bar.py b/src/lightning/pytorch/callbacks/progress/progress_bar.py index 7cf6993b4414b..4c965038cb294 100644 --- a/src/lightning/pytorch/callbacks/progress/progress_bar.py +++ b/src/lightning/pytorch/callbacks/progress/progress_bar.py @@ -85,6 +85,9 @@ def total_train_batches(self) -> Union[int, float]: dataloader is of infinite size. """ + if self.trainer.max_epochs == -1 and self.trainer.max_steps is not None and self.trainer.max_steps > 0: + remaining_steps = self.trainer.max_steps - self.trainer.global_step + return min(self.trainer.num_training_batches, remaining_steps) return self.trainer.num_training_batches @property diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 29afd1ba1a250..e3a4c37f6a284 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging from unittest.mock import Mock import pytest import torch +from torch.utils.data import DataLoader from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.demos.boring_classes import BoringModel @@ -206,3 +208,72 @@ def test_should_stop_early_stopping_conditions_met( assert (message in caplog.text) is raise_debug_msg assert trainer.fit_loop._can_stop_early is early_stop + + +@pytest.mark.parametrize("max_steps", [7, 20]) +def test_tqdm_total_steps_with_iterator_no_length(tmp_path, max_steps): + """Test trainer with infinite iterator (no __len__)""" + + batch_size = 4 + model = BoringModel() + + # Infinite generator (no __len__) + # NOTE: 32 for BoringModel + infinite_iter = (torch.randn(batch_size, 32, dtype=torch.float32) for _ in itertools.count(0)) + + trainer = Trainer( + default_root_dir=tmp_path, + max_steps=max_steps, + max_epochs=-1, + limit_val_batches=0, + enable_progress_bar=True, + enable_model_summary=False, + accelerator="cpu", + ) + + # Override train_dataloader with infinite iterator + model.train_dataloader = lambda: infinite_iter + pbar = trainer.progress_bar_callback + trainer.fit(model) + + # assert progress bar callback uses correct total steps + assert pbar.train_progress_bar.total == max_steps + + +@pytest.mark.parametrize("max_steps", [10, 15]) +def test_progress_bar_steps(tmp_path, max_steps): + batch_size = 4 + + model = BoringModel() + # Create dataloader here, outside the model + # NOTE: 32 for boring model + x = torch.randn(100, 32) + + class SingleTensorDataset(torch.utils.data.IterableDataset): + def __init__(self, data): + super().__init__() + self.data = data + + def __iter__(self): + yield from self.data # yield just a tensor, not a tuple + + dataset = SingleTensorDataset(x) + dataloader = DataLoader(dataset, batch_size=batch_size) + + # Patch model's train_dataloader method to return this dataloader + model.train_dataloader = lambda: dataloader + + trainer = Trainer( + default_root_dir=tmp_path, + max_steps=max_steps, + max_epochs=-1, + limit_val_batches=0, + enable_progress_bar=True, + enable_model_summary=False, + accelerator="cpu", + ) + pbar = trainer.progress_bar_callback + trainer.fit(model) + + # assert progress bar callback uses correct total steps + assert pbar.train_progress_bar.total == max_steps