Skip to content

Commit 5f4a5fe

Browse files
carmoccalexierule
authored andcommitted
Fix to_torchscript() causing false positive deprecation warnings (#10470)
1 parent 53ff840 commit 5f4a5fe

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
- Fixed sampler replacement logic with `overfit_batches` to only replace the sample when `SequentialSampler` is not used ([#10486](https://github.com/PyTorchLightning/pytorch-lightning/issues/10486))
2828

2929

30+
- Fixed `to_torchscript()` causing false positive deprecation warnings ([#10470](https://github.com/PyTorchLightning/pytorch-lightning/issues/10470))
31+
32+
3033
- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))
3134

3235

pytorch_lightning/core/lightning.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
116116
self._param_requires_grad_state = {}
117117
self._metric_attributes: Optional[Dict[int, str]] = None
118118
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
119+
# TODO: remove after the 1.6 release
120+
self._running_torchscript = False
119121

120122
self._register_sharded_tensor_state_dict_hooks_if_available()
121123

@@ -1962,6 +1964,8 @@ def to_torchscript(
19621964
"""
19631965
mode = self.training
19641966

1967+
self._running_torchscript = True
1968+
19651969
if method == "script":
19661970
torchscript_module = torch.jit.script(self.eval(), **kwargs)
19671971
elif method == "trace":
@@ -1987,6 +1991,8 @@ def to_torchscript(
19871991
with fs.open(file_path, "wb") as f:
19881992
torch.jit.save(torchscript_module, f)
19891993

1994+
self._running_torchscript = False
1995+
19901996
return torchscript_module
19911997

19921998
@property
@@ -1996,11 +2002,12 @@ def model_size(self) -> float:
19962002
Note:
19972003
This property will not return correct value for Deepspeed (stage 3) and fully-sharded training.
19982004
"""
1999-
rank_zero_deprecation(
2000-
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
2001-
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",
2002-
stacklevel=5,
2003-
)
2005+
if not self._running_torchscript: # remove with the deprecation removal
2006+
rank_zero_deprecation(
2007+
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
2008+
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",
2009+
stacklevel=5,
2010+
)
20042011
return get_model_size_mb(self)
20052012

20062013
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:

0 commit comments

Comments
 (0)