Skip to content

Commit c8e83a1

Browse files
Samyak2Borda
authored andcommitted
Use high progress_bar_refresh_rate on Google Colab (#4654)
* Use high refresh rate on Google Colab (#3786) Automatically override progress_bar_refresh_rate when on Google Colab. Also added a constant IS_COLAB in utilities to check whether it is being run in colab or not. (#3786) * Show a warning instead of overriding when rate is low on colab * Change warning to suggestion and move it Moved warning to configure_progress_bar instead of on_trainer_init * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> * add a mock test Co-authored-by: chaton <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> (cherry picked from commit ccf38ce)
1 parent 278b9a9 commit c8e83a1

File tree

4 files changed

+27
-4
lines changed

4 files changed

+27
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Added casting to python types for numpy scalars when logging hparams ([#4647](https://github.com/PyTorchLightning/pytorch-lightning/pull/4647))
1313

1414

15+
- Added warning when progress bar refresh rate is less than 20 on Google Colab to prevent crashing ([#4654](https://github.com/PyTorchLightning/pytorch-lightning/pull/4654))
16+
17+
1518
- Added `F1` class metric ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656))
1619

1720

pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from typing import Optional, Union
1516

16-
from typing import Union, Optional
17-
18-
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
17+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar, ProgressBarBase
1918
from pytorch_lightning.utilities import rank_zero_warn
2019
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2120

@@ -74,6 +73,14 @@ def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpo
7473
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None))
7574

7675
def configure_progress_bar(self, refresh_rate=1, process_position=0):
76+
# smaller refresh rate on colab causes crashes, warn user about this
77+
if os.getenv('COLAB_GPU') and refresh_rate < 20:
78+
rank_zero_warn(
79+
"You have set progress_bar_refresh_rate < 20 on Google Colab. This"
80+
" may crash. Consider using progress_bar_refresh_rate >= 20 in Trainer.",
81+
UserWarning
82+
)
83+
7784
progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)]
7885
if len(progress_bars) > 1:
7986
raise MisconfigurationException(

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from pytorch_lightning.utilities.apply_func import move_data_to_device
21-
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
21+
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn
2222
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
2323

2424
try:

tests/callbacks/test_progress_bar.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
import pytest
16+
from unittest import mock
1517

1618
from pytorch_lightning import Trainer
1719
from pytorch_lightning.callbacks import ProgressBarBase, ProgressBar, ModelCheckpoint
@@ -239,3 +241,14 @@ def on_validation_epoch_end(self, trainer, pl_module):
239241
)
240242
trainer.fit(model)
241243
assert trainer.progress_bar_callback.val_progress_bar_total == expected
244+
245+
246+
@mock.patch.dict(os.environ, {'COLAB_GPU': '1'})
247+
def test_progress_bar_warning_on_colab(tmpdir):
248+
with pytest.warns(UserWarning, match='on Google Colab. This may crash.'):
249+
trainer = Trainer(
250+
default_root_dir=tmpdir,
251+
progress_bar_refresh_rate=19,
252+
)
253+
254+
assert trainer.progress_bar_callback.refresh_rate == 19

0 commit comments

Comments
 (0)