Skip to content

Commit 5d178d0

Browse files
Support TQDM_MINITERS env variable (#19381)
Co-authored-by: awaelchli <[email protected]>
1 parent c346f4d commit 5d178d0

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
- Added a utility function and CLI to consolidate FSDP sharded checkpoints into a single file ([#19213](https://github.com/Lightning-AI/lightning/pull/19213))
2828

2929

30+
- The TQDM progress bar now respects the env variable `TQDM_MINITERS` for setting the refresh rate ([#19381](https://github.com/Lightning-AI/lightning/pull/19381))
31+
32+
3033
### Changed
3134

3235
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))

src/lightning/pytorch/callbacks/progress/tqdm_progress.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,9 @@ def _resolve_refresh_rate(refresh_rate: int) -> int:
431431
# smaller refresh rate on colab causes crashes, choose a higher value
432432
rank_zero_debug("Using a higher refresh rate on Colab. Setting it to `20`")
433433
return 20
434+
# Support TQDM_MINITERS environment variable, which sets the minimum refresh rate
435+
if "TQDM_MINITERS" in os.environ:
436+
return max(int(os.environ["TQDM_MINITERS"]), refresh_rate)
434437
return refresh_rate
435438

436439

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,19 @@ def test_tqdm_progress_bar_value_on_colab(tmp_path):
334334
assert trainer.progress_bar_callback.refresh_rate == 19
335335

336336

337+
@pytest.mark.parametrize(("refresh_rate", "env_value", "expected"), [
338+
(0, 1, 1),
339+
(1, 0, 1),
340+
(1, 1, 1),
341+
(2, 1, 2),
342+
(1, 2, 2),
343+
])
344+
def test_tqdm_progress_bar_refresh_rate_via_env_variable(refresh_rate, env_value, expected):
345+
with mock.patch.dict(os.environ, {"TQDM_MINITERS": str(env_value)}):
346+
bar = TQDMProgressBar(refresh_rate=refresh_rate)
347+
assert bar.refresh_rate == expected
348+
349+
337350
@pytest.mark.parametrize(
338351
("train_batches", "val_batches", "refresh_rate", "train_updates", "val_updates"),
339352
[

0 commit comments

Comments
 (0)