Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
00727e8
add: `MultiModelDDPStrategy` and its execution codes
samsara-ku Jun 25, 2025
e6b061a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 25, 2025
aa9b027
refactor: extract block helper in GAN example
samsara-ku Aug 5, 2025
5503d3a
Merge pull request #1 from samsara-ku/codex/add-tests-for-multimodeld…
samsara-ku Aug 5, 2025
dc128b4
Merge branch 'master' into bugfix/gan-ddp-training
Borda Aug 8, 2025
1fb4027
with
Borda Aug 8, 2025
ec62397
Apply suggestions from code review
Borda Aug 8, 2025
ece7d38
formating
Borda Aug 8, 2025
8b1fe23
misc: resolve some review comments for product consistency
samsara-ku Aug 11, 2025
4b22284
misc: merge gan training example, add docstring of MultiModelDDPStrategy
samsara-ku Aug 12, 2025
97dabf8
misc: add docstring of MultiModelDDPStrategy
samsara-ku Aug 12, 2025
033e8e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2025
57e864a
update
Borda Aug 13, 2025
f157f59
Merge branch 'master' into bugfix/gan-ddp-training
SkafteNicki Aug 13, 2025
24872e7
long line
Borda Aug 13, 2025
3891102
Merge branch 'master' into bugfix/gan-ddp-training
SkafteNicki Aug 14, 2025
8121337
add: set base test case and __init__py for MultiModelDDPStrategy
samsara-ku Aug 15, 2025
c442fc3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2025
ab9b2dd
Merge branch 'master' into bugfix/gan-ddp-training
Borda Aug 18, 2025
a58039e
Merge branch 'master' into bugfix/gan-ddp-training
Borda Sep 2, 2025
2ae8072
Merge branch 'master' into bugfix/gan-ddp-training
Borda Sep 4, 2025
77a81b4
Merge branch 'master' into bugfix/gan-ddp-training
Borda Sep 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions examples/pytorch/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,22 @@
from lightning.pytorch import cli_lightning_logo
from lightning.pytorch.core import LightningModule
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule
from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision


def _block(in_feat: int, out_feat: int, normalize: bool = True) -> list:
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers


class Generator(nn.Module):
"""
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Expand All @@ -47,19 +56,11 @@ class Generator(nn.Module):
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
super().__init__()
self.img_shape = img_shape

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
*_block(latent_dim, 128, normalize=False),
*_block(128, 256),
*_block(256, 512),
*_block(512, 1024),
nn.Linear(1024, int(math.prod(img_shape))),
nn.Tanh(),
)
Expand Down Expand Up @@ -209,10 +210,18 @@ def main(args: Namespace) -> None:
# ------------------------
# 2 INIT TRAINER
# ------------------------
# If use distributed training PyTorch recommends to use DistributedDataParallel.
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
dm = MNISTDataModule()
trainer = Trainer(accelerator="gpu", devices=1)

if args.use_ddp:
# `MultiModelDDPStrategy` is critical for multi-gpu GAN training
# There are two ways to run training codes with existed `DDPStrategy`:
# 1) Activate `find_unused_parameters` option
# 2) change from self.manual_backward(loss) to loss.backward()
# Neither of them is desirable.
trainer = Trainer(accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy())
else:
# If you want to run on a single GPU, you can use the default strategy.
trainer = Trainer(accelerator="gpu", devices=1)

# ------------------------
# 3 START TRAINING
Expand All @@ -229,6 +238,8 @@ def main(args: Namespace) -> None:
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--use_ddp", action="store_true", help="distributed strategy to use")

args = parser.parse_args()

main(args)
41 changes: 40 additions & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def set_world_ranks(self) -> None:
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank

def _register_ddp_hooks(self) -> None:
log.debug(f"{self.__class__.__name__}: registering ddp hooks")
log.debug(f"{self.__class__.__name__}: registering DDP hooks")
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if self.root_device.type == "cuda":
Expand Down Expand Up @@ -419,6 +419,45 @@ def teardown(self) -> None:
super().teardown()


class MultiModelDDPStrategy(DDPStrategy):
"""Specific strategy for training on multiple models with multiple optimizers (e.g. GAN training).

This strategy allows to wrap multiple models with DDP, rather than just one which is about just normal DDPStrategy.

"""

@override
def _setup_model(self, model: Module) -> DistributedDataParallel:
device_ids = self.determine_ddp_device_ids()
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
# https://pytorch.org/docs/stable/notes/cuda.html#id5
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
with ctx:
for name, module in model.named_children():
if isinstance(module, Module):
ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs)
setattr(model, name, ddp_module)
return model

@override
def _register_ddp_hooks(self) -> None:
log.debug(f"{self.__class__.__name__}: registering DDP hooks")
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if self.root_device.type != "cuda":
return
assert isinstance(self.model, Module)

for name, module in self.model.named_children():
assert isinstance(module, DistributedDataParallel)
_register_ddp_comm_hook(
model=module,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
)


class _DDPForwardRedirection(_ForwardRedirection):
@override
def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None:
Expand Down
97 changes: 97 additions & 0 deletions tests/tests_pytorch/strategies/test_multi_model_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import PropertyMock

import torch
from torch import nn

from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy


def test_multi_model_ddp_setup_and_register_hooks():
class Parent(nn.Module):
def __init__(self):
super().__init__()
self.gen = nn.Linear(1, 1)
self.dis = nn.Linear(1, 1)

model = Parent()
original_children = [model.gen, model.dis]

strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")])

wrapped_modules = []
wrapped_device_ids = []

class DummyDDP(nn.Module):
def __init__(self, module: nn.Module, device_ids=None, **kwargs):
super().__init__()
self.module = module
wrapped_modules.append(module)
wrapped_device_ids.append(device_ids)

with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP):
returned_model = strategy._setup_model(model)
assert returned_model is model
assert isinstance(model.gen, DummyDDP)
assert isinstance(model.dis, DummyDDP)
assert wrapped_modules == original_children
assert wrapped_device_ids == [None, None]

strategy.model = model
with (
mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook,
mock.patch.object(MultiModelDDPStrategy, "root_device", new_callable=PropertyMock) as root_device,
):
root_device.return_value = torch.device("cuda", 0)
strategy._register_ddp_hooks()

assert register_hook.call_count == 2
register_hook.assert_any_call(
model=model.gen,
ddp_comm_state=strategy._ddp_comm_state,
ddp_comm_hook=strategy._ddp_comm_hook,
ddp_comm_wrapper=strategy._ddp_comm_wrapper,
)
register_hook.assert_any_call(
model=model.dis,
ddp_comm_state=strategy._ddp_comm_state,
ddp_comm_hook=strategy._ddp_comm_hook,
ddp_comm_wrapper=strategy._ddp_comm_wrapper,
)


def test_multi_model_ddp_register_hooks_cpu_noop():
class Parent(nn.Module):
def __init__(self) -> None:
super().__init__()
self.gen = nn.Linear(1, 1)
self.dis = nn.Linear(1, 1)

model = Parent()
strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")])

class DummyDDP(nn.Module):
def __init__(self, module: nn.Module, device_ids=None, **kwargs):
super().__init__()
self.module = module

with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP):
strategy.model = strategy._setup_model(model)

with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook:
strategy._register_ddp_hooks()

register_hook.assert_not_called()
Loading