Skip to content

Commit 00727e8

Browse files
committed
add: MultiModelDDPStrategy and its execution codes
1 parent 39438bf commit 00727e8

File tree

2 files changed

+297
-13
lines changed

2 files changed

+297
-13
lines changed
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""To run this template just do: python generative_adversarial_net.py.
15+
16+
After a few epochs, launch TensorBoard to see the images being generated at every batch:
17+
18+
tensorboard --logdir default
19+
20+
"""
21+
import math
22+
from argparse import ArgumentParser, Namespace
23+
24+
import torch
25+
import torch.nn as nn
26+
import torch.nn.functional as F
27+
28+
# ! TESTING
29+
import os
30+
import sys
31+
32+
sys.path.append(os.path.join(os.getcwd(), "src")) # noqa: E402
33+
# ! TESTING
34+
35+
from lightning.pytorch import cli_lightning_logo
36+
from lightning.pytorch.core import LightningModule
37+
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule
38+
from lightning.pytorch.trainer import Trainer
39+
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
40+
from lightning.pytorch.strategies.ddp import DDPStrategy, MultiModelDDPStrategy
41+
42+
if _TORCHVISION_AVAILABLE:
43+
import torchvision
44+
45+
46+
class Generator(nn.Module):
47+
"""
48+
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
49+
Generator(
50+
(model): Sequential(...)
51+
)
52+
"""
53+
54+
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
55+
super().__init__()
56+
self.img_shape = img_shape
57+
58+
def block(in_feat, out_feat, normalize=True):
59+
layers = [nn.Linear(in_feat, out_feat)]
60+
if normalize:
61+
layers.append(nn.BatchNorm1d(out_feat, 0.8))
62+
layers.append(nn.LeakyReLU(0.2, inplace=True))
63+
return layers
64+
65+
self.model = nn.Sequential(
66+
*block(latent_dim, 128, normalize=False),
67+
*block(128, 256),
68+
*block(256, 512),
69+
*block(512, 1024),
70+
nn.Linear(1024, int(math.prod(img_shape))),
71+
nn.Tanh(),
72+
)
73+
74+
def forward(self, z):
75+
img = self.model(z)
76+
return img.view(img.size(0), *self.img_shape)
77+
78+
79+
class Discriminator(nn.Module):
80+
"""
81+
>>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
82+
Discriminator(
83+
(model): Sequential(...)
84+
)
85+
"""
86+
87+
def __init__(self, img_shape):
88+
super().__init__()
89+
90+
self.model = nn.Sequential(
91+
nn.Linear(int(math.prod(img_shape)), 512),
92+
nn.LeakyReLU(0.2, inplace=True),
93+
nn.Linear(512, 256),
94+
nn.LeakyReLU(0.2, inplace=True),
95+
nn.Linear(256, 1),
96+
)
97+
98+
def forward(self, img):
99+
img_flat = img.view(img.size(0), -1)
100+
return self.model(img_flat)
101+
102+
103+
class GAN(LightningModule):
104+
"""
105+
>>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
106+
GAN(
107+
(generator): Generator(
108+
(model): Sequential(...)
109+
)
110+
(discriminator): Discriminator(
111+
(model): Sequential(...)
112+
)
113+
)
114+
"""
115+
116+
def __init__(
117+
self,
118+
img_shape: tuple = (1, 28, 28),
119+
lr: float = 0.0002,
120+
b1: float = 0.5,
121+
b2: float = 0.999,
122+
latent_dim: int = 100,
123+
):
124+
super().__init__()
125+
self.save_hyperparameters()
126+
self.automatic_optimization = False
127+
128+
# networks
129+
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape)
130+
self.discriminator = Discriminator(img_shape=img_shape)
131+
132+
self.validation_z = torch.randn(8, self.hparams.latent_dim)
133+
134+
self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
135+
136+
# ! TESTING
137+
self.save_path = "pl_test_multi_gpu"
138+
os.makedirs(self.save_path, exist_ok=True)
139+
140+
def forward(self, z):
141+
return self.generator(z)
142+
143+
@staticmethod
144+
def adversarial_loss(y_hat, y):
145+
return F.binary_cross_entropy_with_logits(y_hat, y)
146+
147+
def training_step(self, batch):
148+
imgs, _ = batch
149+
150+
opt_g, opt_d = self.optimizers()
151+
152+
# sample noise
153+
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
154+
z = z.type_as(imgs)
155+
156+
# Train generator
157+
# ground truth result (ie: all fake)
158+
# put on GPU because we created this tensor inside training_loop
159+
valid = torch.ones(imgs.size(0), 1)
160+
valid = valid.type_as(imgs)
161+
162+
self.toggle_optimizer(opt_g)
163+
# adversarial loss is binary cross-entropy
164+
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
165+
opt_g.zero_grad()
166+
self.manual_backward(g_loss)
167+
opt_g.step()
168+
self.untoggle_optimizer(opt_g)
169+
170+
# Train discriminator
171+
# Measure discriminator's ability to classify real from generated samples
172+
# how well can it label as real?
173+
valid = torch.ones(imgs.size(0), 1)
174+
valid = valid.type_as(imgs)
175+
176+
self.toggle_optimizer(opt_d)
177+
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
178+
179+
# how well can it label as fake?
180+
fake = torch.zeros(imgs.size(0), 1)
181+
fake = fake.type_as(imgs)
182+
183+
fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)
184+
185+
# discriminator loss is the average of these
186+
d_loss = (real_loss + fake_loss) / 2
187+
188+
opt_d.zero_grad()
189+
self.manual_backward(d_loss)
190+
opt_d.step()
191+
self.untoggle_optimizer(opt_d)
192+
193+
self.log_dict({"d_loss": d_loss, "g_loss": g_loss})
194+
195+
def configure_optimizers(self):
196+
lr = self.hparams.lr
197+
b1 = self.hparams.b1
198+
b2 = self.hparams.b2
199+
200+
opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
201+
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
202+
return opt_g, opt_d
203+
204+
# ! TESTING
205+
def on_train_epoch_start(self):
206+
if self.trainer.is_global_zero:
207+
print("GEN: ", self.generator.module.model[0].bias[:10])
208+
print("DISC: ", self.discriminator.module.model[0].bias[:10])
209+
210+
# ! TESTING
211+
def validation_step(self, batch, batch_idx):
212+
pass
213+
214+
# ! TESTING
215+
@torch.no_grad()
216+
def on_validation_epoch_end(self):
217+
if self.current_epoch % 5:
218+
self.generator.eval(), self.discriminator.eval()
219+
220+
z = self.validation_z.type_as(self.generator.module.model[0].weight)
221+
sample_imgs = self(z)
222+
223+
if self.trainer.is_global_zero:
224+
grid = torchvision.utils.make_grid(sample_imgs)
225+
torchvision.utils.save_image(grid, os.path.join(self.save_path, f"epoch_{self.current_epoch}.png"))
226+
227+
self.generator.train(), self.discriminator.train()
228+
229+
230+
def main(args: Namespace) -> None:
231+
model = GAN(lr=args.lr, b1=args.b1, b2=args.b2, latent_dim=args.latent_dim)
232+
233+
# ! `MultiModelDDPStrategy` is critical for multi-gpu training
234+
# ! Otherwise, it will not work with multiple models.
235+
# ! There are two ways to run training codes with previous `DDPStrategy`;
236+
# ! 1) activate `find_unused_parameters=True`, 2) change from self.manual_backward(loss) to loss.backward()
237+
# ! Neither of them is desirable.
238+
dm = MNISTDataModule()
239+
trainer = Trainer(
240+
accelerator="auto",
241+
devices=[0, 1, 2, 3],
242+
strategy=MultiModelDDPStrategy(),
243+
max_epochs=100,
244+
)
245+
246+
trainer.fit(model, dm)
247+
248+
249+
if __name__ == "__main__":
250+
cli_lightning_logo()
251+
parser = ArgumentParser()
252+
253+
# Hyperparameters
254+
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
255+
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
256+
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
257+
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
258+
args = parser.parse_args()
259+
260+
main(args)

src/lightning/pytorch/strategies/ddp.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ def __init__(
107107
@property
108108
def is_distributed(self) -> bool: # pragma: no-cover
109109
"""Legacy property kept for backwards compatibility."""
110-
rank_zero_deprecation(
111-
f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6
112-
)
110+
rank_zero_deprecation(f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6)
113111
return True
114112

115113
@property
@@ -229,9 +227,7 @@ def _register_ddp_hooks(self) -> None:
229227
def _enable_model_averaging(self) -> None:
230228
log.debug(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")
231229
if self._model_averaging_period is None:
232-
raise ValueError(
233-
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy."
234-
)
230+
raise ValueError("Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy.")
235231
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
236232

237233
for optimizer in self.optimizers:
@@ -240,10 +236,7 @@ def _enable_model_averaging(self) -> None:
240236

241237
is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False
242238
if isinstance(optimizer, (ZeroRedundancyOptimizer, PostLocalSGDOptimizer)) or is_distributed_optimizer:
243-
raise ValueError(
244-
f"Currently model averaging cannot work with a distributed optimizer of type "
245-
f"{optimizer.__class__.__name__}."
246-
)
239+
raise ValueError(f"Currently model averaging cannot work with a distributed optimizer of type " f"{optimizer.__class__.__name__}.")
247240

248241
assert self._ddp_comm_state is not None
249242
self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager(
@@ -323,9 +316,7 @@ def model_to_device(self) -> None:
323316
self.model.to(self.root_device)
324317

325318
@override
326-
def reduce(
327-
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
328-
) -> Tensor:
319+
def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean") -> Tensor:
329320
"""Reduces a tensor from several distributed processes to one aggregated tensor.
330321
331322
Args:
@@ -419,6 +410,39 @@ def teardown(self) -> None:
419410
super().teardown()
420411

