Skip to content

Commit c604fdf

Browse files
awaelchlilexierule
authored andcommitted
Simplify some ddp-spawn tests #10921
1 parent e4f9656 commit c604fdf

File tree

2 files changed

+16
-41
lines changed

2 files changed

+16
-41
lines changed

tests/plugins/test_sharded_plugin.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66

77
from pytorch_lightning import LightningModule, Trainer
8-
from pytorch_lightning.callbacks import Callback
98
from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin
109
from pytorch_lightning.trainer.states import TrainerFn
1110
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
@@ -31,43 +30,23 @@ def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_v
3130

3231

3332
@RunIf(fairscale=True)
34-
@pytest.mark.parametrize(["strategy"], [("ddp_sharded",), ("ddp_sharded_spawn",)])
35-
def test_sharded_ddp_choice(tmpdir, strategy):
33+
@pytest.mark.parametrize(
34+
"strategy,expected", [("ddp_sharded", DDPShardedPlugin), ("ddp_sharded_spawn", DDPSpawnShardedPlugin)]
35+
)
36+
def test_sharded_ddp_choice(tmpdir, strategy, expected):
3637
"""Test to ensure that plugin is correctly chosen."""
37-
38-
class CB(Callback):
39-
def on_fit_start(self, trainer, pl_module):
40-
if strategy == "ddp_sharded":
41-
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
42-
elif strategy == "ddp_sharded_spawn":
43-
assert isinstance(trainer.accelerator.training_type_plugin, DDPSpawnShardedPlugin)
44-
raise SystemExit()
45-
46-
model = BoringModel()
47-
trainer = Trainer(fast_dev_run=True, strategy=strategy, callbacks=[CB()])
48-
49-
with pytest.raises(SystemExit):
50-
trainer.fit(model)
38+
trainer = Trainer(fast_dev_run=True, strategy=strategy)
39+
assert isinstance(trainer.accelerator.training_type_plugin, expected)
5140

5241

5342
@RunIf(min_gpus=1, fairscale=True)
54-
@pytest.mark.parametrize(["strategy"], [("ddp_sharded",), ("ddp_sharded_spawn",)])
55-
def test_ddp_choice_sharded_amp(tmpdir, strategy):
43+
@pytest.mark.parametrize(
44+
"strategy,expected", [("ddp_sharded", DDPShardedPlugin), ("ddp_sharded_spawn", DDPSpawnShardedPlugin)]
45+
)
46+
def test_ddp_choice_sharded_amp(tmpdir, strategy, expected):
5647
"""Test to ensure that plugin native amp plugin is correctly chosen when using sharded."""
57-
58-
class CB(Callback):
59-
def on_fit_start(self, trainer, pl_module):
60-
if strategy == "ddp_sharded":
61-
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)
62-
elif strategy == "ddp_sharded_spawn":
63-
assert isinstance(trainer.accelerator.training_type_plugin, DDPSpawnShardedPlugin)
64-
raise SystemExit()
65-
66-
model = BoringModel()
67-
trainer = Trainer(fast_dev_run=True, gpus=1, precision=16, strategy=strategy, callbacks=[CB()])
68-
69-
with pytest.raises(SystemExit):
70-
trainer.fit(model)
48+
trainer = Trainer(fast_dev_run=True, gpus=1, precision=16, strategy=strategy)
49+
assert isinstance(trainer.accelerator.training_type_plugin, expected)
7150

7251

7352
@RunIf(skip_windows=True, fairscale=True)

tests/trainer/test_trainer.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,14 +1502,10 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, *arg
15021502
def test_spawn_predict_return_predictions(_, __, accelerator):
15031503
"""Test that `return_predictions=True` raise a MisconfigurationException with spawn training type plugins."""
15041504
model = BoringModel()
1505-
1506-
def run(expected_plugin, **trainer_kwargs):
1507-
trainer = Trainer(**trainer_kwargs, fast_dev_run=True)
1508-
assert isinstance(trainer.training_type_plugin, expected_plugin)
1509-
with pytest.raises(MisconfigurationException, match="`return_predictions` should be set to `False`"):
1510-
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True)
1511-
1512-
run(DDPSpawnPlugin, accelerator=accelerator, strategy="ddp_spawn", devices=2)
1505+
trainer = Trainer(accelerator=accelerator, strategy="ddp_spawn", devices=2, fast_dev_run=True)
1506+
assert isinstance(trainer.training_type_plugin, DDPSpawnPlugin)
1507+
with pytest.raises(MisconfigurationException, match="`return_predictions` should be set to `False`"):
1508+
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=True)
15131509

15141510

15151511
@pytest.mark.parametrize("return_predictions", [None, False, True])

0 commit comments

Comments
 (0)