@@ -208,6 +208,45 @@ def _new_process_model_before_weight_loading(self, model, *args, **kwargs):
208208 pass
209209
210210
211+ def deepspeed_set_z3_leaf_modules (model ):
212+ if not is_deepspeed_zero3_enabled ():
213+ return
214+ try :
215+ architecture = model .config .architectures [0 ]
216+ except Exception :
217+ return
218+ z3_leaf_modules = None
219+ if architecture == 'Qwen3VLMoeForConditionalGeneration' :
220+ from transformers .models .qwen3_vl_moe .modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock
221+ z3_leaf_modules = [Qwen3VLMoeTextSparseMoeBlock ]
222+ elif architecture == 'Qwen3OmniMoeForConditionalGeneration' :
223+ from transformers .models .qwen3_omni_moe .modeling_qwen3_omni_moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
224+ z3_leaf_modules = [Qwen3OmniMoeThinkerTextSparseMoeBlock ]
225+ elif architecture == 'Qwen2MoeForCausalLM' :
226+ from transformers .models .qwen2_moe .modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
227+ z3_leaf_modules = [Qwen2MoeSparseMoeBlock ]
228+ elif architecture == 'Qwen3MoeForCausalLM' :
229+ from transformers .models .qwen3_moe .modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
230+ z3_leaf_modules = [Qwen3MoeSparseMoeBlock ]
231+ elif architecture == 'Glm4MoeForCausalLM' :
232+ from transformers .models .glm4_moe .modeling_glm4_moe import Glm4MoeMoE
233+ z3_leaf_modules = [Glm4MoeMoE ]
234+ elif architecture == 'Glm4vMoeForConditionalGeneration' :
235+ from transformers .models .glm4v_moe .modeling_glm4v_moe import Glm4vMoeTextMoE
236+ z3_leaf_modules = [Glm4vMoeTextMoE ]
237+ elif architecture == 'GptOssForCausalLM' :
238+ from transformers .models .gpt_oss .modeling_gpt_oss import GptOssMLP
239+ z3_leaf_modules = [GptOssMLP ]
240+ elif architecture == 'Llama4ForCausalLM' :
241+ from transformers .models .llama4 .modeling_llama4 import Llama4TextMoe
242+ z3_leaf_modules = [Llama4TextMoe ]
243+
244+ if z3_leaf_modules :
245+ from deepspeed .utils import set_z3_leaf_modules
246+ set_z3_leaf_modules (model , z3_leaf_modules )
247+ logger .info (f'Setting z3_leaf_modules: { z3_leaf_modules } ' )
248+
249+
211250def get_model_tokenizer_from_local (model_dir : str ,
212251 model_info : ModelInfo ,
213252 model_kwargs : Dict [str , Any ],
@@ -329,6 +368,8 @@ def get_model_tokenizer_from_local(model_dir: str,
329368 if model is not None :
330369 # fix seq classification task
331370 HfConfigFactory .set_model_config_attr (model , 'pad_token_id' , pad_token )
371+ # deepspeed zero3
372+ deepspeed_set_z3_leaf_modules (model )
332373
333374 return model , tokenizer
334375
0 commit comments