@@ -643,17 +643,66 @@ def on_train_batch_start(
643643
644644
645645@RunIf (min_gpus = 2 , deepspeed = True , special = True )
646- def test_deepspeed_multigpu_test (tmpdir , deepspeed_config ):
647- """
648- Test to ensure we can use DeepSpeed with just test using ZeRO Stage 3.
649- """
646+ def test_deepspeed_multigpu_test (tmpdir ):
647+ """Test to ensure we can use DeepSpeed with just test using ZeRO Stage 3."""
650648 model = ModelParallelBoringModel ()
651649 trainer = Trainer (
652650 default_root_dir = tmpdir , plugins = [DeepSpeedPlugin (stage = 3 )], gpus = 2 , fast_dev_run = True , precision = 16
653651 )
654652 trainer .test (model )
655653
656654
655+ @RunIf (min_gpus = 1 , deepspeed = True , special = True )
656+ def test_deepspeed_multigpu_partial_partition_parameters (tmpdir ):
657+ """Test to ensure that a module that defines a layer inside the ``__init__`` and ``configure_sharded_model``
658+ correctly converts all parameters to float16 when ``precision=16`` and runs successfully."""
659+
660+ class TestModel (ModelParallelBoringModel ):
661+ def __init__ (self ):
662+ super ().__init__ ()
663+ self .layer_2 = torch .nn .Linear (32 , 32 )
664+
665+ def configure_sharded_model (self ) -> None :
666+ self .layer = torch .nn .Linear (32 , 2 )
667+
668+ def forward (self , x ):
669+ x = self .layer_2 (x )
670+ return self .layer (x )
671+
672+ def on_train_epoch_start (self ) -> None :
673+ assert all ([x .dtype == torch .float16 for x in self .parameters ()])
674+
675+ model = TestModel ()
676+ trainer = Trainer (
677+ default_root_dir = tmpdir , plugins = [DeepSpeedPlugin (stage = 3 )], gpus = 1 , fast_dev_run = True , precision = 16
678+ )
679+ trainer .fit (model )
680+
681+
682+ @RunIf (min_gpus = 1 , deepspeed = True , special = True )
683+ def test_deepspeed_multigpu_test_rnn (tmpdir ):
684+ """Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when
685+ training with certain layers which will crash with explicit partitioning."""
686+
687+ class TestModel (BoringModel ):
688+ def __init__ (self ):
689+ super ().__init__ ()
690+ self .rnn = torch .nn .GRU (32 , 32 )
691+
692+ def on_train_epoch_start (self ) -> None :
693+ assert all ([x .dtype == torch .float16 for x in self .parameters ()])
694+
695+ model = TestModel ()
696+ trainer = Trainer (
697+ default_root_dir = tmpdir ,
698+ plugins = [DeepSpeedPlugin (stage = 3 , partition_module = False )],
699+ gpus = 1 ,
700+ fast_dev_run = True ,
701+ precision = 16 ,
702+ )
703+ trainer .fit (model )
704+
705+
657706@RunIf (deepspeed = True )
658707@mock .patch ("deepspeed.init_distributed" , autospec = True )
659708@pytest .mark .parametrize ("platform" , ["Linux" , "Windows" ])
0 commit comments