@@ -419,6 +419,55 @@ def test_deepspeed_fp32_works(tmpdir):
419419 trainer .fit (model )
420420
421421
422+ @RunIf (min_gpus = 2 , deepspeed = True , special = True )
423+ def test_deepspeed_stage_3_save_warning (tmpdir ):
424+ """Test to ensure that DeepSpeed Stage 3 gives a warning when saving on rank zero."""
425+ model = BoringModel ()
426+ trainer = Trainer (
427+ default_root_dir = tmpdir , plugins = [DeepSpeedPlugin (stage = 3 )], gpus = 2 , fast_dev_run = True , precision = 16
428+ )
429+ trainer .fit (model )
430+ checkpoint_path = os .path .join (tmpdir , "model.pt" )
431+ with pytest .warns (UserWarning ) as record :
432+ # both ranks need to call save checkpoint
433+ trainer .save_checkpoint (checkpoint_path )
434+ if trainer .is_global_zero :
435+ assert len (record ) == 1
436+ match = "each worker will save a shard of the checkpoint within a directory."
437+ assert match in str (record [0 ].message )
438+
439+
440+ @RunIf (min_gpus = 1 , deepspeed = True , special = True )
441+ def test_deepspeed_multigpu_single_file (tmpdir ):
442+ """Test to ensure that DeepSpeed loads from a single file checkpoint."""
443+ model = BoringModel ()
444+ checkpoint_path = os .path .join (tmpdir , "model.pt" )
445+ trainer = Trainer (default_root_dir = tmpdir , fast_dev_run = True )
446+ trainer .fit (model )
447+ trainer .save_checkpoint (checkpoint_path )
448+
449+ trainer = Trainer (
450+ default_root_dir = tmpdir , plugins = [DeepSpeedPlugin (stage = 3 )], gpus = 1 , fast_dev_run = True , precision = 16
451+ )
452+ plugin = trainer .training_type_plugin
453+ assert isinstance (plugin , DeepSpeedPlugin )
454+ assert not plugin .load_full_weights
455+ with pytest .raises (MisconfigurationException , match = "DeepSpeed was unable to load the checkpoint." ):
456+ trainer .test (model , ckpt_path = checkpoint_path )
457+
458+ trainer = Trainer (
459+ default_root_dir = tmpdir ,
460+ plugins = [DeepSpeedPlugin (stage = 3 , load_full_weights = True )],
461+ gpus = 1 ,
462+ fast_dev_run = True ,
463+ precision = 16 ,
464+ )
465+ plugin = trainer .training_type_plugin
466+ assert isinstance (plugin , DeepSpeedPlugin )
467+ assert plugin .load_full_weights
468+ trainer .test (model , ckpt_path = checkpoint_path )
469+
470+
422471class ModelParallelClassificationModel (LightningModule ):
423472 def __init__ (self , lr : float = 0.01 , num_blocks : int = 5 ):
424473 super ().__init__ ()
0 commit comments