55import types
66
77try :
8- from collections import abs as collections_abc # type: ignore[attr-defined]
8+ from collections import abc as collections_abc # type: ignore[attr-defined]
99except ImportError :
1010 import collections as collections_abc # type: ignore[no-redef]
1111
2828from atorch .distributed .distributed import (
2929 get_data_partition_rank_and_size ,
3030 local_rank ,
31+ parallel_group ,
3132 parallel_group_and_ranks ,
3233 parallel_group_size ,
3334 rank ,
@@ -396,6 +397,9 @@ def create_optim(self):
396397 src = ranks [0 ]
397398 torch .distributed ._broadcast_coalesced (process_group , module_states , int (250 * 1024 * 1024 ), src )
398399
400+ if "fsdp2" in self .pre_wrappers and parallel_group ("expert" ) is not None :
401+ self .optim_args ["foreach" ] = False
402+
399403 if not self .check_pipe_model ():
400404 if not self .optim_param_func :
401405 optim = self .optim_func (self .model .parameters (), ** self .optim_args )
@@ -416,6 +420,7 @@ def create_optim(self):
416420 and "ds_zero" not in self .post_wrappers
417421 and "zero2" not in self .post_wrappers
418422 and "fsdp" not in self .pre_wrappers
423+ and "fsdp2" not in self .pre_wrappers
419424 and "ds_3d_parallel" not in self .post_wrappers
420425 ):
421426 is_cuda = next (self .model .parameters ()).is_cuda
@@ -497,6 +502,8 @@ def adjust_wrappers(self):
497502 self .pre_wrappers .pop ("zero2" )
498503 if "fsdp" in self .pre_wrappers :
499504 self .pre_wrappers .pop ("fsdp" )
505+ if "fsdp2" in self .pre_wrappers :
506+ self .pre_wrappers .pop ("fsdp2" )
500507
501508 # DDP is supported and handled internally by PiPPy.
502509 if "ddp" in self .post_wrappers :
@@ -572,13 +579,18 @@ def adjust_wrappers(self):
572579 ds_3d_parallel_wrapper_exist = "ds_3d_parallel" in self .post_wrappers
573580 fairscale_zero2_wrapper_exist = "zero2" in self .post_wrappers
574581 fsdp_wrapper_exist = "fsdp" in self .pre_wrappers or "zero2" in self .pre_wrappers
582+ fsdp2_wrapper_exist = "fsdp2" in self .pre_wrappers
575583 tensor_parallel_wrapper_exist = "tp" in self .pre_wrappers
576584 ckpt_wrapper_exist = "checkpoint" in self .post_wrappers
577585 native_dynamo_wrapper_exist = "native_dynamo" in self .pre_wrappers
578586
579587 # remove ddp wrapper when using zero2
580588 if ddp_wrapper_exist and (
581- fairscale_zero2_wrapper_exist or fsdp_wrapper_exist or ds_zero_wrapper_exist or ds_3d_parallel_wrapper_exist
589+ fairscale_zero2_wrapper_exist
590+ or fsdp_wrapper_exist
591+ or ds_zero_wrapper_exist
592+ or ds_3d_parallel_wrapper_exist
593+ or fsdp2_wrapper_exist
582594 ):
583595 logger .info ("Found Zero, ds_3d_parallel, or pipe wrapper, remove ddp wrapper." )
584596 self .post_wrappers .pop ("ddp" )
@@ -587,21 +599,28 @@ def adjust_wrappers(self):
587599 logger .info ("Found fsdp and amp_native wrapper, turn on mixed_precision in FSDP" )
588600 _ , amp_native_config = self .post_wrappers ["amp_native" ]
589601 fp16_dtype = amp_native_config .get ("dtype" , torch .float16 )
590- mixed_precision_param = (
591- MixedPrecision (param_dtype = fp16_dtype , reduce_dtype = fp16_dtype , buffer_dtype = fp16_dtype )
592- if MixedPrecision
593- else True
594- )
602+ mixed_precision_param = {"param_dtype" : fp16_dtype , "reduce_dtype" : fp16_dtype , "buffer_dtype" : fp16_dtype }
595603 config = self .pre_wrappers ["fsdp" ][1 ] or {}
596604 config ["mixed_precision" ] = mixed_precision_param
597605 self .pre_wrappers ["fsdp" ] = (
598606 self .pre_wrappers ["fsdp" ][0 ],
599607 config ,
600608 )
609+ elif fsdp2_wrapper_exist and "amp_native" in self .post_wrappers :
610+ logger .info ("Found fsdp2 and amp_native wrapper, turn on mixed_precision in FSDP" )
611+ _ , amp_native_config = self .post_wrappers ["amp_native" ]
612+ fp16_dtype = amp_native_config .get ("dtype" , torch .float16 )
613+ mixed_precision_param = {"param_dtype" : fp16_dtype , "reduce_dtype" : fp16_dtype , "buffer_dtype" : fp16_dtype }
614+ config = self .pre_wrappers ["fsdp2" ][1 ] or {}
615+ config ["mixed_precision" ] = mixed_precision_param
616+ self .pre_wrappers ["fsdp2" ] = (
617+ self .pre_wrappers ["fsdp2" ][0 ],
618+ config ,
619+ )
601620
602621 # move dynamo_native wrapper behind ddp or fsdp (fsdp will adjusted later)
603622 # Note that dynamo_native wrapper and fsdp wrapper are pre-wrappers while ddp wrapper is a post-wrapper.
604- if native_dynamo_wrapper_exist and ddp_wrapper_exist and not fsdp_wrapper_exist :
623+ if native_dynamo_wrapper_exist and ddp_wrapper_exist and not fsdp_wrapper_exist and not fsdp2_wrapper_exist :
605624 # ddp wrapper is a post-wrapper. Popping dynamo_native wrapper from pre-wrappers
606625 # then insert it after ddp wrapper.
607626 post_wrappers_list = []
@@ -616,8 +635,13 @@ def adjust_wrappers(self):
616635
617636 if tensor_parallel_wrapper_exist :
618637 wrap_cls = None
638+ fsdp_wrapper = None
619639 if fsdp_wrapper_exist and torch_version () >= (1 , 12 , 0 ):
620640 fsdp_wrapper = self .pre_wrappers ["fsdp" ]
641+ elif fsdp2_wrapper_exist and torch_version () >= (1 , 12 , 0 ):
642+ fsdp_wrapper = self .pre_wrappers ["fsdp2" ]
643+
644+ if fsdp_wrapper is not None :
621645 fsdp_wrapper = list (fsdp_wrapper )
622646 if fsdp_wrapper [1 ] is None :
623647 fsdp_wrapper [1 ] = dict ()
@@ -644,15 +668,19 @@ def adjust_wrappers(self):
644668 leaf_modules = _propose_leaf_modules (wrap_cls )
645669 auto_wrap_cls = _propose_wrap_cls (leaf_modules )
646670
647- if fsdp_wrapper_exist and torch_version () >= (1 , 12 , 0 ):
671+ if ( fsdp_wrapper_exist or fsdp2_wrapper_exist ) and torch_version () >= (1 , 12 , 0 ):
648672 if "atorch_wrap_cls" in fsdp_config :
649673 if auto_wrap_cls is not None :
650674 fsdp_config ["atorch_wrap_cls" ] = auto_wrap_cls
651675 else :
652676 fsdp_config .pop ("atorch_wrap_cls" )
653677
654678 fsdp_wrapper [1 ] = fsdp_config
655- self .pre_wrappers ["fsdp" ] = tuple (fsdp_wrapper )
679+
680+ if fsdp_wrapper_exist :
681+ self .pre_wrappers ["fsdp" ] = tuple (fsdp_wrapper )
682+ elif fsdp2_wrapper_exist :
683+ self .pre_wrappers ["fsdp2" ] = tuple (fsdp_wrapper )
656684
657685 if ckpt_wrapper_exist :
658686 if auto_wrap_cls is not None :
@@ -671,7 +699,7 @@ def adjust_wrappers(self):
671699 tensor_parallel_wrapper_item = list (tensor_parallel_wrapper_item )
672700 tensor_parallel_wrapper_item [1 ] = list (tensor_parallel_wrapper_item [1 ])
673701 tensor_parallel_wrapper_item [1 ][1 ]["leaf_modules" ] = leaf_modules
674- if fsdp_wrapper_exist or pipe_wrapper_exist :
702+ if fsdp_wrapper_exist or fsdp2_wrapper_exist or pipe_wrapper_exist :
675703 tensor_parallel_wrapper_item [1 ][1 ]["defer_init" ] = True
676704 tensor_parallel_wrapper_item [1 ] = tuple (tensor_parallel_wrapper_item [1 ])
677705 tensor_parallel_wrapper_item = tuple (tensor_parallel_wrapper_item )
@@ -687,7 +715,7 @@ def adjust_wrappers(self):
687715 _insert_amp_config_for_tp_ckpt (amp_config )
688716
689717 # adjust pre_wrapper order
690- order_wrapper_name = ["half" , "module_replace" , "sequence_parallel" , "fp8" , "fsdp" , "native_dynamo" ]
718+ order_wrapper_name = ["half" , "module_replace" , "sequence_parallel" , "fp8" , "fsdp" , "fsdp2" , " native_dynamo" ]
691719 match_names = []
692720 for name in self .pre_wrappers :
693721 if name in order_wrapper_name :
0 commit comments