@@ -90,7 +90,12 @@ def get_experts_list(module: torch.nn.Module, model_type: str):
90
90
linear_names = ["w1" , "w2" , "w3" ]
91
91
elif any (
92
92
qwen_variant in model_type
93
- for qwen_variant in ["qwenmoeforcausallm" , "qwen2moeforcausallm" , "qwen3moeforcausallm" ]
93
+ for qwen_variant in [
94
+ "qwenmoeforcausallm" ,
95
+ "qwen2moeforcausallm" ,
96
+ "qwen3moeforcausallm" ,
97
+ "qwen3nextforcausallm" ,
98
+ ]
94
99
):
95
100
linear_names = ["gate_proj" , "down_proj" , "up_proj" ]
96
101
else :
@@ -333,6 +338,7 @@ def is_moe(module: nn.Module) -> bool:
333
338
"DeepseekMoE" .lower (),
334
339
"Qwen2MoeSparseMoeBlock" .lower (),
335
340
"Qwen3MoeSparseMoeBlock" .lower (),
341
+ "Qwen3NextSparseMoeBlock" .lower (),
336
342
]
337
343
)
338
344
@@ -987,7 +993,13 @@ def module_match_name_list(module, name_list):
987
993
return any (name .lower () in type (module ).__name__ .lower () for name in name_list )
988
994
989
995
if module_match_name_list (
990
- module , ["Qwen2MoeSparseMoeBlock" , "Qwen3MoeSparseMoeBlock" , "DeepseekMoE" ]
996
+ module ,
997
+ [
998
+ "Qwen2MoeSparseMoeBlock" ,
999
+ "Qwen3MoeSparseMoeBlock" ,
1000
+ "Qwen3NextSparseMoeBlock" ,
1001
+ "DeepseekMoE" ,
1002
+ ],
991
1003
):
992
1004
return ["gate_proj" , "down_proj" , "up_proj" ]
993
1005
elif module_match_name_list (module , ["MixtralMoeSparseMoeBlock" ]):
0 commit comments