diff --git a/examples/pytorch/domain_templates/generative_adversarial_net.py b/examples/pytorch/domain_templates/generative_adversarial_net.py index 7ce7682d82c76..941bc7241c961 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net.py @@ -29,6 +29,7 @@ from lightning.pytorch import cli_lightning_logo from lightning.pytorch.core import LightningModule from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule +from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy from lightning.pytorch.trainer import Trainer from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE @@ -36,6 +37,14 @@ import torchvision +def _block(in_feat: int, out_feat: int, normalize: bool = True) -> list: + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + class Generator(nn.Module): """ >>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE @@ -47,19 +56,11 @@ class Generator(nn.Module): def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)): super().__init__() self.img_shape = img_shape - - def block(in_feat, out_feat, normalize=True): - layers = [nn.Linear(in_feat, out_feat)] - if normalize: - layers.append(nn.BatchNorm1d(out_feat, 0.8)) - layers.append(nn.LeakyReLU(0.2, inplace=True)) - return layers - self.model = nn.Sequential( - *block(latent_dim, 128, normalize=False), - *block(128, 256), - *block(256, 512), - *block(512, 1024), + *_block(latent_dim, 128, normalize=False), + *_block(128, 256), + *_block(256, 512), + *_block(512, 1024), nn.Linear(1024, int(math.prod(img_shape))), nn.Tanh(), ) @@ -209,10 +210,18 @@ def main(args: Namespace) -> None: # ------------------------ # 2 INIT TRAINER # ------------------------ - # If use distributed training PyTorch recommends to use DistributedDataParallel. - # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel dm = MNISTDataModule() - trainer = Trainer(accelerator="gpu", devices=1) + + if args.use_ddp: + # `MultiModelDDPStrategy` is critical for multi-gpu GAN training + # There are two ways to run training codes with existed `DDPStrategy`: + # 1) Activate `find_unused_parameters` option + # 2) change from self.manual_backward(loss) to loss.backward() + # Neither of them is desirable. + trainer = Trainer(accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy()) + else: + # If you want to run on a single GPU, you can use the default strategy. + trainer = Trainer(accelerator="gpu", devices=1) # ------------------------ # 3 START TRAINING @@ -229,6 +238,8 @@ def main(args: Namespace) -> None: parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient") parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") + parser.add_argument("--use_ddp", action="store_true", help="distributed strategy to use") + args = parser.parse_args() main(args) diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index fd3f66ef42471..d7e9771dad6db 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -214,7 +214,7 @@ def set_world_ranks(self) -> None: rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank def _register_ddp_hooks(self) -> None: - log.debug(f"{self.__class__.__name__}: registering ddp hooks") + log.debug(f"{self.__class__.__name__}: registering DDP hooks") # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 if self.root_device.type == "cuda": @@ -419,6 +419,46 @@ def teardown(self) -> None: super().teardown() +class MultiModelDDPStrategy(DDPStrategy): + """Specific strategy for training on multiple models with multiple optimizers (e.g. GAN training). + + This strategy wraps each individual child module in :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module. + Ensures manual backward only updates parameters of the targeted child module, preventing cross-references between modules' parameters. + + """ + + @override + def _setup_model(self, model: Module) -> DistributedDataParallel: + device_ids = self.determine_ddp_device_ids() + log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") + # https://pytorch.org/docs/stable/notes/cuda.html#id5 + ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() + with ctx: + for name, module in model.named_children(): + if isinstance(module, Module): + ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs) + setattr(model, name, ddp_module) + return model + + @override + def _register_ddp_hooks(self) -> None: + log.debug(f"{self.__class__.__name__}: registering DDP hooks") + # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode + # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 + if self.root_device.type != "cuda": + return + assert isinstance(self.model, Module) + + for name, module in self.model.named_children(): + assert isinstance(module, DistributedDataParallel) + _register_ddp_comm_hook( + model=module, + ddp_comm_state=self._ddp_comm_state, + ddp_comm_hook=self._ddp_comm_hook, + ddp_comm_wrapper=self._ddp_comm_wrapper, + ) + + class _DDPForwardRedirection(_ForwardRedirection): @override def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None: diff --git a/tests/tests_pytorch/strategies/test_multi_model_ddp.py b/tests/tests_pytorch/strategies/test_multi_model_ddp.py new file mode 100644 index 0000000000000..053f2a3312bf8 --- /dev/null +++ b/tests/tests_pytorch/strategies/test_multi_model_ddp.py @@ -0,0 +1,97 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock +from unittest.mock import PropertyMock + +import torch +from torch import nn + +from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy + + +def test_multi_model_ddp_setup_and_register_hooks(): + class Parent(nn.Module): + def __init__(self): + super().__init__() + self.gen = nn.Linear(1, 1) + self.dis = nn.Linear(1, 1) + + model = Parent() + original_children = [model.gen, model.dis] + + strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")]) + + wrapped_modules = [] + wrapped_device_ids = [] + + class DummyDDP(nn.Module): + def __init__(self, module: nn.Module, device_ids=None, **kwargs): + super().__init__() + self.module = module + wrapped_modules.append(module) + wrapped_device_ids.append(device_ids) + + with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP): + returned_model = strategy._setup_model(model) + assert returned_model is model + assert isinstance(model.gen, DummyDDP) + assert isinstance(model.dis, DummyDDP) + assert wrapped_modules == original_children + assert wrapped_device_ids == [None, None] + + strategy.model = model + with ( + mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook, + mock.patch.object(MultiModelDDPStrategy, "root_device", new_callable=PropertyMock) as root_device, + ): + root_device.return_value = torch.device("cuda", 0) + strategy._register_ddp_hooks() + + assert register_hook.call_count == 2 + register_hook.assert_any_call( + model=model.gen, + ddp_comm_state=strategy._ddp_comm_state, + ddp_comm_hook=strategy._ddp_comm_hook, + ddp_comm_wrapper=strategy._ddp_comm_wrapper, + ) + register_hook.assert_any_call( + model=model.dis, + ddp_comm_state=strategy._ddp_comm_state, + ddp_comm_hook=strategy._ddp_comm_hook, + ddp_comm_wrapper=strategy._ddp_comm_wrapper, + ) + + +def test_multi_model_ddp_register_hooks_cpu_noop(): + class Parent(nn.Module): + def __init__(self) -> None: + super().__init__() + self.gen = nn.Linear(1, 1) + self.dis = nn.Linear(1, 1) + + model = Parent() + strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")]) + + class DummyDDP(nn.Module): + def __init__(self, module: nn.Module, device_ids=None, **kwargs): + super().__init__() + self.module = module + + with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP): + strategy.model = strategy._setup_model(model) + + with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook: + strategy._register_ddp_hooks() + + register_hook.assert_not_called()