Skip to content

Commit 1776963

Browse files
Fix progress bar display to correctly handle iterable dataset and max_steps during training (#20869)
* changes to show correct progress bar numbers when using max_steps --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7433dc2 commit 1776963

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

src/lightning/pytorch/callbacks/progress/progress_bar.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def total_train_batches(self) -> Union[int, float]:
8585
dataloader is of infinite size.
8686
8787
"""
88+
if self.trainer.max_epochs == -1 and self.trainer.max_steps is not None and self.trainer.max_steps > 0:
89+
remaining_steps = self.trainer.max_steps - self.trainer.global_step
90+
return min(self.trainer.num_training_batches, remaining_steps)
8891
return self.trainer.num_training_batches
8992

9093
@property

tests/tests_pytorch/loops/test_training_loop.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import itertools
1415
import logging
1516
from unittest.mock import Mock
1617

1718
import pytest
1819
import torch
20+
from torch.utils.data import DataLoader
1921

2022
from lightning.pytorch import Trainer, seed_everything
2123
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -206,3 +208,72 @@ def test_should_stop_early_stopping_conditions_met(
206208

207209
assert (message in caplog.text) is raise_debug_msg
208210
assert trainer.fit_loop._can_stop_early is early_stop
211+
212+
213+
@pytest.mark.parametrize("max_steps", [7, 20])
214+
def test_tqdm_total_steps_with_iterator_no_length(tmp_path, max_steps):
215+
"""Test trainer with infinite iterator (no __len__)"""
216+
217+
batch_size = 4
218+
model = BoringModel()
219+
220+
# Infinite generator (no __len__)
221+
# NOTE: 32 for BoringModel
222+
infinite_iter = (torch.randn(batch_size, 32, dtype=torch.float32) for _ in itertools.count(0))
223+
224+
trainer = Trainer(
225+
default_root_dir=tmp_path,
226+
max_steps=max_steps,
227+
max_epochs=-1,
228+
limit_val_batches=0,
229+
enable_progress_bar=True,
230+
enable_model_summary=False,
231+
accelerator="cpu",
232+
)
233+
234+
# Override train_dataloader with infinite iterator
235+
model.train_dataloader = lambda: infinite_iter
236+
pbar = trainer.progress_bar_callback
237+
trainer.fit(model)
238+
239+
# assert progress bar callback uses correct total steps
240+
assert pbar.train_progress_bar.total == max_steps
241+
242+
243+
@pytest.mark.parametrize("max_steps", [10, 15])
244+
def test_progress_bar_steps(tmp_path, max_steps):
245+
batch_size = 4
246+
247+
model = BoringModel()
248+
# Create dataloader here, outside the model
249+
# NOTE: 32 for boring model
250+
x = torch.randn(100, 32)
251+
252+
class SingleTensorDataset(torch.utils.data.IterableDataset):
253+
def __init__(self, data):
254+
super().__init__()
255+
self.data = data
256+
257+
def __iter__(self):
258+
yield from self.data # yield just a tensor, not a tuple
259+
260+
dataset = SingleTensorDataset(x)
261+
dataloader = DataLoader(dataset, batch_size=batch_size)
262+
263+
# Patch model's train_dataloader method to return this dataloader
264+
model.train_dataloader = lambda: dataloader
265+
266+
trainer = Trainer(
267+
default_root_dir=tmp_path,
268+
max_steps=max_steps,
269+
max_epochs=-1,
270+
limit_val_batches=0,
271+
enable_progress_bar=True,
272+
enable_model_summary=False,
273+
accelerator="cpu",
274+
)
275+
pbar = trainer.progress_bar_callback
276+
trainer.fit(model)
277+
278+
# assert progress bar callback uses correct total steps
279+
assert pbar.train_progress_bar.total == max_steps

0 commit comments

Comments
 (0)