Skip to content

Commit d17d4f4

Browse files
authored
Never pickle the Trainer with the LightningModule (#17133)
1 parent c575efb commit d17d4f4

File tree

3 files changed

+4
-17
lines changed

3 files changed

+4
-17
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414

1515
### Changed
1616

17-
-
17+
- Pickling the `LightningModule` no longer pickles the `Trainer` ([#17133](https://github.com/Lightning-AI/lightning/pull/17133))
1818

1919

2020
### Depercated
@@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929

3030
### Fixed
3131

32-
-
32+
3333

3434

3535
## [2.0.0] - 2023-03-15

src/lightning/pytorch/callbacks/stochastic_weight_avg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
150150
raise MisconfigurationException("SWA does not currently support sharded models.")
151151

152152
# copy the model before moving it to accelerator device.
153-
with pl_module._prevent_trainer_and_dataloaders_deepcopy():
154-
self._average_model = deepcopy(pl_module)
153+
self._average_model = deepcopy(pl_module)
155154

156155
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
157156
if len(trainer.optimizers) != 1:

src/lightning/pytorch/core/module.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
124124
self._automatic_optimization: bool = True
125125
self._param_requires_grad_state: Dict[str, bool] = {}
126126
self._metric_attributes: Optional[Dict[int, str]] = None
127-
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
128127
self._register_sharded_tensor_state_dict_hooks_if_available()
129128
self._compiler_ctx: Optional[Dict[str, Any]] = None
130129

@@ -1539,20 +1538,9 @@ def load_from_checkpoint(
15391538
)
15401539
return cast(Self, loaded)
15411540

1542-
@contextmanager
1543-
def _prevent_trainer_and_dataloaders_deepcopy(self) -> Generator[None, None, None]:
1544-
self._should_prevent_trainer_and_dataloaders_deepcopy = True
1545-
yield
1546-
self._should_prevent_trainer_and_dataloaders_deepcopy = False
1547-
15481541
def __getstate__(self) -> Dict[str, Any]:
15491542
state = dict(self.__dict__)
1550-
if self._should_prevent_trainer_and_dataloaders_deepcopy:
1551-
state["_trainer"] = None
1552-
state.pop("train_dataloader", None)
1553-
state.pop("val_dataloader", None)
1554-
state.pop("test_dataloader", None)
1555-
state.pop("predict_dataloader", None)
1543+
state["_trainer"] = None
15561544
return state
15571545

15581546
def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:

0 commit comments

Comments
 (0)