Skip to content

Commit 0cb2953

Browse files
committed
test(progress): update test_rich_progress_bar_with_refresh_rate to test_rich_progress_bar_update_counts
1 parent d15460f commit 0cb2953

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -221,30 +221,27 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmp_path):
221221

222222
@RunIf(rich=True)
223223
@pytest.mark.parametrize(
224-
("refresh_rate", "train_batches", "val_batches", "expected_call_count"),
224+
("train_batches", "val_batches", "expected_call_count"),
225225
[
226226
# note: there is always one extra update at the very end (+1)
227-
(3, 6, 6, 2 + 2 + 1),
228-
(4, 6, 6, 2 + 2 + 1),
229-
(7, 6, 6, 1 + 1 + 1),
230-
(1, 2, 3, 2 + 3 + 1),
231-
(1, 0, 0, 0 + 0),
232-
(3, 1, 0, 1 + 0),
233-
(3, 1, 1, 1 + 1 + 1),
234-
(3, 5, 0, 2 + 0),
235-
(3, 5, 2, 2 + 1 + 1),
236-
(6, 5, 2, 1 + 1 + 1),
227+
(6, 6, 6 + 6 + 1),
228+
(2, 3, 2 + 3 + 1),
229+
(0, 0, 0 + 0),
230+
(1, 0, 1 + 0),
231+
(1, 1, 1 + 1 + 1),
232+
(5, 0, 5 + 0),
233+
(5, 2, 5 + 2 + 1),
237234
],
238235
)
239-
def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batches, val_batches, expected_call_count):
236+
def test_rich_progress_bar_update_counts(tmp_path, train_batches, val_batches, expected_call_count):
240237
model = BoringModel()
241238
trainer = Trainer(
242239
default_root_dir=tmp_path,
243240
num_sanity_val_steps=0,
244241
limit_train_batches=train_batches,
245242
limit_val_batches=val_batches,
246243
max_epochs=1,
247-
callbacks=RichProgressBar(refresh_rate=refresh_rate),
244+
callbacks=RichProgressBar(),
248245
)
249246

250247
trainer.progress_bar_callback.on_train_start(trainer, model)

0 commit comments

Comments
 (0)