@@ -90,7 +90,12 @@ def get_experts_list(module: torch.nn.Module, model_type: str):
9090 linear_names = ["w1" , "w2" , "w3" ]
9191 elif any (
9292 qwen_variant in model_type
93- for qwen_variant in ["qwenmoeforcausallm" , "qwen2moeforcausallm" , "qwen3moeforcausallm" ]
93+ for qwen_variant in [
94+ "qwenmoeforcausallm" ,
95+ "qwen2moeforcausallm" ,
96+ "qwen3moeforcausallm" ,
97+ "qwen3nextforcausallm" ,
98+ ]
9499 ):
95100 linear_names = ["gate_proj" , "down_proj" , "up_proj" ]
96101 else :
@@ -333,6 +338,7 @@ def is_moe(module: nn.Module) -> bool:
333338 "DeepseekMoE" .lower (),
334339 "Qwen2MoeSparseMoeBlock" .lower (),
335340 "Qwen3MoeSparseMoeBlock" .lower (),
341+ "Qwen3NextSparseMoeBlock" .lower (),
336342 ]
337343 )
338344
@@ -987,7 +993,13 @@ def module_match_name_list(module, name_list):
987993 return any (name .lower () in type (module ).__name__ .lower () for name in name_list )
988994
989995 if module_match_name_list (
990- module , ["Qwen2MoeSparseMoeBlock" , "Qwen3MoeSparseMoeBlock" , "DeepseekMoE" ]
996+ module ,
997+ [
998+ "Qwen2MoeSparseMoeBlock" ,
999+ "Qwen3MoeSparseMoeBlock" ,
1000+ "Qwen3NextSparseMoeBlock" ,
1001+ "DeepseekMoE" ,
1002+ ],
9911003 ):
9921004 return ["gate_proj" , "down_proj" , "up_proj" ]
9931005 elif module_match_name_list (module , ["MixtralMoeSparseMoeBlock" ]):
0 commit comments