Skip to content

Commit a5f82f5

Browse files
ananthsubrohitgr7carmocca
authored andcommitted
Fix initialization of optimizers in DDP Strategy (#11952)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent f89b181 commit a5f82f5

File tree

11 files changed

+151
-48
lines changed

11 files changed

+151
-48
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Fixed torchelastic detection with non-distributed installations ([#13142](https://github.com/PyTorchLightning/pytorch-lightning/pull/13142))
3131
- Fixed logging's step values when multiple dataloaders are used during evaluation ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184))
3232
- Fixed epoch logging on train epoch end ([#13025](https://github.com/PyTorchLightning/pytorch-lightning/pull/13025))
33+
- Fixed `DDPStrategy` and `DDPSpawnStrategy` to initialize optimizers only after moving the module to the device ([#11952](https://github.com/PyTorchLightning/pytorch-lightning/pull/11952))
3334

3435

3536
## [1.6.3] - 2022-05-03

pytorch_lightning/strategies/bagua.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
from pytorch_lightning.plugins.precision import PrecisionPlugin
1717
from pytorch_lightning.strategies.ddp import DDPStrategy
1818
from pytorch_lightning.strategies.strategy import TBroadcast
19+
from pytorch_lightning.trainer.states import TrainerFn
1920
from pytorch_lightning.utilities.distributed import ReduceOp
2021
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2122
from pytorch_lightning.utilities.imports import _BAGUA_AVAILABLE
23+
from pytorch_lightning.utilities.optimizer import optimizers_to_device
2224
from pytorch_lightning.utilities.seed import reset_seed
2325

2426
if _BAGUA_AVAILABLE:
@@ -152,6 +154,33 @@ def _set_node_environment_variables(self) -> None:
152154
os.environ["WORLD_SIZE"] = str(self.world_size)
153155
os.environ["LOCAL_RANK"] = str(self.local_rank)
154156

157+
def setup(self, trainer: "pl.Trainer") -> None:
158+
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
159+
if self._should_run_deadlock_detection():
160+
self._share_information_to_prevent_deadlock()
161+
162+
self.accelerator.setup(trainer)
163+
164+
# move the model to the correct device
165+
self.model_to_device()
166+
167+
trainer_fn = trainer.state.fn
168+
169+
if trainer_fn == TrainerFn.FITTING:
170+
if self._layer_sync and self.model:
171+
self.model = self._layer_sync.apply(self.model)
172+
173+
self.setup_precision_plugin()
174+
175+
if trainer_fn == TrainerFn.FITTING:
176+
# set up optimizers after the module has been moved to the device
177+
# but before the module has been wrapped
178+
self.setup_optimizers(trainer)
179+
optimizers_to_device(self.optimizers, self.root_device)
180+
181+
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
182+
self._configure_bagua_model(trainer)
183+
155184
def _check_qadam_optimizer(self) -> None:
156185
has_qadam_optimizer = any([isinstance(opt, QAdamOptimizer) for opt in self.optimizers])
157186

@@ -160,13 +189,12 @@ def _check_qadam_optimizer(self) -> None:
160189

161190
self._bagua_kwargs["q_adam_optimizer"] = self.optimizers[0]
162191

163-
def configure_ddp(self) -> None:
192+
def _configure_bagua_model(self, trainer: "pl.Trainer") -> None:
164193
model = LightningBaguaModule(self.model) # type: ignore[arg-type]
165194
self._model = self._setup_model(model)
166195

167196
# start the background communication for async algorithm
168-
assert self.lightning_module.trainer is not None
169-
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
197+
if trainer.training and self._bagua_algorithm == "async":
170198
self.model.bagua_algorithm.resume(self.model) # type: ignore
171199

172200
def _setup_model(self, model: Module) -> BaguaDistributedDataParallel:

pytorch_lightning/strategies/ddp.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
_TORCH_GREATER_EQUAL_1_10,
5757
_TORCH_GREATER_EQUAL_1_11,
5858
)
59+
from pytorch_lightning.utilities.optimizer import optimizers_to_device
5960
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
6061
from pytorch_lightning.utilities.seed import reset_seed
6162
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -152,24 +153,37 @@ def setup_environment(self) -> None:
152153
super().setup_environment()
153154

154155
def setup(self, trainer: "pl.Trainer") -> None:
155-
super().setup(trainer)
156156
# share ddp pids to all processes
157157
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
158158
if self._should_run_deadlock_detection():
159159
self._share_information_to_prevent_deadlock()
160160

161+
self.accelerator.setup(trainer)
162+
161163
# move the model to the correct device
162164
self.model_to_device()
163165

164166
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
165167
trainer_fn = trainer.state.fn
166-
if trainer_fn != TrainerFn.FITTING:
167-
return
168168

169-
if self._layer_sync:
170-
self.model = self._layer_sync.apply(self.model)
169+
if trainer_fn == TrainerFn.FITTING:
170+
if self._layer_sync:
171+
self.model = self._layer_sync.apply(self.model)
172+
173+
self.setup_precision_plugin()
174+
175+
if trainer_fn == TrainerFn.FITTING:
176+
self.configure_ddp()
171177

172-
self.configure_ddp()
178+
# set up optimizers after the wrapped module has been moved to the device
179+
self.setup_optimizers(trainer)
180+
optimizers_to_device(self.optimizers, self.root_device)
181+
182+
if _TORCH_GREATER_EQUAL_1_10 and trainer_fn == TrainerFn.FITTING:
183+
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
184+
185+
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
186+
self._enable_model_averaging()
173187

174188
def _setup_model(self, model: Module) -> DistributedDataParallel:
175189
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
@@ -223,12 +237,6 @@ def _register_ddp_hooks(self) -> None:
223237
ddp_comm_wrapper=self._ddp_comm_wrapper,
224238
)
225239

226-
if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
227-
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
228-
229-
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
230-
self._enable_model_averaging()
231-
232240
def _enable_model_averaging(self) -> None:
233241
# Only called when PyTorch version >= 1.10
234242
log.detail(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")

pytorch_lightning/strategies/ddp_spawn.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
sync_ddp_if_available,
4343
)
4444
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
45+
from pytorch_lightning.utilities.optimizer import optimizers_to_device
4546
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
4647
from pytorch_lightning.utilities.seed import reset_seed
4748
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -122,20 +123,22 @@ def _configure_launcher(self):
122123

123124
def setup(self, trainer: "pl.Trainer") -> None:
124125
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
125-
super().setup(trainer)
126+
127+
self.accelerator.setup(trainer)
126128

127129
# move the model to the correct device
128130
self.model_to_device()
129131

130-
trainer_fn = self.lightning_module.trainer.state.fn
131-
if trainer_fn != TrainerFn.FITTING:
132-
return
132+
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
133+
trainer_fn = trainer.state.fn
134+
if trainer_fn == TrainerFn.FITTING:
135+
if self._layer_sync:
136+
self.model = self._layer_sync.apply(self.model)
133137

134-
if self._layer_sync:
135-
self.model = self._layer_sync.apply(self.model)
138+
self.setup_precision_plugin()
136139

137-
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
138-
self.configure_ddp()
140+
if trainer_fn == TrainerFn.FITTING:
141+
self.configure_ddp()
139142

140143
def _setup_model(self, model: Module) -> DistributedDataParallel:
141144
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
@@ -186,6 +189,10 @@ def configure_ddp(self) -> None:
186189
self.model = self._setup_model(LightningDistributedModule(self.model))
187190
self._register_ddp_hooks()
188191

192+
# set up optimizers after the wrapped module has been moved to the device
193+
self.setup_optimizers(self.lightning_module.trainer)
194+
optimizers_to_device(self.optimizers, self.root_device)
195+
189196
def determine_ddp_device_ids(self):
190197
if self.root_device.type == "cpu":
191198
return None

pytorch_lightning/strategies/fully_sharded.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,16 @@ def setup_distributed(self) -> None:
139139
def setup(self, trainer: "pl.Trainer") -> None:
140140
self.accelerator.setup(trainer)
141141

142-
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
143-
self.model = self._layer_sync.apply(self.model)
142+
if trainer.state.fn == TrainerFn.FITTING:
143+
self.setup_optimizers(trainer)
144+
optimizers_to_device(self.optimizers, self.root_device)
144145

146+
if self._layer_sync:
147+
self.model = self._layer_sync.apply(self.model)
148+
149+
self.setup_precision_plugin()
145150
self.configure_ddp()
146151
self.barrier()
147-
self.setup_optimizers(trainer)
148-
optimizers_to_device(self.optimizers, self.root_device)
149-
self.setup_precision_plugin()
150152

151153
@contextlib.contextmanager
152154
def model_sharded_context(self) -> Generator:
@@ -183,6 +185,9 @@ def configure_ddp(self) -> None:
183185
# (TODO: need to figure out solution)
184186
self.model_to_device()
185187

188+
# setup optimizers after fully sharded has wrapped the lightning module
189+
self.setup_optimizers(self.lightning_module.trainer)
190+
186191
def model_to_device(self) -> None:
187192
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
188193
# ensure we update the device type in the lightning module

pytorch_lightning/strategies/sharded.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pytorch_lightning.utilities.enums import PrecisionType
2626
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2727
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
28+
from pytorch_lightning.utilities.optimizer import optimizers_to_device
2829
from pytorch_lightning.utilities.rank_zero import rank_zero_only
2930

3031
if _FAIRSCALE_AVAILABLE:
@@ -40,16 +41,41 @@ class DDPShardedStrategy(DDPStrategy):
4041
strategy_name = "ddp_sharded"
4142
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2**23 # 8M
4243

43-
def configure_ddp(self) -> None:
44-
trainer = self.lightning_module.trainer
45-
if "reduce_buffer_size" not in self._ddp_kwargs:
46-
# For multi-node training, enabling bucketing will improve performance.
47-
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
44+
def setup(self, trainer: "pl.Trainer") -> None:
45+
# share ddp pids to all processes
46+
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
47+
if self._should_run_deadlock_detection():
48+
self._share_information_to_prevent_deadlock()
49+
50+
self.accelerator.setup(trainer)
51+
52+
# move the model to the correct device
53+
self.model_to_device()
54+
55+
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
56+
trainer_fn = trainer.state.fn
57+
if trainer_fn == TrainerFn.FITTING:
58+
if self._layer_sync:
59+
self.model = self._layer_sync.apply(self.model)
4860

61+
self.setup_precision_plugin()
62+
63+
if trainer_fn == TrainerFn.FITTING:
64+
self.configure_ddp()
65+
66+
def configure_ddp(self) -> None:
67+
self._set_ddp_kwargs()
68+
self.setup_optimizers(self.model.trainer)
4969
self.model, self.optimizers = self._setup_model_and_optimizers(
5070
model=LightningShardedDataParallel(self.model),
51-
optimizers=trainer.optimizers,
71+
optimizers=self.optimizers,
5272
)
73+
optimizers_to_device(self.optimizers, self.root_device)
74+
75+
def _set_ddp_kwargs(self) -> None:
76+
if "reduce_buffer_size" not in self._ddp_kwargs:
77+
# For multi-node training, enabling bucketing will improve performance.
78+
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
5379

5480
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
5581
"""Wraps the model and optimizers with fairscale components.
@@ -62,6 +88,12 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
6288
model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs)
6389
return model, optimizers
6490

91+
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
92+
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
93+
return optimizers
94+
95+
return self._reinit_optimizers_with_oss(optimizers)
96+
6597
def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
6698
for x, optimizer in enumerate(optimizers):
6799
if isinstance(optimizer, LightningOptimizer):
@@ -79,12 +111,6 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin
79111
del optimizer
80112
return optimizers
81113

82-
def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
83-
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
84-
return optimizers
85-
86-
return self._reinit_optimizers_with_oss(optimizers)
87-
88114
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
89115
if isinstance(optimizer, LightningOptimizer):
90116
optimizer = optimizer._optimizer

pytorch_lightning/strategies/sharded_spawn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_lightning.trainer.states import TrainerFn
2424
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2525
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
26+
from pytorch_lightning.utilities.optimizer import optimizers_to_device
2627
from pytorch_lightning.utilities.rank_zero import rank_zero_only
2728

2829
if _FAIRSCALE_AVAILABLE:
@@ -38,9 +39,12 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy):
3839
strategy_name = "ddp_sharded_spawn"
3940

4041
def configure_ddp(self) -> None:
42+
# set up optimizers after the wrapped module has been moved to the device
43+
self.setup_optimizers(self.lightning_module.trainer)
4144
self.model, self.optimizers = self._setup_model_and_optimizers(
4245
model=LightningShardedDataParallel(self.model), optimizers=self.optimizers
4346
)
47+
optimizers_to_device(self.optimizers, self.root_device)
4448

4549
def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
4650
"""Wraps the model and optimizers with fairscale components.

pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
2727
from pytorch_lightning.strategies.launchers.xla_spawn import _XLASpawnLauncher
2828
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
29+
from pytorch_lightning.trainer.states import TrainerFn
2930
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
3031
from pytorch_lightning.utilities.data import has_len
3132
from pytorch_lightning.utilities.distributed import ReduceOp
@@ -126,9 +127,6 @@ def _configure_launcher(self):
126127
def setup(self, trainer: "pl.Trainer") -> None:
127128
self.start_method = "fork"
128129
self.accelerator.setup(trainer)
129-
self.setup_optimizers(trainer)
130-
self.setup_precision_plugin()
131-
optimizers_to_device(self.optimizers, self.root_device)
132130

133131
if self.debug:
134132
os.environ["PT_XLA_DEBUG"] = str(1)
@@ -140,8 +138,11 @@ def setup(self, trainer: "pl.Trainer") -> None:
140138
else:
141139
set_shared_parameters(self.model.module, shared_params)
142140

143-
self.setup_optimizers(trainer)
144-
self.precision_plugin.connect(self.model, None, None)
141+
self.setup_precision_plugin()
142+
143+
if trainer.state.fn == TrainerFn.FITTING:
144+
self.setup_optimizers(trainer)
145+
optimizers_to_device(self.optimizers, self.root_device)
145146

146147
def _setup_model(self, model: Module) -> Module:
147148
return model

tests/strategies/test_bagua_strategy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def test_configuration(algorithm, tmpdir):
8585
), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True):
8686
if algorithm == "qadam":
8787
with pytest.raises(MisconfigurationException, match="Bagua QAdam can only accept one QAdamOptimizer"):
88-
trainer.strategy.configure_ddp()
88+
trainer.strategy._configure_bagua_model(trainer)
8989
else:
90-
trainer.strategy.configure_ddp()
90+
trainer.strategy._configure_bagua_model(trainer)
9191

9292

9393
@RunIf(min_gpus=1, bagua=True)
@@ -109,7 +109,7 @@ def test_qadam_configuration(tmpdir):
109109
with mock.patch(
110110
"bagua.torch_api.data_parallel.bagua_distributed.BaguaDistributedDataParallel.__init__", return_value=None
111111
), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True):
112-
trainer.strategy.configure_ddp()
112+
trainer.strategy._configure_bagua_model(trainer)
113113

114114

115115
def test_bagua_not_available(monkeypatch):

0 commit comments

Comments
 (0)