55import torch
66
77from pytorch_lightning import LightningModule , Trainer
8- from pytorch_lightning .callbacks import Callback
98from pytorch_lightning .plugins import DDPShardedPlugin , DDPSpawnShardedPlugin
109from pytorch_lightning .trainer .states import TrainerFn
1110from pytorch_lightning .utilities import _FAIRSCALE_AVAILABLE
@@ -31,43 +30,23 @@ def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_v
3130
3231
3332@RunIf (fairscale = True )
34- @pytest .mark .parametrize (["strategy" ], [("ddp_sharded" ,), ("ddp_sharded_spawn" ,)])
35- def test_sharded_ddp_choice (tmpdir , strategy ):
33+ @pytest .mark .parametrize (
34+ "strategy,expected" , [("ddp_sharded" , DDPShardedPlugin ), ("ddp_sharded_spawn" , DDPSpawnShardedPlugin )]
35+ )
36+ def test_sharded_ddp_choice (tmpdir , strategy , expected ):
3637 """Test to ensure that plugin is correctly chosen."""
37-
38- class CB (Callback ):
39- def on_fit_start (self , trainer , pl_module ):
40- if strategy == "ddp_sharded" :
41- assert isinstance (trainer .accelerator .training_type_plugin , DDPShardedPlugin )
42- elif strategy == "ddp_sharded_spawn" :
43- assert isinstance (trainer .accelerator .training_type_plugin , DDPSpawnShardedPlugin )
44- raise SystemExit ()
45-
46- model = BoringModel ()
47- trainer = Trainer (fast_dev_run = True , strategy = strategy , callbacks = [CB ()])
48-
49- with pytest .raises (SystemExit ):
50- trainer .fit (model )
38+ trainer = Trainer (fast_dev_run = True , strategy = strategy )
39+ assert isinstance (trainer .accelerator .training_type_plugin , expected )
5140
5241
5342@RunIf (min_gpus = 1 , fairscale = True )
54- @pytest .mark .parametrize (["strategy" ], [("ddp_sharded" ,), ("ddp_sharded_spawn" ,)])
55- def test_ddp_choice_sharded_amp (tmpdir , strategy ):
43+ @pytest .mark .parametrize (
44+ "strategy,expected" , [("ddp_sharded" , DDPShardedPlugin ), ("ddp_sharded_spawn" , DDPSpawnShardedPlugin )]
45+ )
46+ def test_ddp_choice_sharded_amp (tmpdir , strategy , expected ):
5647 """Test to ensure that plugin native amp plugin is correctly chosen when using sharded."""
57-
58- class CB (Callback ):
59- def on_fit_start (self , trainer , pl_module ):
60- if strategy == "ddp_sharded" :
61- assert isinstance (trainer .accelerator .training_type_plugin , DDPShardedPlugin )
62- elif strategy == "ddp_sharded_spawn" :
63- assert isinstance (trainer .accelerator .training_type_plugin , DDPSpawnShardedPlugin )
64- raise SystemExit ()
65-
66- model = BoringModel ()
67- trainer = Trainer (fast_dev_run = True , gpus = 1 , precision = 16 , strategy = strategy , callbacks = [CB ()])
68-
69- with pytest .raises (SystemExit ):
70- trainer .fit (model )
48+ trainer = Trainer (fast_dev_run = True , gpus = 1 , precision = 16 , strategy = strategy )
49+ assert isinstance (trainer .accelerator .training_type_plugin , expected )
7150
7251
7352@RunIf (skip_windows = True , fairscale = True )
0 commit comments