|
1 | 1 | import collections |
2 | 2 | import inspect |
3 | 3 | import os |
4 | | -import warnings |
5 | 4 | from abc import ABC, abstractmethod |
6 | 5 | from argparse import Namespace |
7 | 6 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence |
|
20 | 19 | from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv |
21 | 20 | from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel |
22 | 21 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 22 | +from pytorch_lightning.utilities import rank_zero_warn |
23 | 23 |
|
24 | 24 | try: |
25 | 25 | import torch_xla.core.xla_model as xm |
@@ -225,7 +225,7 @@ def training_step(self, batch, batch_idx, hiddens): |
225 | 225 | The loss value shown in the progress bar is smoothed (averaged) over the last values, |
226 | 226 | so it differs from the actual loss returned in train/validation step. |
227 | 227 | """ |
228 | | - warnings.warn('`training_step` must be implemented to be used with the Lightning Trainer') |
| 228 | + rank_zero_warn('`training_step` must be implemented to be used with the Lightning Trainer') |
229 | 229 |
|
230 | 230 | def training_end(self, *args, **kwargs): |
231 | 231 | """ |
@@ -1088,7 +1088,7 @@ def configure_optimizers(self): |
1088 | 1088 | } |
1089 | 1089 |
|
1090 | 1090 | """ |
1091 | | - warnings.warn('`configure_optimizers` must be implemented to be used with the Lightning Trainer') |
| 1091 | + rank_zero_warn('`configure_optimizers` must be implemented to be used with the Lightning Trainer') |
1092 | 1092 |
|
1093 | 1093 | def optimizer_step( |
1094 | 1094 | self, |
@@ -1291,16 +1291,16 @@ def train_dataloader(self): |
1291 | 1291 | return loader |
1292 | 1292 |
|
1293 | 1293 | """ |
1294 | | - warnings.warn('`train_dataloader` must be implemented to be used with the Lightning Trainer') |
| 1294 | + rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer') |
1295 | 1295 |
|
1296 | 1296 | def tng_dataloader(self): # todo: remove in v1.0.0 |
1297 | 1297 | """ |
1298 | 1298 | Warnings: |
1299 | 1299 | Deprecated in v0.5.0. Use :meth:`train_dataloader` instead. Will be removed in 1.0.0. |
1300 | 1300 | """ |
1301 | 1301 | output = self.train_dataloader() |
1302 | | - warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0." |
1303 | | - " and this method will be removed in v1.0.0", DeprecationWarning) |
| 1302 | + rank_zero_warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0." |
| 1303 | + " and this method will be removed in v1.0.0", DeprecationWarning) |
1304 | 1304 | return output |
1305 | 1305 |
|
1306 | 1306 | def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: |
@@ -1407,7 +1407,7 @@ def load_from_metrics(cls, weights_path, tags_csv, map_location=None): |
1407 | 1407 | Deprecated in version 0.7.0. You should use :meth:`load_from_checkpoint` instead. |
1408 | 1408 | Will be removed in v0.9.0. |
1409 | 1409 | """ |
1410 | | - warnings.warn( |
| 1410 | + rank_zero_warn( |
1411 | 1411 | "`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0." |
1412 | 1412 | " The deprecated method will be removed in v0.9.0.", DeprecationWarning |
1413 | 1413 | ) |
@@ -1519,7 +1519,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'Ligh |
1519 | 1519 | is_namespace = checkpoint.get('hparams_type', 'namespace') == 'namespace' |
1520 | 1520 | hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams |
1521 | 1521 | else: |
1522 | | - warnings.warn( |
| 1522 | + rank_zero_warn( |
1523 | 1523 | f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ " |
1524 | 1524 | f"contains argument 'hparams'. Will pass in an empty Namespace instead." |
1525 | 1525 | " Did you forget to store your model hyperparameters in self.hparams?" |
|
0 commit comments