Skip to content

Commit 98eb736

Browse files
authored
Added getstate/setstate method for torch.save serialization (#4127)
* Added getstate/setstate method for torch.save serialization, added additional Optional Typing to results object * Added tests to ensure torch.save does not fail * Added flags to ensure compatible ddp cpu environment * Removed torch version check due to minimum already being 1.3, reduced epochs for speed * Moved tests to separate file * Update to accelerator, move to ddp_spawn to prevent hanging ddp
1 parent 01402e3 commit 98eb736

File tree

5 files changed

+99
-1
lines changed

5 files changed

+99
-1
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,20 @@ def init_ddp_connection(
214214
torch_backend, rank=global_rank, world_size=world_size
215215
)
216216

217+
def __getstate__(self):
218+
return {
219+
'trainer': self.trainer,
220+
'nickname': self.nickname,
221+
'cluster_environment': self.cluster_environment,
222+
'dist': self.dist
223+
}
224+
225+
def __setstate__(self, d):
226+
self.trainer = d['trainer']
227+
self.nickname = d['nickname']
228+
self.cluster_environment = d['cluster_environment']
229+
self.dist = d['dist']
230+
217231

218232
# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
219233
class BackendType(Enum):

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(self, *args, **kwargs):
107107
# optionally can be set by user
108108
self._example_input_array = None
109109
self._datamodule = None
110-
self._results: Result = None
110+
self._results: Optional[Result] = None
111111
self._current_fx_name = ''
112112

113113
def optimizers(self):

pytorch_lightning/core/step_result.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
2424
from pytorch_lightning.metrics import Metric
2525

26+
2627
class Result(Dict):
2728
def __init__(
2829
self,
@@ -89,6 +90,12 @@ def __setattr__(self, key: str, val: Union[Tensor, Any]):
8990

9091
self[key] = val
9192

93+
def __getstate__(self):
94+
return self
95+
96+
def __setstate__(self, d):
97+
self.update(d)
98+
9299
def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]):
93100
if potential_metric is not None and not isinstance(potential_metric, bool):
94101
assert isinstance(potential_metric, Tensor), f'{name} must be a torch.Tensor'

tests/checkpointing/test_model_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from distutils.version import LooseVersion
1516
from unittest.mock import MagicMock, Mock
1617

1718
import yaml
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import platform
16+
17+
import pytest
18+
import torch
19+
20+
from pytorch_lightning import Trainer
21+
from tests.base import EvalModelTemplate
22+
23+
24+
def test_model_torch_save(tmpdir):
25+
"""Test to ensure torch save does not fail for model and trainer."""
26+
model = EvalModelTemplate()
27+
num_epochs = 1
28+
trainer = Trainer(
29+
default_root_dir=tmpdir,
30+
max_epochs=num_epochs,
31+
)
32+
temp_path = os.path.join(tmpdir, 'temp.pt')
33+
trainer.fit(model)
34+
35+
# Ensure these do not fail
36+
torch.save(trainer.model, temp_path)
37+
torch.save(trainer, temp_path)
38+
39+
40+
@pytest.mark.skipif(platform.system() == "Windows",
41+
reason="Distributed training is not supported on Windows")
42+
def test_model_torch_save_ddp_cpu(tmpdir):
43+
"""Test to ensure torch save does not fail for model and trainer using cpu ddp."""
44+
model = EvalModelTemplate()
45+
num_epochs = 1
46+
trainer = Trainer(
47+
default_root_dir=tmpdir,
48+
max_epochs=num_epochs,
49+
accelerator="ddp_cpu",
50+
num_processes=2,
51+
)
52+
temp_path = os.path.join(tmpdir, 'temp.pt')
53+
trainer.fit(model)
54+
55+
# Ensure these do not fail
56+
torch.save(trainer.model, temp_path)
57+
torch.save(trainer, temp_path)
58+
59+
60+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
61+
def test_model_torch_save_ddp_cuda(tmpdir):
62+
"""Test to ensure torch save does not fail for model and trainer using gpu ddp."""
63+
model = EvalModelTemplate()
64+
num_epochs = 1
65+
trainer = Trainer(
66+
default_root_dir=tmpdir,
67+
max_epochs=num_epochs,
68+
accelerator="ddp_spawn",
69+
gpus=2
70+
)
71+
temp_path = os.path.join(tmpdir, 'temp.pt')
72+
trainer.fit(model)
73+
74+
# Ensure these do not fail
75+
torch.save(trainer.model, temp_path)
76+
torch.save(trainer, temp_path)

0 commit comments

Comments
 (0)