Skip to content

Commit e6e178f

Browse files
committed
add testing
1 parent ccf63c3 commit e6e178f

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batch
246246
with mock.patch.object(
247247
trainer.progress_bar_callback.progress, "update", wraps=trainer.progress_bar_callback.progress.update
248248
) as progress_update:
249+
metrics_update = mock.MagicMock()
250+
trainer.progress_bar_callback._update_metrics = metrics_update
251+
249252
trainer.fit(model)
250253
assert progress_update.call_count == expected_call_count
251254

@@ -260,6 +263,9 @@ def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batch
260263
assert fit_val_bar.total == val_batches
261264
assert not fit_val_bar.visible
262265

266+
# one call for each train batch + one at the end of training epoch + one for validation end
267+
assert metrics_update.call_count == train_batches + (1 if train_batches > 0 else 0) + (1 if val_batches > 0 else 0)
268+
263269

264270
@RunIf(rich=True)
265271
@pytest.mark.parametrize("limit_val_batches", [1, 5])

0 commit comments

Comments
 (0)