Skip to content

Commit 238f2ac

Browse files
set tqdm as default progress (#144)
* set tqdm as default progress * Remove default progress bar settings completely * add unit tests * automatically set refresh rate to 10 on kaggle and colab * Update __init__.py * Update progress.py Update .gitignore Update progress.py Update progress.py Update progress.py Update progress.py Update progress.py * Update progress.py * Update test_trainer.py * Update trainer.py * Update trainer.py --------- Co-authored-by: Giovanni Volpe <[email protected]>
1 parent 75ed214 commit 238f2ac

File tree

5 files changed

+487
-23
lines changed

5 files changed

+487
-23
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Trial notebooks
2+
trial*.ipynb
3+
14
# Byte-compiled / optimized / DLL files
25
*.pyc
36

deeplay/callbacks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .progress import RichProgressBar
21
from .history import LogHistory
2+
from .progress import RichProgressBar, TQDMProgressBar

deeplay/callbacks/progress.py

Lines changed: 273 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,289 @@
1+
"""Enhanced progress bars for compatibility and customization.
2+
3+
This module defines enhanced progress bars for PyTorch Lightning, designed to
4+
improve compatibility and usability in various environments.
5+
The `TQDMProgressBar` and `RichProgressBar` classes extend the Lightning
6+
default implementations and provide safe refresh rate handling for platforms
7+
like Google Colab and Kaggle, which may crash with small refresh rates.
8+
9+
Key Features
10+
------------
11+
- **TQDM Progress Bar with Compatibility Enhancements**
12+
13+
The `TQDMProgressBar` class extends Lightning's `TQDMProgressBar`,
14+
providing a mechanism to adjust refresh rates based on the execution
15+
environment. This helps avoid issues caused by small refresh rates on Colab
16+
and Kaggle.
17+
18+
- **Rich Progress Bar with Customization Options**
19+
20+
The `RichProgressBar` class offers a visually appealing progress bar with
21+
customizable themes and console options. Similar to the `TQDMProgressBar`,
22+
it includes environment-based refresh rate adjustments to enhance
23+
stability.
24+
25+
Module Structure
26+
----------------
27+
Classes:
28+
29+
- `TQDMProgressBar`: Enhances Lightning TQDM progress bar.
30+
31+
Automatically modifies the refresh rate if the code is executed on
32+
platforms like Colab or Kaggle.
33+
34+
- `RichProgressBar`: Enhances Lightning Rich progress bar.
35+
36+
Supports configurable themes and console settings, and adjusts refresh
37+
rates when needed.
38+
39+
Examples
40+
--------
41+
This example demosntrate the use of the standard TQDM progress bar:
42+
43+
```python
44+
import deeplay as dl
45+
import torch
46+
47+
# Create training dataset.
48+
num_samples = 10 ** 4
49+
data = torch.randn(num_samples, 2)
50+
labels = (data.sum(dim=1) > 0).long()
51+
52+
dataset = torch.utils.data.TensorDataset(data, labels)
53+
dataloader = dl.DataLoader(dataset, batch_size=16, shuffle=True)
54+
55+
# Create neural network and classifier application.
56+
mlp = dl.MediumMLP(in_features=2, out_features=2)
57+
classifier = dl.Classifier(mlp, optimizer=dl.Adam(), num_classes=2).build()
58+
59+
# Train neural network with progress bar.
60+
tqdm_bar = dl.callbacks.TQDMProgressBar(refresh_rate=100)
61+
trainer = dl.Trainer(max_epochs=100, callbacks=[tqdm_bar])
62+
trainer.fit(classifier, dataloader)
63+
```
64+
65+
Alternatively, you can use the rich progress bar with:
66+
67+
```python
68+
rich_bar = dl.callbacks.RichProgressBar(refresh_rate=100)
69+
trainer = dl.Trainer(max_epochs=100, callbacks=[rich_bar])
70+
trainer.fit(classifier, dataloader)
71+
```
72+
73+
"""
74+
75+
from __future__ import annotations
76+
77+
import os
78+
179
from lightning.pytorch.callbacks.progress.rich_progress import (
2-
RichProgressBar as RPB,
3-
RichProgressBarTheme as RPBT,
80+
RichProgressBar as LightningRichProgressBar,
81+
RichProgressBarTheme as RPBTheme,
482
)
83+
from lightning.pytorch.callbacks.progress.tqdm_progress import (
84+
TQDMProgressBar as LightningTQDMProgressBar,
85+
)
86+
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
87+
88+
89+
class TQDMProgressBar(LightningTQDMProgressBar):
90+
"""A progress bar for displaying training progress with TQDM.
91+
92+
This class enhances the standard Lightning TQDMProgressBar by providing
93+
environment-specific adjustments to prevent potential crashes on platforms
94+
like Colab and Kaggle.
595
96+
Parameters
97+
----------
98+
refresh_rate : int, optional
99+
The refresh rate of the progress bar, by default 1.
6100
7-
class RichProgressBar(RPB):
101+
Example
102+
-------
103+
This example demosntrate the use of the standard TQDM progress bar:
104+
105+
```python
106+
import deeplay as dl
107+
import torch
108+
109+
# Create training dataset.
110+
num_samples = 10 ** 4
111+
data = torch.randn(num_samples, 2)
112+
labels = (data.sum(dim=1) > 0).long()
113+
114+
dataset = torch.utils.data.TensorDataset(data, labels)
115+
dataloader = dl.DataLoader(dataset, batch_size=16, shuffle=True)
116+
117+
# Create neural network and classifier application.
118+
mlp = dl.MediumMLP(in_features=2, out_features=2)
119+
classifier = dl.Classifier(mlp, optimizer=dl.Adam(), num_classes=2).build()
120+
121+
# Train neural network with progress bar.
122+
tqdm_bar = dl.callbacks.TQDMProgressBar(refresh_rate=100)
123+
trainer = dl.Trainer(max_epochs=100, callbacks=[tqdm_bar])
124+
trainer.fit(classifier, dataloader)
125+
```
126+
127+
"""
8128

9129
def __init__(
10-
self,
130+
self: TQDMProgressBar,
131+
refresh_rate: int = 1,
132+
):
133+
"""Initialize the progress bar with a configurable refresh rate.
134+
135+
Parameters
136+
----------
137+
refresh_rate : int, optional
138+
The refresh rate of the progress bar, by default 1.
139+
140+
"""
141+
142+
super().__init__(refresh_rate=refresh_rate)
143+
144+
@staticmethod
145+
def _resolve_refresh_rate(refresh_rate: int) -> int:
146+
"""Resolve refresh rate for compatibility with Colab and Kaggle.
147+
148+
This method adjusts the refresh rate to a safe value to prevent crashes
149+
on platforms that are known to have issues with small refresh rates.
150+
151+
Parameters
152+
----------
153+
refresh_rate : int
154+
The desired refresh rate of the progress bar.
155+
156+
Returns
157+
-------
158+
int
159+
The adjusted refresh rate.
160+
161+
"""
162+
163+
# This should work both for Colab and Kaggle because Kaggle returns a
164+
# Colab session.
165+
if "COLAB_JUPYTER_IP" in os.environ and refresh_rate == 1:
166+
rank_zero_debug(
167+
"Small refresh rates can crash on Colab or Kaggle. "
168+
"Setting refresh_rate to 10.\n"
169+
"To manually set the refresh rate, "
170+
"call `trainer.tqdm_progress_bar(refresh_rate=10)`."
171+
)
172+
refresh_rate = 10
173+
174+
return LightningTQDMProgressBar._resolve_refresh_rate(refresh_rate)
175+
176+
177+
class RichProgressBar(LightningRichProgressBar):
178+
"""A progress bar for displaying training progress with Rich.
179+
180+
This class enhances the standard Lightning RichProgressBar by supporting
181+
customizable themes and console options. It includes an
182+
environment-specific adjustment to prevent potential crashes on platforms
183+
like Colab and Kaggle.
184+
185+
Parameters
186+
----------
187+
refresh_rate : int, optional
188+
The refresh rate of the progress bar, by default 1.
189+
leave : bool, optional
190+
Whether to leave the progress bar on the screen after completion,
191+
by default False.
192+
theme : RichProgressBarTheme, optional
193+
The theme used for the Rich progress bar,
194+
by default `RichProgressBarTheme(metrics_format=".3g")`.
195+
console_kwargs : dict, optional
196+
Additional keyword arguments for configuring the Rich console,
197+
by default None.
198+
199+
Example
200+
-------
201+
This example demosntrate the use of the standard TQDM progress bar:
202+
203+
```python
204+
import deeplay as dl
205+
import torch
206+
207+
# Create training dataset.
208+
num_samples = 10 ** 4
209+
data = torch.randn(num_samples, 2)
210+
labels = (data.sum(dim=1) > 0).long()
211+
212+
dataset = torch.utils.data.TensorDataset(data, labels)
213+
dataloader = dl.DataLoader(dataset, batch_size=16, shuffle=True)
214+
215+
# Create neural network and classifier application.
216+
mlp = dl.MediumMLP(in_features=2, out_features=2)
217+
classifier = dl.Classifier(mlp, optimizer=dl.Adam(), num_classes=2).build()
218+
219+
# Train neural network with progress bar.
220+
rich_bar = dl.callbacks.RichProgressBar(refresh_rate=100)
221+
trainer = dl.Trainer(max_epochs=100, callbacks=[rich_bar])
222+
trainer.fit(classifier, dataloader)
223+
```
224+
225+
"""
226+
227+
def __init__(
228+
self: RichProgressBar,
11229
refresh_rate: int = 1,
12230
leave: bool = False,
13-
theme: RPBT = RPBT(metrics_format=".3g"),
231+
theme: RPBTheme = RPBTheme(metrics_format=".3g"),
14232
console_kwargs=None,
15233
):
234+
"""Initialize the Rich progress bar with customizable settings.
235+
236+
Parameters
237+
----------
238+
refresh_rate : int, optional
239+
The refresh rate of the progress bar, by default 1.
240+
leave : bool, optional
241+
Whether to leave the progress bar displayed after completion,
242+
by default False.
243+
theme : RichProgressBarTheme, optional
244+
The theme of the progress bar,
245+
by default `RPBTheme(metrics_format=".3g")`.
246+
console_kwargs : dict, optional
247+
Additional keyword arguments to configure the Rich console,
248+
by default None.
249+
250+
"""
251+
16252
super().__init__(
17253
refresh_rate=refresh_rate,
18254
leave=leave,
19255
theme=theme,
20256
console_kwargs=console_kwargs,
21257
)
258+
259+
@staticmethod
260+
def _resolve_refresh_rate(refresh_rate: int) -> int:
261+
"""Resolve refresh rate for compatibility with Colab and Kaggle.
262+
263+
This method adjusts the refresh rate to a safe value to prevent crashes
264+
on platforms that are known to have issues with small refresh rates.
265+
266+
Parameters
267+
----------
268+
refresh_rate : int
269+
The desired refresh rate of the progress bar.
270+
271+
Returns
272+
-------
273+
int
274+
The adjusted refresh rate.
275+
276+
"""
277+
278+
# This should work both for Colab and Kaggle because Kaggle returns a
279+
# Colab session.
280+
if "COLAB_JUPYTER_IP" in os.environ and refresh_rate == 1:
281+
rank_zero_debug(
282+
"Small refresh rates can crash on Colab or Kaggle. "
283+
"Setting refresh_rate to 10.\n"
284+
"To manually set the refresh rate, "
285+
"call `trainer.rich_progress_bar(refresh_rate=10)`."
286+
)
287+
refresh_rate = 10
288+
289+
return LightningTQDMProgressBar._resolve_refresh_rate(refresh_rate)

deeplay/tests/test_trainer.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
from deeplay import DataLoader, Regressor
22
from deeplay.callbacks import LogHistory, RichProgressBar
3+
4+
from deeplay import Regressor, DataLoader
5+
from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar
6+
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
7+
import unittest
8+
import torch.nn as nn
9+
310
from deeplay.trainer import Trainer
11+
412
import lightning as L
513
from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar
614
import torch
@@ -12,7 +20,8 @@
1220
class TestTrainer(unittest.TestCase):
1321

1422
def setUp(self):
15-
self._patcher = patch("torch.backends.mps.is_available", return_value=False)
23+
self._patcher = patch("torch.backends.mps.is_available",
24+
return_value=False)
1625
self._mock = self._patcher.start()
1726

1827
def tearDown(self):
@@ -23,9 +32,50 @@ def test_trainer(self):
2332
self.assertIsInstance(trainer, Trainer)
2433
self.assertIsInstance(trainer.callbacks[0], LogHistory)
2534
self.assertIsInstance(
26-
trainer.callbacks[1], RichProgressBar
35+
trainer.callbacks[1], TQDMProgressBar
2736
) # should be added by default
2837

38+
def test_trainer(self):
39+
trainer = Trainer(callbacks=[LogHistory()])
40+
trainer.disable_progress_bar()
41+
self.assertIsInstance(trainer, Trainer)
42+
self.assertIsInstance(trainer.callbacks[0], LogHistory)
43+
44+
has_a_progress_bar = False
45+
for callback in trainer.callbacks:
46+
if isinstance(callback, ProgressBar):
47+
has_a_progress_bar = True
48+
break
49+
self.assertFalse(has_a_progress_bar)
50+
51+
def test_trainer_tqdm_progress_bar(self):
52+
trainer = Trainer(callbacks=[LogHistory()])
53+
trainer.tqdm_progress_bar()
54+
self.assertIsInstance(trainer, Trainer)
55+
self.assertIsInstance(trainer.callbacks[0], LogHistory)
56+
57+
num_progress_bars = 0
58+
for callback in trainer.callbacks:
59+
if isinstance(callback, ProgressBar):
60+
num_progress_bars += 1
61+
self.assertIsInstance(callback, TQDMProgressBar)
62+
63+
self.assertEqual(num_progress_bars, 1)
64+
65+
def test_trainer_rich_progress_bar(self):
66+
trainer = Trainer(callbacks=[LogHistory()])
67+
trainer.rich_progress_bar()
68+
self.assertIsInstance(trainer, Trainer)
69+
self.assertIsInstance(trainer.callbacks[0], LogHistory)
70+
71+
num_progress_bars = 0
72+
for callback in trainer.callbacks:
73+
if isinstance(callback, ProgressBar):
74+
num_progress_bars += 1
75+
self.assertIsInstance(callback, RichProgressBar)
76+
77+
self.assertEqual(num_progress_bars, 1)
78+
2979
def test_trainer_explicit_progress_bar(self):
3080
trainer = Trainer(callbacks=[LogHistory(), RichProgressBar()])
3181
self.assertIsInstance(trainer, Trainer)

0 commit comments

Comments
 (0)