Skip to content

Commit 17f58d2

Browse files
authored
add rank warning (#1428)
* add rank warning * changelog * use rank_zero_warn * user trainer_init * replace warnings * fix test * flake8 * docs * changelog * bug lol
1 parent b4eb388 commit 17f58d2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+213
-187
lines changed

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
## [0.7.3] - 2020-04-09
8+
9+
### Added
10+
11+
- Added `rank_zero_warn` for warning only in rank 0 ([#1428](https://github.com/PyTorchLightning/pytorch-lightning/pull/1428))
12+
13+
### Fixed
14+
15+
- Fixed default `DistributedSampler` for DDP training ([#1425](https://github.com/PyTorchLightning/pytorch-lightning/pull/1425))
16+
- Fixed workers warning not on windows ([#1430](https://github.com/PyTorchLightning/pytorch-lightning/pull/1430))
17+
718
## [0.7.2] - 2020-04-07
819

920
### Added

pytorch_lightning/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Root package info."""
22

3-
__version__ = '0.7.2'
3+
__version__ = '0.7.3rc1'
44
__author__ = 'William Falcon et al.'
55
__author_email__ = '[email protected]'
66
__license__ = 'Apache-2.0'

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
77
"""
88

9-
import warnings
10-
119
import numpy as np
1210

1311
from pytorch_lightning import _logger as log
1412
from pytorch_lightning.callbacks.base import Callback
13+
from pytorch_lightning.utilities import rank_zero_warn
1514

1615

1716
class EarlyStopping(Callback):
@@ -80,7 +79,7 @@ def check_metrics(self, logs):
8079
if self.strict:
8180
raise RuntimeError(error_msg)
8281
if self.verbose > 0:
83-
warnings.warn(error_msg, RuntimeWarning)
82+
rank_zero_warn(error_msg, RuntimeWarning)
8483

8584
return False
8685

@@ -113,6 +112,6 @@ def on_epoch_end(self, trainer, pl_module):
113112

114113
def on_train_end(self, trainer, pl_module):
115114
if self.stopped_epoch > 0 and self.verbose > 0:
116-
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
117-
' but will start from "0" in v0.8.0.', DeprecationWarning)
115+
rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
116+
' but will start from "0" in v0.8.0.', DeprecationWarning)
118117
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')

pytorch_lightning/callbacks/gradient_accumulation_scheduler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
77
"""
88

9-
import warnings
10-
119
from pytorch_lightning.callbacks.base import Callback
10+
from pytorch_lightning.utilities import rank_zero_warn
1211

1312

1413
class GradientAccumulationScheduler(Callback):
@@ -46,8 +45,8 @@ def __init__(self, scheduling: dict):
4645
raise TypeError("All epoches and accumulation factor must be integers")
4746

4847
minimal_epoch = min(scheduling.keys())
49-
warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,'
50-
' but will start from "0" in v0.8.0.', DeprecationWarning)
48+
rank_zero_warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,'
49+
' but will start from "0" in v0.8.0.', DeprecationWarning)
5150
if minimal_epoch < 1:
5251
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
5352
raise IndexError(msg)

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
"""
88

99
import os
10-
import shutil
11-
import warnings
1210
import re
1311

1412
import numpy as np
1513

16-
from pytorch_lightning.callbacks.base import Callback
1714
from pytorch_lightning import _logger as log
15+
from pytorch_lightning.callbacks.base import Callback
16+
from pytorch_lightning.utilities import rank_zero_warn
1817

1918

2019
class ModelCheckpoint(Callback):
@@ -83,7 +82,7 @@ def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = Fal
8382
mode: str = 'auto', period: int = 1, prefix: str = ''):
8483
super().__init__()
8584
if save_top_k > 0 and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
86-
warnings.warn(
85+
rank_zero_warn(
8786
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
8887
"All files in this directory will be deleted when a checkpoint is saved!"
8988
)
@@ -115,9 +114,7 @@ def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = Fal
115114
}
116115

117116
if mode not in mode_dict:
118-
warnings.warn(
119-
f'ModelCheckpoint mode {mode} is unknown, '
120-
'fallback to auto mode.', RuntimeWarning)
117+
rank_zero_warn(f'ModelCheckpoint mode {mode} is unknown, fallback to auto mode.', RuntimeWarning)
121118
mode = 'auto'
122119

123120
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
@@ -206,7 +203,7 @@ def on_validation_end(self, trainer, pl_module):
206203
current = metrics.get(self.monitor)
207204

208205
if current is None:
209-
warnings.warn(f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning)
206+
rank_zero_warn(f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning)
210207
elif self.check_monitor_top_k(current):
211208
self._do_check_save(filepath, current, epoch)
212209
elif self.verbose > 0:

pytorch_lightning/core/decorators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import warnings
1+
from pytorch_lightning.utilities import rank_zero_warn
22

33

44
def data_loader(fn):
@@ -7,7 +7,7 @@ def data_loader(fn):
77
Warnings:
88
This decorator deprecated in v0.7.0 and it will be removed v0.9.0.
99
"""
10-
warnings.warn('`data_loader` decorator deprecated in v0.7.0. Will be removed v0.9.0', DeprecationWarning)
10+
rank_zero_warn('`data_loader` decorator deprecated in v0.7.0. Will be removed v0.9.0', DeprecationWarning)
1111

1212
def inner_fx(self):
1313
return fn(self)

pytorch_lightning/core/lightning.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import collections
22
import inspect
33
import os
4-
import warnings
54
from abc import ABC, abstractmethod
65
from argparse import Namespace
76
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
@@ -20,6 +19,7 @@
2019
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
2120
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
2221
from pytorch_lightning.utilities.exceptions import MisconfigurationException
22+
from pytorch_lightning.utilities import rank_zero_warn
2323

2424
try:
2525
import torch_xla.core.xla_model as xm
@@ -225,7 +225,7 @@ def training_step(self, batch, batch_idx, hiddens):
225225
The loss value shown in the progress bar is smoothed (averaged) over the last values,
226226
so it differs from the actual loss returned in train/validation step.
227227
"""
228-
warnings.warn('`training_step` must be implemented to be used with the Lightning Trainer')
228+
rank_zero_warn('`training_step` must be implemented to be used with the Lightning Trainer')
229229

230230
def training_end(self, *args, **kwargs):
231231
"""
@@ -1088,7 +1088,7 @@ def configure_optimizers(self):
10881088
}
10891089
10901090
"""
1091-
warnings.warn('`configure_optimizers` must be implemented to be used with the Lightning Trainer')
1091+
rank_zero_warn('`configure_optimizers` must be implemented to be used with the Lightning Trainer')
10921092

10931093
def optimizer_step(
10941094
self,
@@ -1291,16 +1291,16 @@ def train_dataloader(self):
12911291
return loader
12921292
12931293
"""
1294-
warnings.warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
1294+
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
12951295

12961296
def tng_dataloader(self): # todo: remove in v1.0.0
12971297
"""
12981298
Warnings:
12991299
Deprecated in v0.5.0. Use :meth:`train_dataloader` instead. Will be removed in 1.0.0.
13001300
"""
13011301
output = self.train_dataloader()
1302-
warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0."
1303-
" and this method will be removed in v1.0.0", DeprecationWarning)
1302+
rank_zero_warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0."
1303+
" and this method will be removed in v1.0.0", DeprecationWarning)
13041304
return output
13051305

13061306
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
@@ -1407,7 +1407,7 @@ def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
14071407
Deprecated in version 0.7.0. You should use :meth:`load_from_checkpoint` instead.
14081408
Will be removed in v0.9.0.
14091409
"""
1410-
warnings.warn(
1410+
rank_zero_warn(
14111411
"`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0."
14121412
" The deprecated method will be removed in v0.9.0.", DeprecationWarning
14131413
)
@@ -1519,7 +1519,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh
15191519
is_namespace = checkpoint.get('hparams_type', 'namespace') == 'namespace'
15201520
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
15211521
else:
1522-
warnings.warn(
1522+
rank_zero_warn(
15231523
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ "
15241524
f"contains argument 'hparams'. Will pass in an empty Namespace instead."
15251525
" Did you forget to store your model hyperparameters in self.hparams?"

pytorch_lightning/core/model_saving.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
The deprecated module name will be removed in v0.8.0.
44
"""
55

6-
import warnings
6+
from pytorch_lightning.utilities import rank_zero_warn
77

8-
warnings.warn("`model_saving` module has been renamed to `saving` since v0.6.0."
9-
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
8+
rank_zero_warn("`model_saving` module has been renamed to `saving` since v0.6.0."
9+
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
1010

1111
from pytorch_lightning.core.saving import * # noqa: F403

pytorch_lightning/core/root_module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
The deprecated module name will be removed in v0.8.0.
44
"""
55

6-
import warnings
6+
from pytorch_lightning.utilities import rank_zero_warn
77

8-
from pytorch_lightning.core.lightning import * # noqa: F403
8+
rank_zero_warn("`root_module` module has been renamed to `lightning` since v0.6.0."
9+
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
910

10-
warnings.warn("`root_module` module has been renamed to `lightning` since v0.6.0."
11-
" The deprecated module name will be removed in v0.8.0.", DeprecationWarning)
11+
from pytorch_lightning.core.lightning import * # noqa: F403

pytorch_lightning/logging/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
The deprecated package name will be removed in v0.9.0.
44
"""
55

6-
import warnings
6+
from pytorch_lightning.utilities import rank_zero_warn
77

8-
warnings.warn("`logging` package has been renamed to `loggers` since v0.7.0"
9-
" The deprecated package name will be removed in v0.9.0.", DeprecationWarning)
8+
rank_zero_warn("`logging` package has been renamed to `loggers` since v0.7.0"
9+
" The deprecated package name will be removed in v0.9.0.", DeprecationWarning)
1010

1111
from pytorch_lightning.loggers import * # noqa: F403
1212
from pytorch_lightning.loggers import base, tensorboard # noqa: F403

0 commit comments

Comments
 (0)