Skip to content

Commit 5503d3a

Browse files
authored
Merge pull request #1 from samsara-ku/codex/add-tests-for-multimodelddp-implementation
test: cover MultiModelDDPStrategy
2 parents e6b061a + aa9b027 commit 5503d3a

File tree

3 files changed

+119
-23
lines changed

3 files changed

+119
-23
lines changed

examples/pytorch/domain_templates/generative_adversarial_net.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@
3636
import torchvision
3737

3838

39+
def _block(in_feat: int, out_feat: int, normalize: bool = True):
40+
layers = [nn.Linear(in_feat, out_feat)]
41+
if normalize:
42+
layers.append(nn.BatchNorm1d(out_feat, 0.8))
43+
layers.append(nn.LeakyReLU(0.2, inplace=True))
44+
return layers
45+
46+
3947
class Generator(nn.Module):
4048
"""
4149
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
@@ -47,19 +55,11 @@ class Generator(nn.Module):
4755
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
4856
super().__init__()
4957
self.img_shape = img_shape
50-
51-
def block(in_feat, out_feat, normalize=True):
52-
layers = [nn.Linear(in_feat, out_feat)]
53-
if normalize:
54-
layers.append(nn.BatchNorm1d(out_feat, 0.8))
55-
layers.append(nn.LeakyReLU(0.2, inplace=True))
56-
return layers
57-
5858
self.model = nn.Sequential(
59-
*block(latent_dim, 128, normalize=False),
60-
*block(128, 256),
61-
*block(256, 512),
62-
*block(512, 1024),
59+
*_block(latent_dim, 128, normalize=False),
60+
*_block(128, 256),
61+
*_block(256, 512),
62+
*_block(512, 1024),
6363
nn.Linear(1024, int(math.prod(img_shape))),
6464
nn.Tanh(),
6565
)

examples/pytorch/domain_templates/generative_adversarial_net_ddp.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@
4444
import torchvision
4545

4646

47+
def _block(in_feat: int, out_feat: int, normalize: bool = True):
48+
layers = [nn.Linear(in_feat, out_feat)]
49+
if normalize:
50+
layers.append(nn.BatchNorm1d(out_feat, 0.8))
51+
layers.append(nn.LeakyReLU(0.2, inplace=True))
52+
return layers
53+
54+
4755
class Generator(nn.Module):
4856
"""
4957
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
@@ -56,18 +64,11 @@ def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
5664
super().__init__()
5765
self.img_shape = img_shape
5866

59-
def block(in_feat, out_feat, normalize=True):
60-
layers = [nn.Linear(in_feat, out_feat)]
61-
if normalize:
62-
layers.append(nn.BatchNorm1d(out_feat, 0.8))
63-
layers.append(nn.LeakyReLU(0.2, inplace=True))
64-
return layers
65-
6667
self.model = nn.Sequential(
67-
*block(latent_dim, 128, normalize=False),
68-
*block(128, 256),
69-
*block(256, 512),
70-
*block(512, 1024),
68+
*_block(latent_dim, 128, normalize=False),
69+
*_block(128, 256),
70+
*_block(256, 512),
71+
*_block(512, 1024),
7172
nn.Linear(1024, int(math.prod(img_shape))),
7273
nn.Tanh(),
7374
)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from unittest import mock
15+
from unittest.mock import PropertyMock
16+
17+
import torch
18+
from torch import nn
19+
20+
from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy
21+
22+
23+
def test_multi_model_ddp_setup_and_register_hooks():
24+
class Parent(nn.Module):
25+
def __init__(self):
26+
super().__init__()
27+
self.gen = nn.Linear(1, 1)
28+
self.dis = nn.Linear(1, 1)
29+
30+
model = Parent()
31+
original_children = [model.gen, model.dis]
32+
33+
strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")])
34+
35+
wrapped_modules = []
36+
wrapped_device_ids = []
37+
38+
class DummyDDP(nn.Module):
39+
def __init__(self, module: nn.Module, device_ids=None, **kwargs):
40+
super().__init__()
41+
self.module = module
42+
wrapped_modules.append(module)
43+
wrapped_device_ids.append(device_ids)
44+
45+
with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP):
46+
returned_model = strategy._setup_model(model)
47+
assert returned_model is model
48+
assert isinstance(model.gen, DummyDDP)
49+
assert isinstance(model.dis, DummyDDP)
50+
assert wrapped_modules == original_children
51+
assert wrapped_device_ids == [None, None]
52+
53+
strategy.model = model
54+
with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook:
55+
with mock.patch.object(MultiModelDDPStrategy, "root_device", new_callable=PropertyMock) as root_device:
56+
root_device.return_value = torch.device("cuda", 0)
57+
strategy._register_ddp_hooks()
58+
59+
assert register_hook.call_count == 2
60+
register_hook.assert_any_call(
61+
model=model.gen,
62+
ddp_comm_state=strategy._ddp_comm_state,
63+
ddp_comm_hook=strategy._ddp_comm_hook,
64+
ddp_comm_wrapper=strategy._ddp_comm_wrapper,
65+
)
66+
register_hook.assert_any_call(
67+
model=model.dis,
68+
ddp_comm_state=strategy._ddp_comm_state,
69+
ddp_comm_hook=strategy._ddp_comm_hook,
70+
ddp_comm_wrapper=strategy._ddp_comm_wrapper,
71+
)
72+
73+
74+
def test_multi_model_ddp_register_hooks_cpu_noop():
75+
class Parent(nn.Module):
76+
def __init__(self) -> None:
77+
super().__init__()
78+
self.gen = nn.Linear(1, 1)
79+
self.dis = nn.Linear(1, 1)
80+
81+
model = Parent()
82+
strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")])
83+
84+
class DummyDDP(nn.Module):
85+
def __init__(self, module: nn.Module, device_ids=None, **kwargs):
86+
super().__init__()
87+
self.module = module
88+
89+
with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP):
90+
strategy.model = strategy._setup_model(model)
91+
92+
with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook:
93+
strategy._register_ddp_hooks()
94+
95+
register_hook.assert_not_called()

0 commit comments

Comments
 (0)