Skip to content

Commit c442fc3

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 8121337 commit c442fc3

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

tests/tests_pytorch/strategies/test_multi_model_ddp.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,27 @@
2020
from lightning.pytorch import Trainer
2121
from lightning.pytorch.strategies import MultiModelDDPStrategy
2222
from lightning.pytorch.trainer import seed_everything
23-
from tests_pytorch.helpers.runif import RunIf
2423
from tests_pytorch.helpers.advanced_models import BasicGAN
24+
from tests_pytorch.helpers.runif import RunIf
2525

2626

2727
@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True)
2828
def test_multi_gpu_with_multi_model_ddp_fit_only(tmp_path):
2929
dm = BasicGAN.train_dataloader()
3030
model = BasicGAN()
31-
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy())
31+
trainer = Trainer(
32+
default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy()
33+
)
3234
trainer.fit(model, datamodule=dm)
3335

3436

3537
@RunIf(min_cuda_gpus=2, standalone=True, sklearn=True)
3638
def test_multi_gpu_with_multi_model_ddp_predict_only(tmp_path):
3739
dm = BasicGAN.train_dataloader()
3840
model = BasicGAN()
39-
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy())
41+
trainer = Trainer(
42+
default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy()
43+
)
4044
trainer.predict(model, datamodule=dm)
4145

4246

@@ -45,7 +49,9 @@ def test_multi_gpu_multi_model_ddp_fit_predict(tmp_path):
4549
seed_everything(4321)
4650
dm = BasicGAN.train_dataloader()
4751
model = BasicGAN()
48-
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy())
52+
trainer = Trainer(
53+
default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy()
54+
)
4955
trainer.fit(model, datamodule=dm)
5056
trainer.predict(model, datamodule=dm)
5157

@@ -73,7 +79,9 @@ def test_find_unused_parameters_ddp_spawn_raises():
7379
max_steps=2,
7480
logger=False,
7581
)
76-
with pytest.raises(ProcessRaisedException, match="It looks like your LightningModule has parameters that were not used in"):
82+
with pytest.raises(
83+
ProcessRaisedException, match="It looks like your LightningModule has parameters that were not used in"
84+
):
7785
trainer.fit(UnusedParametersBasicGAN())
7886

7987

0 commit comments

Comments
 (0)