Skip to content

Commit 8121337

Browse files
committed
add: set base test case and __init__py for MultiModelDDPStrategy
1 parent 3891102 commit 8121337

File tree

2 files changed

+153
-83
lines changed

2 files changed

+153
-83
lines changed

src/lightning/pytorch/strategies/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from lightning.fabric.strategies.registry import _StrategyRegistry
1717
from lightning.fabric.utilities.registry import _register_classes
18-
from lightning.pytorch.strategies.ddp import DDPStrategy
18+
from lightning.pytorch.strategies.ddp import DDPStrategy, MultiModelDDPStrategy
1919
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
2020
from lightning.pytorch.strategies.fsdp import FSDPStrategy
2121
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
@@ -30,6 +30,7 @@
3030

3131
__all__ = [
3232
"DDPStrategy",
33+
"MultiModelDDPStrategy",
3334
"DeepSpeedStrategy",
3435
"FSDPStrategy",
3536
"ModelParallelStrategy",

tests/tests_pytorch/strategies/test_multi_model_ddp.py

Lines changed: 151 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -11,87 +11,156 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from unittest import mock
15-
from unittest.mock import PropertyMock
14+
import os
1615

16+
import pytest
1717
import torch
18-
from torch import nn
19-
20-
from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy
21-
22-
23-
def test_multi_model_ddp_setup_and_register_hooks():
24-
class Parent(nn.Module):
25-
def __init__(self):
26-
super().__init__()
27-
self.gen = nn.Linear(1, 1)
28-
self.dis = nn.Linear(1, 1)
29-
30-
model = Parent()
31-
original_children = [model.gen, model.dis]
32-
33-
strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")])
34-
35-
wrapped_modules = []
36-
wrapped_device_ids = []
37-
38-
class DummyDDP(nn.Module):
39-
def __init__(self, module: nn.Module, device_ids=None, **kwargs):
40-
super().__init__()
41-
self.module = module
42-
wrapped_modules.append(module)
43-
wrapped_device_ids.append(device_ids)
44-
45-
with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP):
46-
returned_model = strategy._setup_model(model)
47-
assert returned_model is model
48-
assert isinstance(model.gen, DummyDDP)
49-
assert isinstance(model.dis, DummyDDP)
50-
assert wrapped_modules == original_children
51-
assert wrapped_device_ids == [None, None]
52-
53-
strategy.model = model
54-
with (
55-
mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook,
56-
mock.patch.object(MultiModelDDPStrategy, "root_device", new_callable=PropertyMock) as root_device,
57-
):
58-
root_device.return_value = torch.device("cuda", 0)
59-
strategy._register_ddp_hooks()
60-
61-
assert register_hook.call_count == 2
62-
register_hook.assert_any_call(
63-
model=model.gen,
64-
ddp_comm_state=strategy._ddp_comm_state,
65-
ddp_comm_hook=strategy._ddp_comm_hook,
66-
ddp_comm_wrapper=strategy._ddp_comm_wrapper,
67-
)
68-
register_hook.assert_any_call(
69-
model=model.dis,
70-
ddp_comm_state=strategy._ddp_comm_state,
71-
ddp_comm_hook=strategy._ddp_comm_hook,
72-
ddp_comm_wrapper=strategy._ddp_comm_wrapper,
73-
)
74-
75-
76-
def test_multi_model_ddp_register_hooks_cpu_noop():
77-
class Parent(nn.Module):
78-
def __init__(self) -> None:
79-
super().__init__()
80-
self.gen = nn.Linear(1, 1)
81-
self.dis = nn.Linear(1, 1)
82-
83-
model = Parent()
84-
strategy = MultiModelDDPStrategy(parallel_devices=[torch.device("cpu")])
85-
86-
class DummyDDP(nn.Module):
87-
def __init__(self, module: nn.Module, device_ids=None, **kwargs):
88-
super().__init__()
89-
self.module = module
90-
91-
with mock.patch("lightning.pytorch.strategies.ddp.DistributedDataParallel", DummyDDP):
92-
strategy.model = strategy._setup_model(model)
93-
94-
with mock.patch("lightning.pytorch.strategies.ddp._register_ddp_comm_hook") as register_hook:
95-
strategy._register_ddp_hooks()
96-
97-
register_hook.assert_not_called()
18+
from torch.multiprocessing import ProcessRaisedException
19+
20+
from lightning.pytorch import Trainer
21+
from lightning.pytorch.strategies import MultiModelDDPStrategy
22+
from lightning.pytorch.trainer import seed_everything
23+
from tests_pytorch.helpers.runif import RunIf
24+
from tests_pytorch.helpers.advanced_models import BasicGAN
25+
26+
27+
@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True)
28+
def test_multi_gpu_with_multi_model_ddp_fit_only(tmp_path):
29+
dm = BasicGAN.train_dataloader()
30+
model = BasicGAN()
31+
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy())
32+
trainer.fit(model, datamodule=dm)
33+
34+
35+
@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True)
36+
def test_multi_gpu_with_multi_model_ddp_predict_only(tmp_path):
37+
dm = BasicGAN.train_dataloader()
38+
model = BasicGAN()
39+
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy())
40+
trainer.predict(model, datamodule=dm)
41+
42+
43+
@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True)
44+
def test_multi_gpu_multi_model_ddp_fit_predict(tmp_path):
45+
seed_everything(4321)
46+
dm = BasicGAN.train_dataloader()
47+
model = BasicGAN()
48+
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy())
49+
trainer.fit(model, datamodule=dm)
50+
trainer.predict(model, datamodule=dm)
51+
52+
53+
class UnusedParametersBasicGAN(BasicGAN):
54+
def __init__(self):
55+
super().__init__()
56+
mnist_shape = (1, 28, 28)
57+
self.intermediate_layer = torch.nn.Linear(mnist_shape[-1], mnist_shape[-1])
58+
59+
def training_step(self, batch, batch_idx):
60+
with torch.no_grad():
61+
img = self.intermediate_layer(batch[0])
62+
batch[0] = img # modify the batch to use the intermediate layer result
63+
return super().training_step(batch, batch_idx)
64+
65+
66+
@RunIf(standalone=True)
67+
def test_find_unused_parameters_ddp_spawn_raises():
68+
"""Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning users."""
69+
trainer = Trainer(
70+
accelerator="cpu",
71+
devices=1,
72+
strategy=MultiModelDDPStrategy(),
73+
max_steps=2,
74+
logger=False,
75+
)
76+
with pytest.raises(ProcessRaisedException, match="It looks like your LightningModule has parameters that were not used in"):
77+
trainer.fit(UnusedParametersBasicGAN())
78+
79+
80+
@RunIf(standalone=True)
81+
def test_find_unused_parameters_ddp_exception():
82+
"""Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning users."""
83+
trainer = Trainer(
84+
accelerator="cpu",
85+
devices=1,
86+
strategy=MultiModelDDPStrategy(),
87+
max_steps=2,
88+
logger=False,
89+
)
90+
with pytest.raises(RuntimeError, match="It looks like your LightningModule has parameters that were not used in"):
91+
trainer.fit(UnusedParametersBasicGAN())
92+
93+
94+
class CheckOptimizerDeviceModel(BasicGAN):
95+
def configure_optimizers(self):
96+
assert all(param.device.type == "cuda" for param in self.parameters())
97+
super().configure_optimizers()
98+
99+
100+
@RunIf(min_cuda_gpus=1)
101+
def test_model_parameters_on_device_for_optimizer():
102+
"""Test that the strategy has moved the parameters to the device by the time the optimizer gets created."""
103+
model = CheckOptimizerDeviceModel()
104+
trainer = Trainer(
105+
default_root_dir=os.getcwd(),
106+
fast_dev_run=1,
107+
accelerator="gpu",
108+
devices=1,
109+
strategy=MultiModelDDPStrategy(),
110+
)
111+
trainer.fit(model)
112+
113+
114+
class BasicGANCPU(BasicGAN):
115+
def on_train_start(self) -> None:
116+
# make sure that the model is on CPU when training
117+
assert self.device == torch.device("cpu")
118+
119+
120+
@RunIf(skip_windows=True)
121+
def test_multi_model_ddp_with_cpu():
122+
"""Tests if device is set correctly when training for MultiModelDDPStrategy."""
123+
trainer = Trainer(
124+
accelerator="cpu",
125+
devices=-1,
126+
strategy=MultiModelDDPStrategy(),
127+
fast_dev_run=True,
128+
)
129+
# assert strategy attributes for device setting
130+
assert isinstance(trainer.strategy, MultiModelDDPStrategy)
131+
assert trainer.strategy.root_device == torch.device("cpu")
132+
model = BasicGANCPU()
133+
trainer.fit(model)
134+
135+
136+
class BasicGANGPU(BasicGAN):
137+
def on_train_start(self) -> None:
138+
# make sure that the model is on GPU when training
139+
assert self.device == torch.device(f"cuda:{self.trainer.strategy.local_rank}")
140+
self.start_cuda_memory = torch.cuda.memory_allocated()
141+
142+
143+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
144+
def test_multi_model_ddp_with_gpus():
145+
"""Tests if device is set correctly when training and after teardown for MultiModelDDPStrategy."""
146+
trainer = Trainer(
147+
accelerator="gpu",
148+
devices=-1,
149+
strategy=MultiModelDDPStrategy(),
150+
fast_dev_run=True,
151+
enable_progress_bar=False,
152+
enable_model_summary=False,
153+
)
154+
# assert strategy attributes for device setting
155+
assert isinstance(trainer.strategy, MultiModelDDPStrategy)
156+
local_rank = trainer.strategy.local_rank
157+
assert trainer.strategy.root_device == torch.device(f"cuda:{local_rank}")
158+
159+
model = BasicGANGPU()
160+
161+
trainer.fit(model)
162+
163+
# assert after training, model is moved to CPU and memory is deallocated
164+
assert model.device == torch.device("cpu")
165+
cuda_memory = torch.cuda.memory_allocated()
166+
assert cuda_memory < model.start_cuda_memory

0 commit comments

Comments
 (0)