@@ -104,6 +104,7 @@ def prepare_scattermoe(
104104 rank : int = None ,
105105 world_size : int = None ,
106106 ep_degree : int = 1 ,
107+ disable_distributed : bool = False ,
107108 key_rep : str = KEY_REPLICATE ,
108109 key_ep : str = KEY_EXPERT_PARALLEL ,
109110 device_type : str = "cuda" ,
@@ -116,13 +117,10 @@ def prepare_scattermoe(
116117 # pylint: disable=import-outside-toplevel
117118 from .scattermoe import ScatterMoE
118119
119- ep_disabled = False
120- if ep_degree == 0 :
121- ep_disabled = True
122- # flow of code when EP not enabled is mostly same as
123- # with ep_degree set to 1. Therefore, we explicitly set
124- # ep_degree to 1 however handle it along with ep_disabled var
125- ep_degree = 1
120+ if disable_distributed and ep_degree > 1 :
121+ raise ValueError (
122+ "expert sharding can not be deferred to top level sharding protocol (e.g. FSDP) when ep_degree > 1"
123+ )
126124
127125 assert world_size % ep_degree == 0 , (
128126 f"world size ({ world_size } ) " f"not divisible by ep_size ({ ep_degree } )."
@@ -137,11 +135,7 @@ def prepare_scattermoe(
137135 # current rank of the device
138136 device = torch .device (f"{ device_type } :{ rank } " )
139137
140- if ep_disabled :
141- # Larger models result in OOM especially when loading
142- # all experts to the same GPU device (when EP disabled).
143- # For cases like FSDP + EP disabled, its memory efficient to
144- # load the model to CPU and hand it over to the FSDP.
138+ if ep_degree == 1 and disable_distributed and is_fsdp_enabled () and rank == 0 :
145139 device = torch .device ("cpu" )
146140
147141 # get the scattermoe conversion spec
@@ -158,7 +152,7 @@ def prepare_scattermoe(
158152
159153 rep_size = world_size // ep_degree
160154
161- if ep_degree == 1 and ( rep_size == 1 or ep_disabled ) :
155+ if ep_degree == 1 :
162156 # in this case no need for sharding
163157 device_mesh = None
164158 elif rep_size == 1 :
@@ -281,10 +275,10 @@ def prepare_scattermoe(
281275 )
282276
283277 if device_mesh is None :
284- if is_fsdp_enabled () and rank > 0 :
285- _init_scattermoe_context = init_empty_weights
286- else :
278+ if not is_fsdp_enabled () or is_local_dist_rank_0 ():
287279 _init_scattermoe_context = nullcontext
280+ else :
281+ _init_scattermoe_context = init_empty_weights
288282 else :
289283 # in this case we need to distribute parameters, so just initialize
290284 # the scattermoe module swap with empty weights,
@@ -337,7 +331,7 @@ def prepare_scattermoe(
337331 if device_mesh is None :
338332 # - if not on meta, just load the state dict
339333 # - and then put on the device
340- if rank == 0 or not is_fsdp_enabled ():
334+ if not is_fsdp_enabled () or is_local_dist_rank_0 ():
341335 moe .load_state_dict (sd )
342336 moe = moe .to (device )
343337 else :
0 commit comments