Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/callbacks/progress/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def total_train_batches(self) -> Union[int, float]:
dataloader is of infinite size.

"""
if self.trainer.max_epochs == -1 and self.trainer.max_steps is not None and self.trainer.max_steps > 0:
remaining_steps = self.trainer.max_steps - self.trainer.global_step
return min(self.trainer.num_training_batches, remaining_steps)
return self.trainer.num_training_batches

@property
Expand Down
69 changes: 69 additions & 0 deletions tests/tests_pytorch/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from unittest.mock import Mock

import pytest
import torch
from torch.utils.data import DataLoader

from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.demos.boring_classes import BoringModel
Expand Down Expand Up @@ -206,3 +208,70 @@ def test_should_stop_early_stopping_conditions_met(

assert (message in caplog.text) is raise_debug_msg
assert trainer.fit_loop._can_stop_early is early_stop


@pytest.mark.parametrize("max_steps", [7, 20])
def test_tqdm_total_steps_with_iterator_no_length(tmp_path, max_steps):
"""Test trainer with infinite iterator (no __len__)"""

batch_size = 4
model = BoringModel()

# Infinite generator (no __len__)
# NOTE: 32 for BoringModel
infinite_iter = (torch.randn(batch_size, 32, dtype=torch.float32) for _ in itertools.count(0))

trainer = Trainer(
default_root_dir=tmp_path,
max_steps=max_steps,
max_epochs=-1,
limit_val_batches=0,
enable_progress_bar=True,
enable_model_summary=False,
accelerator="cpu",
)

# Override train_dataloader with infinite iterator
model.train_dataloader = lambda: infinite_iter
trainer.fit(model)

# tqdm total steps should equal max_steps for iterator with no length
assert trainer.estimated_stepping_batches == max_steps


@pytest.mark.parametrize("max_steps", [10, 15])
def test_progress_bar_steps(tmp_path, max_steps):
batch_size = 4

model = BoringModel()
# Create dataloader here, outside the model
# NOTE: 32 for boring model
x = torch.randn(100, 32)

class SingleTensorDataset(torch.utils.data.IterableDataset):
def __init__(self, data):
super().__init__()
self.data = data

def __iter__(self):
yield from self.data # yield just a tensor, not a tuple

dataset = SingleTensorDataset(x)
dataloader = DataLoader(dataset, batch_size=batch_size)

# Patch model's train_dataloader method to return this dataloader
model.train_dataloader = lambda: dataloader

trainer = Trainer(
default_root_dir=tmp_path,
max_steps=max_steps,
max_epochs=-1,
limit_val_batches=0,
enable_progress_bar=True,
enable_model_summary=False,
accelerator="cpu",
)
trainer.fit(model)

# tqdm total steps should equal max_steps for iterator with no length
assert trainer.estimated_stepping_batches == max_steps
Loading