Skip to content

Fix wrong behavior of DDPStrategy option with simple GAN training using DDP #20936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions examples/pytorch/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,22 @@
from lightning.pytorch import cli_lightning_logo
from lightning.pytorch.core import LightningModule
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule
from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision


def _block(in_feat: int, out_feat: int, normalize: bool = True) -> list:
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers


class Generator(nn.Module):
"""
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Expand All @@ -47,19 +56,11 @@ class Generator(nn.Module):
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
super().__init__()
self.img_shape = img_shape

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
*_block(latent_dim, 128, normalize=False),
*_block(128, 256),
*_block(256, 512),
*_block(512, 1024),
nn.Linear(1024, int(math.prod(img_shape))),
nn.Tanh(),
)
Expand Down Expand Up @@ -209,10 +210,18 @@ def main(args: Namespace) -> None:
# ------------------------
# 2 INIT TRAINER
# ------------------------
# If use distributed training PyTorch recommends to use DistributedDataParallel.
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
dm = MNISTDataModule()
trainer = Trainer(accelerator="gpu", devices=1)

if args.use_ddp:
# `MultiModelDDPStrategy` is critical for multi-gpu GAN training
# There are two ways to run training codes with existed `DDPStrategy`:
# 1) Activate `find_unused_parameters` option
# 2) change from self.manual_backward(loss) to loss.backward()
# Neither of them is desirable.
trainer = Trainer(accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy())
else:
# If you want to run on a single GPU, you can use the default strategy.
trainer = Trainer(accelerator="gpu", devices=1)

# ------------------------
# 3 START TRAINING
Expand All @@ -229,6 +238,8 @@ def main(args: Namespace) -> None:
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--use_ddp", action="store_true", help="distributed strategy to use")

args = parser.parse_args()

main(args)
42 changes: 41 additions & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def set_world_ranks(self) -> None:
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank

def _register_ddp_hooks(self) -> None:
log.debug(f"{self.__class__.__name__}: registering ddp hooks")
log.debug(f"{self.__class__.__name__}: registering DDP hooks")
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if self.root_device.type == "cuda":
Expand Down Expand Up @@ -419,6 +419,46 @@ def teardown(self) -> None:
super().teardown()


class MultiModelDDPStrategy(DDPStrategy):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this strategy be available as a string, or are users expected to just provide a instance of the class themself?

"""Specific strategy for training on multiple models with multiple optimizers (e.g. GAN training).

This strategy wraps each individual child module in :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module.
Ensures manual backward only updates parameters of the targeted child module, preventing cross-references between modules' parameters.

"""

@override
def _setup_model(self, model: Module) -> DistributedDataParallel:
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
for name, module in model.named_children():
if isinstance(module, Module):
ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs)
setattr(model, name, ddp_module)
return model

@override
def _register_ddp_hooks(self) -> None:
log.debug(f"{self.__class__.__name__}: registering DDP hooks")
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if self.root_device.type != "cuda":
return
assert isinstance(self.model, Module)

for name, module in self.model.named_children():
assert isinstance(module, DistributedDataParallel)
_register_ddp_comm_hook(
model=module,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
)


class _DDPForwardRedirection(_ForwardRedirection):
@override
def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None:
Expand Down
97 changes: 97 additions & 0 deletions tests/tests_pytorch/strategies/test_multi_model_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import PropertyMock

import torch
from torch import nn

from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy


def test_multi_model_ddp_setup_and_register_hooks():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add one line docstring what the test does

class Parent(nn.Module):
def __init__(self):
super().__init__()
self.gen = nn.Linear(1, 1)
self.dis = nn.Linear(1, 1)

model = Parent()
original_children = [model.gen, model.dis]

strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")])

wrapped_modules = []
wrapped_device_ids = []

class DummyDDP(nn.Module):
def __init__(self, module: nn.Module, device_ids=None, **kwargs):
super().__init__()
self.module = module
wrapped_modules.append(module)
wrapped_device_ids.append(device_ids)

with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP):
returned_model = strategy._setup_model(model)
assert returned_model is model
assert isinstance(model.gen, DummyDDP)
assert isinstance(model.dis, DummyDDP)
assert wrapped_modules == original_children
assert wrapped_device_ids == [None, None]

strategy.model = model
with (
mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook,
mock.patch.object(MultiModelDDPStrategy, "root_device", new_callable=PropertyMock) as root_device,
):
root_device.return_value = torch.device("cuda", 0)
strategy._register_ddp_hooks()

assert register_hook.call_count == 2
register_hook.assert_any_call(
model=model.gen,
ddp_comm_state=strategy._ddp_comm_state,
ddp_comm_hook=strategy._ddp_comm_hook,
ddp_comm_wrapper=strategy._ddp_comm_wrapper,
)
register_hook.assert_any_call(
model=model.dis,
ddp_comm_state=strategy._ddp_comm_state,
ddp_comm_hook=strategy._ddp_comm_hook,
ddp_comm_wrapper=strategy._ddp_comm_wrapper,
)


def test_multi_model_ddp_register_hooks_cpu_noop():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add one line docstring what the test does

class Parent(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gen = nn.Linear(1, 1)
self.dis = nn.Linear(1, 1)

model = Parent()
strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")])

class DummyDDP(nn.Module):
def __init__(self, module: nn.Module, device_ids=None, **kwargs):
super().__init__()
self.module = module

with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP):
strategy.model = strategy._setup_model(model)

with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook:
strategy._register_ddp_hooks()

register_hook.assert_not_called()