|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | from contextlib import contextmanager |
15 | | -from typing import Dict, Generator, Optional |
| 15 | +from typing import Dict, Generator, List, Optional, Tuple, Union |
16 | 16 |
|
17 | 17 | import torch |
| 18 | +from torch.nn import Module |
| 19 | +from torch.optim import Optimizer |
18 | 20 |
|
19 | 21 | import pytorch_lightning as pl |
20 | 22 | from pytorch_lightning.core.optimizer import LightningOptimizer |
|
33 | 35 | class DDPShardedPlugin(DDPPlugin): |
34 | 36 | """Optimizer and gradient sharded training provided by FairScale.""" |
35 | 37 |
|
36 | | - _REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M |
| 38 | + _REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M |
37 | 39 |
|
38 | | - def configure_ddp(self) -> None: |
39 | | - self._wrap_optimizers() |
| 40 | + def __init__(self, *args, **kwargs): |
| 41 | + super().__init__(*args, **kwargs) |
| 42 | + self._precision = None |
40 | 43 |
|
| 44 | + def configure_ddp(self) -> None: |
| 45 | + trainer = self.lightning_module.trainer |
41 | 46 | if "reduce_buffer_size" not in self._ddp_kwargs: |
42 | 47 | # For multi-node training, enabling bucketing will improve performance. |
43 | 48 | self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 |
44 | 49 |
|
45 | | - self._model = ShardedDataParallel( |
46 | | - LightningShardedDataParallel(self.model), |
47 | | - sharded_optimizer=self.lightning_module.trainer.optimizers, |
48 | | - **self._ddp_kwargs |
| 50 | + [self._model], optimizers = self._setup_models_and_optimizers( |
| 51 | + models=[LightningShardedDataParallel(self.model)], |
| 52 | + optimizers=trainer.optimizers, |
49 | 53 | ) |
50 | | - setattr(self._model, "require_backward_grad_sync", False) |
| 54 | + trainer.optimizers = optimizers |
| 55 | + trainer.convert_to_lightning_optimizers() |
| 56 | + |
| 57 | + def _setup_models_and_optimizers( |
| 58 | + self, models: List[Module], optimizers: List[Optimizer] |
| 59 | + ) -> Tuple[List[Module], List[Optimizer]]: |
| 60 | + """Wraps the model and optimizers with fairscale components. |
51 | 61 |
|
52 | | - def _reinit_optimizers_with_oss(self): |
53 | | - optimizers = self.lightning_module.trainer.optimizers |
| 62 | + Currently only one model can be setup at once. |
| 63 | +
|
| 64 | + Return: |
| 65 | + A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module |
| 66 | + and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. |
| 67 | + """ |
| 68 | + if len(models) > 1: |
| 69 | + raise ValueError( |
| 70 | + "DDPSharded only supports setting up a single model with one or several optimizers." |
| 71 | + f" Got {len(models)} models." |
| 72 | + ) |
| 73 | + |
| 74 | + optimizers = self._wrap_optimizers(optimizers) |
| 75 | + model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) |
| 76 | + setattr(model, "require_backward_grad_sync", False) # TODO: needed? |
| 77 | + return [model], optimizers |
| 78 | + |
| 79 | + def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: |
54 | 80 | for x, optimizer in enumerate(optimizers): |
55 | 81 | if isinstance(optimizer, LightningOptimizer): |
56 | 82 | optimizer = optimizer._optimizer |
57 | 83 | if not isinstance(optimizer, OSS): |
58 | 84 | optim_class = type(optimizer) |
59 | 85 | zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) |
60 | 86 | if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: |
61 | | - precision = self.lightning_module.trainer.precision |
| 87 | + precision = self._precision or self.lightning_module.trainer.precision |
62 | 88 | is_fp16 = precision in ("mixed", 16) |
63 | 89 | # For multi-node training, compressing the model shards in fp16 before broadcasting |
64 | 90 | # improves performance. When using PyTorch AMP, it will not degrade |
65 | 91 | # the model performance. |
66 | 92 | zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 |
67 | 93 | optimizers[x] = zero_optimizer |
68 | 94 | del optimizer |
69 | | - trainer = self.lightning_module.trainer |
70 | | - trainer.optimizers = optimizers |
71 | | - trainer.convert_to_lightning_optimizers() |
| 95 | + return optimizers |
| 96 | + |
| 97 | + def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: |
| 98 | + if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: |
| 99 | + return optimizers |
72 | 100 |
|
73 | | - def _wrap_optimizers(self): |
74 | | - if self.model.trainer.state.fn != TrainerFn.FITTING: |
75 | | - return |
76 | | - self._reinit_optimizers_with_oss() |
| 101 | + return self._reinit_optimizers_with_oss(optimizers) |
77 | 102 |
|
78 | 103 | def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: |
79 | 104 | if isinstance(optimizer, LightningOptimizer): |
|
0 commit comments