20
20
from lightning .pytorch import Trainer
21
21
from lightning .pytorch .strategies import MultiModelDDPStrategy
22
22
from lightning .pytorch .trainer import seed_everything
23
- from tests_pytorch .helpers .runif import RunIf
24
23
from tests_pytorch .helpers .advanced_models import BasicGAN
24
+ from tests_pytorch .helpers .runif import RunIf
25
25
26
26
27
27
@RunIf (min_cuda_gpus = 2 , standalone = True , sklearn = True )
28
28
def test_multi_gpu_with_multi_model_ddp_fit_only (tmp_path ):
29
29
dm = BasicGAN .train_dataloader ()
30
30
model = BasicGAN ()
31
- trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ())
31
+ trainer = Trainer (
32
+ default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ()
33
+ )
32
34
trainer .fit (model , datamodule = dm )
33
35
34
36
35
37
@RunIf (min_cuda_gpus = 2 , standalone = True , sklearn = True )
36
38
def test_multi_gpu_with_multi_model_ddp_predict_only (tmp_path ):
37
39
dm = BasicGAN .train_dataloader ()
38
40
model = BasicGAN ()
39
- trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ())
41
+ trainer = Trainer (
42
+ default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ()
43
+ )
40
44
trainer .predict (model , datamodule = dm )
41
45
42
46
@@ -45,7 +49,9 @@ def test_multi_gpu_multi_model_ddp_fit_predict(tmp_path):
45
49
seed_everything (4321 )
46
50
dm = BasicGAN .train_dataloader ()
47
51
model = BasicGAN ()
48
- trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ())
52
+ trainer = Trainer (
53
+ default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ()
54
+ )
49
55
trainer .fit (model , datamodule = dm )
50
56
trainer .predict (model , datamodule = dm )
51
57
@@ -73,7 +79,9 @@ def test_find_unused_parameters_ddp_spawn_raises():
73
79
max_steps = 2 ,
74
80
logger = False ,
75
81
)
76
- with pytest .raises (ProcessRaisedException , match = "It looks like your LightningModule has parameters that were not used in" ):
82
+ with pytest .raises (
83
+ ProcessRaisedException , match = "It looks like your LightningModule has parameters that were not used in"
84
+ ):
77
85
trainer .fit (UnusedParametersBasicGAN ())
78
86
79
87
0 commit comments