3333 FSDP2_SUPPORTED = False
3434
3535try :
36+ import torch .distributed .checkpoint as dcp
3637 from torch .distributed .checkpoint .state_dict import (
3738 StateDictOptions ,
39+ get_model_state_dict ,
3840 set_model_state_dict ,
3941 )
4042
@@ -163,8 +165,29 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
163165 )
164166 fsdp_mode = gpc .config .parallel .fsdp .get ("mode" , "v1" )
165167 fsdp_init_method = gpc .config .parallel .fsdp .get ("init_method" , "cuda" )
168+ if gpc .is_using_parallel_mode (ParallelMode .EXPERT ):
169+ assert gpc .get_world_size (ParallelMode .EXPERT_DATA ) * gpc .get_world_size (ParallelMode .EXPERT ) == gpc .get_world_size (ParallelMode .GLOBAL )
166170
167171 if fsdp_mode == "v1" :
172+ ignored_mod = []
173+ if gpc .is_using_parallel_mode (ParallelMode .EXPERT ):
174+ for layer_id , layer in enumerate (model .model .layers ):
175+ if layer_id >= gpc .config .model .first_k_dense_replace :
176+ # Should follow this modeling pattern if EP is enabled.
177+ # Change the expert module name if needed.
178+ # TODO: Make this part hard-coded or config-driven?
179+ layer .feed_forward .moe_layer .experts = FSDP (
180+ layer .feed_forward .moe_layer .experts ,
181+ process_group = gpc .get_group (ParallelMode .EXPERT_DATA ),
182+ sharding_strategy = ShardingStrategy .FULL_SHARD ,
183+ sync_module_states = fsdp_init_method != "cuda" , # sync model paramters
184+ forward_prefetch = True ,
185+ backward_prefetch = BackwardPrefetch .BACKWARD_PRE ,
186+ limit_all_gathers = True ,
187+ use_orig_params = True ,
188+ device_id = None if fsdp_init_method == "cuda" else get_current_device (), # needed for sync_module_states
189+ )
190+ ignored_mod .append (layer .feed_forward .moe_layer .experts )
168191 model = FSDP (
169192 module = model ,
170193 process_group = gpc .get_group (ParallelMode .GLOBAL ),
@@ -176,6 +199,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
176199 limit_all_gathers = True ,
177200 use_orig_params = True ,
178201 device_id = None if fsdp_init_method == "cuda" else get_current_device (), # needed for sync_module_states
202+ ignored_modules = ignored_mod ,
179203 )
180204 # For FSDP v1, to get ckpt resuming work normally, we do dummy forward.
181205 # This hack is needed due to FSDP v1 lazy initialization in model construction.
@@ -196,7 +220,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
196220 else :
197221 raise ValueError (f"Unsupported FSDP mode: { fsdp_mode } " )
198222
199- if is_using_hf () and not gpc .config .ckpt .get ("auto_resume" , False ):
223+ if not gpc .config .ckpt .get ("auto_resume" , False ):
200224 load_ckpt_info = gpc .config .ckpt .load_ckpt_info
201225 load_ckpt_path = load_ckpt_info .get ("path" , None )
202226 load_ckpt_content = load_ckpt_info .get ("content" , [])
@@ -205,16 +229,22 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
205229 "model" ,
206230 ), "If auto_resume=False and checkpoint path is given, only model can be loaded"
207231 if DCP_SUPPORTED :
208- hf = gpc .config .hf
209- mod = LazyObject (hf .mod , hf .mod_cls )
210- mod = mod .build ()
211- state_dict = mod .from_pretrained (
212- pretrained_model_name_or_path = load_ckpt_path , use_safetensors = True
213- ).state_dict ()
214- state_dict = {f"model.{ key } " : state_dict [key ].clone ().detach () for key in state_dict }
215- set_model_state_dict (
216- model = model , model_state_dict = state_dict , options = StateDictOptions (full_state_dict = True )
217- )
232+ if is_using_hf ():
233+ hf = gpc .config .hf
234+ mod = LazyObject (hf .mod , hf .mod_cls )
235+ mod = mod .build ()
236+ state_dict = mod .from_pretrained (
237+ pretrained_model_name_or_path = load_ckpt_path , use_safetensors = True
238+ ).state_dict ()
239+ state_dict = {f"model.{ key } " : state_dict [key ].clone ().detach () for key in state_dict }
240+ set_model_state_dict (
241+ model = model , model_state_dict = state_dict , options = StateDictOptions (full_state_dict = True )
242+ )
243+ else :
244+ state_dict = get_model_state_dict (model = model )
245+ state_dict = {key : state_dict [key ].clone ().detach () for key in state_dict }
246+ dcp .load (state_dict = state_dict , checkpoint_id = load_ckpt_path )
247+ set_model_state_dict (model = model , model_state_dict = state_dict )
218248 del state_dict
219249 internlm_accelerator .empty_cache ()
220250 else :
0 commit comments