Skip to content

Commit d406aa1

Browse files
Added support for qwen3-next quantization and export (#323)
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 76e8ce2 commit d406aa1

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

modelopt/torch/export/layer_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ def get_experts_list(module: torch.nn.Module, model_type: str):
9090
linear_names = ["w1", "w2", "w3"]
9191
elif any(
9292
qwen_variant in model_type
93-
for qwen_variant in ["qwenmoeforcausallm", "qwen2moeforcausallm", "qwen3moeforcausallm"]
93+
for qwen_variant in [
94+
"qwenmoeforcausallm",
95+
"qwen2moeforcausallm",
96+
"qwen3moeforcausallm",
97+
"qwen3nextforcausallm",
98+
]
9499
):
95100
linear_names = ["gate_proj", "down_proj", "up_proj"]
96101
else:
@@ -333,6 +338,7 @@ def is_moe(module: nn.Module) -> bool:
333338
"DeepseekMoE".lower(),
334339
"Qwen2MoeSparseMoeBlock".lower(),
335340
"Qwen3MoeSparseMoeBlock".lower(),
341+
"Qwen3NextSparseMoeBlock".lower(),
336342
]
337343
)
338344

@@ -987,7 +993,13 @@ def module_match_name_list(module, name_list):
987993
return any(name.lower() in type(module).__name__.lower() for name in name_list)
988994

989995
if module_match_name_list(
990-
module, ["Qwen2MoeSparseMoeBlock", "Qwen3MoeSparseMoeBlock", "DeepseekMoE"]
996+
module,
997+
[
998+
"Qwen2MoeSparseMoeBlock",
999+
"Qwen3MoeSparseMoeBlock",
1000+
"Qwen3NextSparseMoeBlock",
1001+
"DeepseekMoE",
1002+
],
9911003
):
9921004
return ["gate_proj", "down_proj", "up_proj"]
9931005
elif module_match_name_list(module, ["MixtralMoeSparseMoeBlock"]):

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,16 @@ def top_k(self, value):
559559
except ImportError:
560560
pass
561561

562+
try:
563+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
564+
565+
if Qwen3NextSparseMoeBlock not in QuantModuleRegistry:
566+
QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})(
567+
_QuantMoeSparseMoe
568+
)
569+
except ImportError:
570+
pass
571+
562572

563573
class _QuantGptOssExperts(_QuantFunctionalMixin):
564574
"""Quantized wrapper for `transformers.GptOssExperts`.

0 commit comments

Comments
 (0)