Skip to content

Commit e3f86c8

Browse files
committed
asserting directly on progress bar proprty
1 parent 49f7f23 commit e3f86c8

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tests/tests_pytorch/loops/test_training_loop.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,11 @@ def test_tqdm_total_steps_with_iterator_no_length(tmp_path, max_steps):
233233

234234
# Override train_dataloader with infinite iterator
235235
model.train_dataloader = lambda: infinite_iter
236+
pbar = trainer.progress_bar_callback
236237
trainer.fit(model)
237238

238-
# tqdm total steps should equal max_steps for iterator with no length
239-
assert trainer.estimated_stepping_batches == max_steps
239+
# assert progress bar callback uses correct total steps
240+
assert pbar.train_progress_bar.total == max_steps
240241

241242

242243
@pytest.mark.parametrize("max_steps", [10, 15])
@@ -271,7 +272,8 @@ def __iter__(self):
271272
enable_model_summary=False,
272273
accelerator="cpu",
273274
)
275+
pbar = trainer.progress_bar_callback
274276
trainer.fit(model)
275-
276-
# tqdm total steps should equal max_steps for iterator with no length
277-
assert trainer.estimated_stepping_batches == max_steps
277+
278+
# assert progress bar callback uses correct total steps
279+
assert pbar.train_progress_bar.total == max_steps

0 commit comments

Comments
 (0)