421412

413+
class MultiModelDDPStrategy(DDPStrategy):
414+
@override
415+
def _setup_model(self, model: Module) -> Module:
416+
device_ids = self.determine_ddp_device_ids()
417+
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
418+
# https://pytorch.org/docs/stable/notes/cuda.html#id5
419+
ctx = torch.cuda.stream(torch.cuda.Stream()) if device_ids is not None else nullcontext()
420+
with ctx:
421+
for name, module in model.named_children():
422+
if isinstance(module, Module):
423+
ddp_module = DistributedDataParallel(module, device_ids=device_ids, **self._ddp_kwargs)
424+
setattr(model, name, ddp_module)
425+
426+
return model
427+
428+
@override
429+
def _register_ddp_hooks(self) -> None:
430+
log.debug(f"{self.__class__.__name__}: registering ddp hooks")
431+
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
432+
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
433+
if self.root_device.type == "cuda":
434+
assert isinstance(self.model, Module)
435+
436+
for name, module in self.model.named_children():
437+
assert isinstance(module, DistributedDataParallel)
438+
_register_ddp_comm_hook(
439+
model=module,
440+
ddp_comm_state=self._ddp_comm_state,
441+
ddp_comm_hook=self._ddp_comm_hook,
442+
ddp_comm_wrapper=self._ddp_comm_wrapper,
443+
)
444+
445+
422446
class _DDPForwardRedirection(_ForwardRedirection):
423447
@override
424448
def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None:

0 commit comments

Comments
 (0)