-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix wrong behavior of DDPStrategy
option with simple GAN training using DDP
#20936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
00727e8
e6b061a
aa9b027
5503d3a
dc128b4
1fb4027
ec62397
ece7d38
8b1fe23
4b22284
97dabf8
033e8e8
57e864a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -214,7 +214,7 @@ def set_world_ranks(self) -> None: | |
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank | ||
|
||
def _register_ddp_hooks(self) -> None: | ||
log.debug(f"{self.__class__.__name__}: registering ddp hooks") | ||
log.debug(f"{self.__class__.__name__}: registering DDP hooks") | ||
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode | ||
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 | ||
if self.root_device.type == "cuda": | ||
|
@@ -419,6 +419,46 @@ def teardown(self) -> None: | |
super().teardown() | ||
|
||
|
||
class MultiModelDDPStrategy(DDPStrategy): | ||
samsara-ku marked this conversation as resolved.
Show resolved
Hide resolved
samsara-ku marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this strategy be available as a string, or are users expected to just provide a instance of the class themself?
|
||
"""Specific strategy for training on multiple models with multiple optimizers (e.g. GAN training). | ||
|
||
This strategy wraps each individual child module in :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module. | ||
Ensures manual backward only updates parameters of the targeted child module, preventing cross-references between modules' parameters. | ||
|
||
""" | ||
|
||
@override | ||
def _setup_model(self, model: Module) -> DistributedDataParallel: | ||
device_ids = self.determine_ddp_device_ids() | ||
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}") | ||
# https://pytorch.org/docs/stable/notes/cuda.html#id5 | ||
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext() | ||
with ctx: | ||
for name, module in model.named_children(): | ||
if isinstance(module, Module): | ||
ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs) | ||
setattr(model, name, ddp_module) | ||
return model | ||
|
||
@override | ||
def _register_ddp_hooks(self) -> None: | ||
log.debug(f"{self.__class__.__name__}: registering DDP hooks") | ||
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode | ||
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 | ||
if self.root_device.type != "cuda": | ||
return | ||
assert isinstance(self.model, Module) | ||
|
||
for name, module in self.model.named_children(): | ||
assert isinstance(module, DistributedDataParallel) | ||
_register_ddp_comm_hook( | ||
model=module, | ||
ddp_comm_state=self._ddp_comm_state, | ||
ddp_comm_hook=self._ddp_comm_hook, | ||
ddp_comm_wrapper=self._ddp_comm_wrapper, | ||
) | ||
|
||
|
||
class _DDPForwardRedirection(_ForwardRedirection): | ||
@override | ||
def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from unittest import mock | ||
from unittest.mock import PropertyMock | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy | ||
|
||
|
||
def test_multi_model_ddp_setup_and_register_hooks(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add one line docstring what the test does |
||
class Parent(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.gen = nn.Linear(1, 1) | ||
self.dis = nn.Linear(1, 1) | ||
|
||
model = Parent() | ||
original_children = [model.gen, model.dis] | ||
|
||
strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")]) | ||
|
||
wrapped_modules = [] | ||
wrapped_device_ids = [] | ||
|
||
class DummyDDP(nn.Module): | ||
def __init__(self, module: nn.Module, device_ids=None, **kwargs): | ||
super().__init__() | ||
self.module = module | ||
wrapped_modules.append(module) | ||
wrapped_device_ids.append(device_ids) | ||
|
||
with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP): | ||
returned_model = strategy._setup_model(model) | ||
assert returned_model is model | ||
assert isinstance(model.gen, DummyDDP) | ||
assert isinstance(model.dis, DummyDDP) | ||
assert wrapped_modules == original_children | ||
assert wrapped_device_ids == [None, None] | ||
|
||
strategy.model = model | ||
with ( | ||
mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook, | ||
mock.patch.object(MultiModelDDPStrategy, "root_device", new_callable=PropertyMock) as root_device, | ||
): | ||
root_device.return_value = torch.device("cuda", 0) | ||
strategy._register_ddp_hooks() | ||
|
||
assert register_hook.call_count == 2 | ||
register_hook.assert_any_call( | ||
model=model.gen, | ||
ddp_comm_state=strategy._ddp_comm_state, | ||
ddp_comm_hook=strategy._ddp_comm_hook, | ||
ddp_comm_wrapper=strategy._ddp_comm_wrapper, | ||
) | ||
register_hook.assert_any_call( | ||
model=model.dis, | ||
ddp_comm_state=strategy._ddp_comm_state, | ||
ddp_comm_hook=strategy._ddp_comm_hook, | ||
ddp_comm_wrapper=strategy._ddp_comm_wrapper, | ||
) | ||
|
||
|
||
def test_multi_model_ddp_register_hooks_cpu_noop(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add one line docstring what the test does |
||
class Parent(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.gen = nn.Linear(1, 1) | ||
self.dis = nn.Linear(1, 1) | ||
|
||
model = Parent() | ||
strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")]) | ||
|
||
class DummyDDP(nn.Module): | ||
def __init__(self, module: nn.Module, device_ids=None, **kwargs): | ||
super().__init__() | ||
self.module = module | ||
|
||
with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP): | ||
strategy.model = strategy._setup_model(model) | ||
|
||
with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook: | ||
strategy._register_ddp_hooks() | ||
|
||
register_hook.assert_not_called() | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.