Skip to content

Commit 4ea72a9

Browse files
awaelchlikaushikb11tchaton
authored
Update setup logic in training type plugins (sharded) [4 / 4] (#10028)
Co-authored-by: Kaushik B <[email protected]> Co-authored-by: thomas chaton <[email protected]>
1 parent 84706a2 commit 4ea72a9

File tree

3 files changed

+82
-34
lines changed

3 files changed

+82
-34
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
213213
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
214214
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
215215
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
216+
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028))
216217
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))
217218

218219

pytorch_lightning/plugins/training_type/sharded.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Dict, Generator, Optional
15+
from typing import Dict, Generator, List, Optional, Tuple, Union
1616

1717
import torch
18+
from torch.nn import Module
19+
from torch.optim import Optimizer
1820

1921
import pytorch_lightning as pl
2022
from pytorch_lightning.core.optimizer import LightningOptimizer
@@ -33,47 +35,70 @@
3335
class DDPShardedPlugin(DDPPlugin):
3436
"""Optimizer and gradient sharded training provided by FairScale."""
3537

36-
_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M
38+
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M
3739

38-
def configure_ddp(self) -> None:
39-
self._wrap_optimizers()
40+
def __init__(self, *args, **kwargs):
41+
super().__init__(*args, **kwargs)
42+
self._precision = None
4043

44+
def configure_ddp(self) -> None:
45+
trainer = self.lightning_module.trainer
4146
if "reduce_buffer_size" not in self._ddp_kwargs:
4247
# For multi-node training, enabling bucketing will improve performance.
4348
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
4449

45-
self._model = ShardedDataParallel(
46-
LightningShardedDataParallel(self.model),
47-
sharded_optimizer=self.lightning_module.trainer.optimizers,
48-
**self._ddp_kwargs
50+
[self._model], optimizers = self._setup_models_and_optimizers(
51+
models=[LightningShardedDataParallel(self.model)],
52+
optimizers=trainer.optimizers,
4953
)
50-
setattr(self._model, "require_backward_grad_sync", False)
54+
trainer.optimizers = optimizers
55+
trainer.convert_to_lightning_optimizers()
56+
57+
def _setup_models_and_optimizers(
58+
self, models: List[Module], optimizers: List[Optimizer]
59+
) -> Tuple[List[Module], List[Optimizer]]:
60+
"""Wraps the model and optimizers with fairscale components.
5161
52-
def _reinit_optimizers_with_oss(self):
53-
optimizers = self.lightning_module.trainer.optimizers
62+
Currently only one model can be setup at once.
63+
64+
Return:
65+
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
66+
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
67+
"""
68+
if len(models) > 1:
69+
raise ValueError(
70+
"DDPSharded only supports setting up a single model with one or several optimizers."
71+
f" Got {len(models)} models."
72+
)
73+
74+
optimizers = self._wrap_optimizers(optimizers)
75+
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
76+
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
77+
return [model], optimizers
78+
79+
def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
5480
for x, optimizer in enumerate(optimizers):
5581
if isinstance(optimizer, LightningOptimizer):
5682
optimizer = optimizer._optimizer
5783
if not isinstance(optimizer, OSS):
5884
optim_class = type(optimizer)
5985
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
6086
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
61-
precision = self.lightning_module.trainer.precision
87+
precision = self._precision or self.lightning_module.trainer.precision
6288
is_fp16 = precision in ("mixed", 16)
6389
# For multi-node training, compressing the model shards in fp16 before broadcasting
6490
# improves performance. When using PyTorch AMP, it will not degrade
6591
# the model performance.
6692
zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
6793
optimizers[x] = zero_optimizer
6894
del optimizer
69-
trainer = self.lightning_module.trainer
70-
trainer.optimizers = optimizers
71-
trainer.convert_to_lightning_optimizers()
95+
return optimizers
96+
97+
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
98+
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
99+
return optimizers
72100

73-
def _wrap_optimizers(self):
74-
if self.model.trainer.state.fn != TrainerFn.FITTING:
75-
return
76-
self._reinit_optimizers_with_oss()
101+
return self._reinit_optimizers_with_oss(optimizers)
77102

78103
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
79104
if isinstance(optimizer, LightningOptimizer):

pytorch_lightning/plugins/training_type/sharded_spawn.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414
from contextlib import contextmanager
1515
from multiprocessing.queues import SimpleQueue
16-
from typing import Dict, Generator, Optional
16+
from typing import Dict, Generator, List, Optional, Tuple
1717

1818
import torch
19+
from torch.nn import Module
20+
from torch.optim import Optimizer
1921

2022
import pytorch_lightning as pl
2123
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
@@ -36,29 +38,49 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
3638
"""Optimizer sharded training provided by FairScale."""
3739

3840
def configure_ddp(self) -> None:
39-
self._wrap_optimizers()
40-
self._model = ShardedDataParallel(
41-
LightningShardedDataParallel(self.model),
42-
sharded_optimizer=self.lightning_module.trainer.optimizers,
43-
**self._ddp_kwargs
41+
trainer = self.lightning_module.trainer
42+
[self._model], optimizers = self._setup_models_and_optimizers(
43+
models=[LightningShardedDataParallel(self.model)],
44+
optimizers=trainer.optimizers,
4445
)
45-
setattr(self._model, "require_backward_grad_sync", False)
46+
trainer.optimizers = optimizers
47+
48+
def _setup_models_and_optimizers(
49+
self, models: List[Module], optimizers: List[Optimizer]
50+
) -> Tuple[List[Module], List[Optimizer]]:
51+
"""Wraps the model and optimizers with fairscale components.
4652
47-
def _reinit_optimizers_with_oss(self):
48-
optimizers = self.lightning_module.trainer.optimizers
53+
Currently only one model can be setup at once.
54+
55+
Return:
56+
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
57+
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
58+
"""
59+
if len(models) > 1:
60+
raise ValueError(
61+
f"DDPShardedSpawn only supports setting up a single model with one or several optimizers."
62+
f" Got {len(models)} models."
63+
)
64+
65+
optimizers = self._wrap_optimizers(optimizers)
66+
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
67+
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
68+
return [model], optimizers
69+
70+
def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
4971
for x, optimizer in enumerate(optimizers):
5072
if not isinstance(optimizer, OSS):
5173
optim_class = type(optimizer)
5274
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
5375
optimizers[x] = zero_optimizer
5476
del optimizer
55-
trainer = self.lightning_module.trainer
56-
trainer.optimizers = optimizers
77+
return optimizers
78+
79+
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
80+
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
81+
return optimizers
5782

58-
def _wrap_optimizers(self):
59-
if self.model.trainer.state.fn != TrainerFn.FITTING:
60-
return
61-
self._reinit_optimizers_with_oss()
83+
return self._reinit_optimizers_with_oss(optimizers)
6284

6385
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
6486
if isinstance(optimizer, OSS):

0 commit comments

Comments
 (0)