Skip to content

Commit 35663af

Browse files
authored
compat qwen3_vl zero3 (#6080)
1 parent e450967 commit 35663af

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

examples/models/qwen3_vl/zero2.sh renamed to examples/models/qwen3_vl/zero3.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# zero2: 70GiB
1+
# 2 * 42GiB
22
IMAGE_MAX_TOKEN_NUM=1024 \
33
NPROC_PER_NODE=2 \
44
CUDA_VISIBLE_DEVICES=0,1 \
@@ -30,7 +30,7 @@ swift sft \
3030
--max_length 2048 \
3131
--output_dir output \
3232
--warmup_ratio 0.05 \
33-
--deepspeed zero2 \
33+
--deepspeed zero3 \
3434
--use_liger_kernel true \
3535
--dataset_num_proc 4 \
3636
--dataloader_num_workers 4

swift/llm/model/register.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,45 @@ def _new_process_model_before_weight_loading(self, model, *args, **kwargs):
208208
pass
209209

210210

211+
def deepspeed_set_z3_leaf_modules(model):
212+
if not is_deepspeed_zero3_enabled():
213+
return
214+
try:
215+
architecture = model.config.architectures[0]
216+
except Exception:
217+
return
218+
z3_leaf_modules = None
219+
if architecture == 'Qwen3VLMoeForConditionalGeneration':
220+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock
221+
z3_leaf_modules = [Qwen3VLMoeTextSparseMoeBlock]
222+
elif architecture == 'Qwen3OmniMoeForConditionalGeneration':
223+
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
224+
z3_leaf_modules = [Qwen3OmniMoeThinkerTextSparseMoeBlock]
225+
elif architecture == 'Qwen2MoeForCausalLM':
226+
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
227+
z3_leaf_modules = [Qwen2MoeSparseMoeBlock]
228+
elif architecture == 'Qwen3MoeForCausalLM':
229+
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
230+
z3_leaf_modules = [Qwen3MoeSparseMoeBlock]
231+
elif architecture == 'Glm4MoeForCausalLM':
232+
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE
233+
z3_leaf_modules = [Glm4MoeMoE]
234+
elif architecture == 'Glm4vMoeForConditionalGeneration':
235+
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextMoE
236+
z3_leaf_modules = [Glm4vMoeTextMoE]
237+
elif architecture == 'GptOssForCausalLM':
238+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP
239+
z3_leaf_modules = [GptOssMLP]
240+
elif architecture == 'Llama4ForCausalLM':
241+
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
242+
z3_leaf_modules = [Llama4TextMoe]
243+
244+
if z3_leaf_modules:
245+
from deepspeed.utils import set_z3_leaf_modules
246+
set_z3_leaf_modules(model, z3_leaf_modules)
247+
logger.info(f'Setting z3_leaf_modules: {z3_leaf_modules}')
248+
249+
211250
def get_model_tokenizer_from_local(model_dir: str,
212251
model_info: ModelInfo,
213252
model_kwargs: Dict[str, Any],
@@ -329,6 +368,8 @@ def get_model_tokenizer_from_local(model_dir: str,
329368
if model is not None:
330369
# fix seq classification task
331370
HfConfigFactory.set_model_config_attr(model, 'pad_token_id', pad_token)
371+
# deepspeed zero3
372+
deepspeed_set_z3_leaf_modules(model)
332373

333374
return model, tokenizer
334375

0 commit comments

Comments
 (0)