From 49c23e9aacf29178807e97e96ff8b61c94083d5f Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Tue, 16 Sep 2025 01:37:36 +0000 Subject: [PATCH] Added support for qwen3-next quantization and export Signed-off-by: Kinjal Patel --- modelopt/torch/export/layer_utils.py | 16 ++++++++++++++-- .../torch/quantization/plugins/huggingface.py | 10 ++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index c35491283..e35ee070f 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -90,7 +90,12 @@ def get_experts_list(module: torch.nn.Module, model_type: str): linear_names = ["w1", "w2", "w3"] elif any( qwen_variant in model_type - for qwen_variant in ["qwenmoeforcausallm", "qwen2moeforcausallm", "qwen3moeforcausallm"] + for qwen_variant in [ + "qwenmoeforcausallm", + "qwen2moeforcausallm", + "qwen3moeforcausallm", + "qwen3nextforcausallm", + ] ): linear_names = ["gate_proj", "down_proj", "up_proj"] else: @@ -333,6 +338,7 @@ def is_moe(module: nn.Module) -> bool: "DeepseekMoE".lower(), "Qwen2MoeSparseMoeBlock".lower(), "Qwen3MoeSparseMoeBlock".lower(), + "Qwen3NextSparseMoeBlock".lower(), ] ) @@ -987,7 +993,13 @@ def module_match_name_list(module, name_list): return any(name.lower() in type(module).__name__.lower() for name in name_list) if module_match_name_list( - module, ["Qwen2MoeSparseMoeBlock", "Qwen3MoeSparseMoeBlock", "DeepseekMoE"] + module, + [ + "Qwen2MoeSparseMoeBlock", + "Qwen3MoeSparseMoeBlock", + "Qwen3NextSparseMoeBlock", + "DeepseekMoE", + ], ): return ["gate_proj", "down_proj", "up_proj"] elif module_match_name_list(module, ["MixtralMoeSparseMoeBlock"]): diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index da2a18c08..061e71dba 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -559,6 +559,16 @@ def top_k(self, value): except ImportError: pass +try: + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock + + if Qwen3NextSparseMoeBlock not in QuantModuleRegistry: + QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})( + _QuantMoeSparseMoe + ) +except ImportError: + pass + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`.