@@ -414,3 +414,148 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init):
414414 zero_init_mock .assert_called_with (enabled = False , remote_device = None , config_dict_or_path = ANY )
415415 assert init_mock .call_count == int (not empty_init )
416416 assert model .layer .weight .dtype == torch .bfloat16
417+
418+
419+ @RunIf (min_cuda_gpus = 2 , standalone = True , deepspeed = True )
420+ def test_deepspeed_multigpu_stage_3_MiCS_support ():
421+ """Test to ensure ZeRO Stage 3 MiCS works with a parallel model."""
422+ strategy = DeepSpeedStrategy (stage = 3 )
423+ strategy .config ["zero_optimization" ]["stage" ] = 3
424+ strategy .config ["zero_optimization" ]["mics_shard_size" ] = 1
425+ strategy .config ["zero_optimization" ]["mics_hierarchical_params_gather" ] = False
426+
427+ fabric = Fabric (
428+ strategy = strategy ,
429+ accelerator = "cuda" ,
430+ devices = 2 ,
431+ precision = "16-mixed" ,
432+ )
433+ fabric .launch ()
434+
435+ def _make_block ():
436+ return nn .Sequential (nn .Linear (32 , 32 , bias = False ), nn .ReLU ())
437+
438+ with fabric .init_module ():
439+ model = nn .Sequential (* (_make_block () for _ in range (5 )), nn .Linear (32 , 3 ))
440+
441+ optimizer = torch .optim .Adam (model .parameters (), lr = 0.1 )
442+ model , optimizer = fabric .setup (model , optimizer )
443+
444+ x = torch .rand (2 , 32 , device = fabric .device )
445+ y = torch .ones (x .size (0 ), device = x .device , dtype = torch .long )
446+ x = model (x )
447+ x = x .float () # Ensure output is in float32 for softmax operation
448+ logits = F .softmax (x , dim = 1 )
449+ loss = F .cross_entropy (logits , y )
450+ fabric .backward (loss )
451+ optimizer .step ()
452+ optimizer .zero_grad ()
453+
454+ @RunIf (min_cuda_gpus = 2 , standalone = True , deepspeed = True )
455+ def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support ():
456+ """Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support"""
457+ strategy = DeepSpeedStrategy (stage = 3 , offload_params_device = "cpu" )
458+ strategy .config ["zero_optimization" ]["stage" ] = 3
459+ strategy .config ["zero_optimization" ]["mics_shard_size" ] = 1
460+ strategy .config ["zero_optimization" ]["mics_hierarchical_params_gather" ] = False
461+
462+ fabric = Fabric (
463+ strategy = strategy ,
464+ accelerator = "cuda" ,
465+ devices = 2 ,
466+ precision = "16-mixed" ,
467+ )
468+ fabric .launch ()
469+
470+ def _make_block ():
471+ return nn .Sequential (nn .Linear (32 , 32 , bias = False ), nn .ReLU ())
472+
473+ with fabric .init_module ():
474+ model = nn .Sequential (* (_make_block () for _ in range (5 )), nn .Linear (32 , 3 ))
475+
476+ optimizer = torch .optim .Adam (model .parameters (), lr = 0.1 )
477+ model , optimizer = fabric .setup (model , optimizer )
478+
479+ x = torch .rand (2 , 32 , device = fabric .device )
480+ y = torch .ones (x .size (0 ), device = x .device , dtype = torch .long )
481+ x = model (x )
482+ x = x .float () # Ensure output is in float32 for softmax operation
483+ logits = F .softmax (x , dim = 1 )
484+ loss = F .cross_entropy (logits , y )
485+ fabric .backward (loss )
486+ optimizer .step ()
487+ optimizer .zero_grad ()
488+
489+
490+ @RunIf (min_cuda_gpus = 2 , standalone = True , deepspeed = True )
491+ def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support ():
492+ """Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support."""
493+ strategy = DeepSpeedStrategy (stage = 3 , offload_params_device = "cpu" , offload_optimizer_device = "cpu" )
494+ strategy .config ["zero_optimization" ]["stage" ] = 3
495+ strategy .config ["zero_optimization" ]["mics_shard_size" ] = 1
496+ strategy .config ["zero_optimization" ]["mics_hierarchical_params_gather" ] = False
497+
498+ fabric = Fabric (
499+ strategy = strategy ,
500+ accelerator = "cuda" ,
501+ devices = 2 ,
502+ precision = "16-mixed" ,
503+ )
504+ fabric .launch ()
505+
506+ def _make_block ():
507+ return nn .Sequential (nn .Linear (32 , 32 , bias = False ), nn .ReLU ())
508+
509+ with fabric .init_module ():
510+ model = nn .Sequential (* (_make_block () for _ in range (5 )), nn .Linear (32 , 3 ))
511+
512+ optimizer = torch .optim .Adam (model .parameters (), lr = 0.1 )
513+ model , optimizer = fabric .setup (model , optimizer )
514+
515+ x = torch .rand (2 , 32 , device = fabric .device )
516+ y = torch .ones (x .size (0 ), device = x .device , dtype = torch .long )
517+ x = model (x )
518+ x = x .float () # Ensure output is in float32 for softmax operation
519+ logits = F .softmax (x , dim = 1 )
520+ loss = F .cross_entropy (logits , y )
521+ fabric .backward (loss )
522+ optimizer .step ()
523+ optimizer .zero_grad ()
524+
525+ @RunIf (min_cuda_gpus = 4 , standalone = True , deepspeed = True )
526+ def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support ():
527+ """Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' =
528+ True)."""
529+ strategy = DeepSpeedStrategy (stage = 3 )
530+ strategy .config ["zero_optimization" ]["stage" ] = 3
531+ strategy .config ["zero_optimization" ]["mics_shard_size" ] = 2
532+ strategy .config ["zero_optimization" ]["offload_param" ] = {}
533+ strategy .config ["zero_optimization" ]["offload_optimizer" ] = {}
534+ strategy .config ["zero_optimization" ]["mics_hierarchical_params_gather" ] = True
535+
536+ fabric = Fabric (
537+ strategy = strategy ,
538+ accelerator = "cuda" ,
539+ devices = 2 ,
540+ precision = "16-mixed" ,
541+ )
542+ fabric .launch ()
543+
544+ def _make_block ():
545+ return nn .Sequential (nn .Linear (32 , 32 , bias = False ), nn .ReLU ())
546+
547+ with fabric .init_module ():
548+ model = nn .Sequential (* (_make_block () for _ in range (5 )), nn .Linear (32 , 3 ))
549+
550+ optimizer = torch .optim .Adam (model .parameters (), lr = 0.1 )
551+ model , optimizer = fabric .setup (model , optimizer )
552+
553+ x = torch .rand (2 , 32 , device = fabric .device )
554+ y = torch .ones (x .size (0 ), device = x .device , dtype = torch .long )
555+ x = model (x )
556+ x = x .float () # Ensure output is in float32 for softmax operation
557+ logits = F .softmax (x , dim = 1 )
558+ loss = F .cross_entropy (logits , y )
559+ fabric .backward (loss )
560+ optimizer .step ()
561+ optimizer .zero_grad ()
0 commit comments