Skip to content

Commit 9409412

Browse files
authored
Fix AutoQuantize config for MoE (#641)
## What does this PR do? **Type of change:** Bug fix [aq_Qwen3-235B-A22B-Thinking-2507_scores.html](https://github.com/user-attachments/files/23915284/aq_Qwen3-235B-A22B-Thinking-2507_scores.html) **Overview:** For Qwen3 MoE, sensitivity should be measured at `layer.x.mlp` not `layer.x.mlp.experts` (`layer.x.mlp.experts` module forward is never called. Hence sensitivity was not correctly estimated previously). After this PR, Qwen3 MoE sensitivity is correctly estimated. (See attached file) ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: realAsma <[email protected]>
1 parent 9e280f4 commit 9409412

File tree

4 files changed

+33
-9
lines changed

4 files changed

+33
-9
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
get_model_type,
5151
)
5252
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
53-
from modelopt.torch.quantization.config import need_calibration
53+
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
5454
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
5555
from modelopt.torch.quantization.utils import is_quantized
5656
from modelopt.torch.utils.dataset_utils import (
@@ -155,7 +155,8 @@ def forward_step(model, batch):
155155
# AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration.
156156
num_score_steps=min(len(calib_dataloader), max(auto_quantize_score_size // batch_size, 1)),
157157
verbose=True,
158-
disabled_layers=["*lm_head*"],
158+
# Disable all default disabled layers such as lm_head, mlp.gate, router etc.
159+
disabled_layers=list(_default_disabled_quantizer_cfg.keys()),
159160
method=auto_quantize_method,
160161
checkpoint=auto_quantize_checkpoint,
161162
)

modelopt/torch/quantization/algorithms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ class AutoQuantizeGradientSearcher(_AutoQuantizeBaseSearcher):
755755

756756
score_module_rules = [
757757
# Use MLP layer output for gate_proj, up_proj, down_proj for Qwen3 like MoE models (local and shared experts)
758-
r"^(.*?\.mlp\.experts)\.\d+\.(gate_proj|up_proj|down_proj)$",
758+
r"^(.*?\.mlp)\.experts\.\d+\.(gate_proj|up_proj|down_proj)$",
759759
r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts
760760
r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts
761761
]

tests/_test_utils/torch/transformers_models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
LlamaForCausalLM,
3131
Qwen3Config,
3232
Qwen3ForCausalLM,
33+
Qwen3MoeConfig,
34+
Qwen3MoeForCausalLM,
3335
T5Config,
3436
T5ForConditionalGeneration,
3537
T5Tokenizer,
@@ -61,6 +63,28 @@ def get_tiny_qwen3(**config_kwargs) -> "Qwen3ForCausalLM":
6163
return tiny_qwen3
6264

6365

66+
def get_tiny_qwen3_moe(**config_kwargs) -> "Qwen3MoeForCausalLM":
67+
set_seed(SEED)
68+
69+
kwargs = {
70+
"hidden_size": 32,
71+
"intermediate_size": 32,
72+
"moe_intermediate_size": 32,
73+
"num_hidden_layers": 2,
74+
"num_attention_heads": 16,
75+
"num_key_value_heads": 2,
76+
"max_position_embeddings": 32,
77+
"vocab_size": 32,
78+
"num_experts": 4,
79+
"num_experts_per_tok": 2,
80+
"decoder_sparse_step": 1,
81+
}
82+
kwargs.update(**config_kwargs)
83+
tiny_qwen3_moe = Qwen3MoeForCausalLM(Qwen3MoeConfig(**kwargs))
84+
85+
return tiny_qwen3_moe
86+
87+
6488
def get_tiny_llama(**config_kwargs) -> LlamaForCausalLM:
6589
set_seed(SEED)
6690
kwargs = {

tests/unit/torch/quantization/plugins/test_huggingface.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from _test_utils.torch.transformers_models import (
2525
create_tiny_llama_dir,
2626
get_tiny_llama,
27+
get_tiny_qwen3_moe,
2728
tf_modelopt_state_and_output_tester,
2829
)
2930

@@ -137,12 +138,10 @@ def test_dbrx():
137138
assert torch.allclose(out_1[0], out_2[0])
138139

139140

140-
@pytest.mark.parametrize(
141-
"method",
142-
["gradient", "kl_div"],
143-
)
144-
def test_autoquantize_huggingface(method):
145-
model = get_tiny_llama()
141+
@pytest.mark.parametrize("method", ["gradient", "kl_div"])
142+
@pytest.mark.parametrize("model_provider", [get_tiny_llama, get_tiny_qwen3_moe])
143+
def test_autoquantize_huggingface(model_provider, method):
144+
model = model_provider()
146145
input_ids = model.dummy_inputs["input_ids"]
147146

148147
def forward_step(model, batch):

0 commit comments

Comments
 (0)