2020from lightning .pytorch import Trainer
2121from lightning .pytorch .strategies import MultiModelDDPStrategy
2222from lightning .pytorch .trainer import seed_everything
23- from tests_pytorch .helpers .runif import RunIf
2423from tests_pytorch .helpers .advanced_models import BasicGAN
24+ from tests_pytorch .helpers .runif import RunIf
2525
2626
2727@RunIf (min_cuda_gpus = 2 , standalone = True , sklearn = True )
2828def test_multi_gpu_with_multi_model_ddp_fit_only (tmp_path ):
2929 dm = BasicGAN .train_dataloader ()
3030 model = BasicGAN ()
31- trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ())
31+ trainer = Trainer (
32+ default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ()
33+ )
3234 trainer .fit (model , datamodule = dm )
3335
3436
3537@RunIf (min_cuda_gpus = 2 , standalone = True , sklearn = True )
3638def test_multi_gpu_with_multi_model_ddp_predict_only (tmp_path ):
3739 dm = BasicGAN .train_dataloader ()
3840 model = BasicGAN ()
39- trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ())
41+ trainer = Trainer (
42+ default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ()
43+ )
4044 trainer .predict (model , datamodule = dm )
4145
4246
@@ -45,7 +49,9 @@ def test_multi_gpu_multi_model_ddp_fit_predict(tmp_path):
4549 seed_everything (4321 )
4650 dm = BasicGAN .train_dataloader ()
4751 model = BasicGAN ()
48- trainer = Trainer (default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ())
52+ trainer = Trainer (
53+ default_root_dir = tmp_path , max_epochs = 1 , accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ()
54+ )
4955 trainer .fit (model , datamodule = dm )
5056 trainer .predict (model , datamodule = dm )
5157
@@ -73,7 +79,9 @@ def test_find_unused_parameters_ddp_spawn_raises():
7379 max_steps = 2 ,
7480 logger = False ,
7581 )
76- with pytest .raises (ProcessRaisedException , match = "It looks like your LightningModule has parameters that were not used in" ):
82+ with pytest .raises (
83+ ProcessRaisedException , match = "It looks like your LightningModule has parameters that were not used in"
84+ ):
7785 trainer .fit (UnusedParametersBasicGAN ())
7886
7987
0 commit comments