Skip to content

Commit cdc8996

Browse files
Qwen3 TRT-LLM/HF export (#188)
Co-authored-by: Michael Feil <[email protected]>
1 parent 7af33d2 commit cdc8996

File tree

2 files changed

+113
-14
lines changed

2 files changed

+113
-14
lines changed

modelopt/torch/export/layer_utils.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def is_moe(module: nn.Module) -> bool:
324324
"PhimoeSparseMoeBlock".lower(),
325325
"DeepseekMoE".lower(),
326326
"Qwen2MoeSparseMoeBlock".lower(),
327+
"Qwen3MoeSparseMoeBlock".lower(),
327328
]
328329

329330

@@ -969,26 +970,75 @@ def get_stacked_scaling_factors(experts, get_function, module_name):
969970
return config
970971

971972

972-
@contextmanager
973-
def set_zero_amax_for_uncalibrated_experts(experts: nn.Module):
974-
"""For experts that does not have valid amax value of input quantizer, we set them to 0."""
973+
def get_expert_linear_names(module: nn.Module) -> list[str]:
974+
"""Get the list of linear names for the experts."""
975+
if type(module).__name__.lower() in [
976+
"Qwen2MoeSparseMoeBlock".lower(),
977+
"Qwen3MoeSparseMoeBlock".lower(),
978+
"DeepseekMoE".lower(),
979+
]:
980+
return ["gate_proj", "down_proj", "up_proj"]
981+
elif type(module).__name__.lower() in "MixtralMoeSparseMoeBlock".lower():
982+
return ["linear_fc1", "linear_fc2"]
983+
elif type(module).__name__.lower() in "DBRXMoeSparseMoeBlock".lower():
984+
return ["w1_linear", "w2_linear", "v1_linear"]
985+
else:
986+
# assuing w1, w2, w3 by default
987+
return ["w1", "w2", "w3"]
988+
989+
990+
def set_amax_for_uncalibrated_experts(experts: nn.Module, set_amax_value: float | None = None):
991+
"""Set amax of uncalibrated experts to a given value or the max of existing amax value from other experts.
992+
993+
Args:
994+
experts: a list of experts
995+
set_amax_value: set amax value to the given value.
996+
If None, set amax value to the max of existing amax value from other experts.
997+
998+
Returns:
999+
uncalibrated_experts: a list of uncalibrated experts
1000+
"""
9751001
uncalibrated_experts = []
1002+
# get the max amax value from all experts
1003+
if set_amax_value is None:
1004+
amax_values = [
1005+
module.input_quantizer.amax
1006+
for module in experts
1007+
if (
1008+
hasattr(module, "input_quantizer")
1009+
and module.input_quantizer is not None
1010+
and module.input_quantizer.is_enabled
1011+
)
1012+
and module.input_quantizer.amax is not None
1013+
]
1014+
if len(amax_values) == 0:
1015+
return uncalibrated_experts
1016+
set_amax_value = torch.max(torch.stack(amax_values))
1017+
9761018
for module in experts:
9771019
if (
9781020
hasattr(module, "input_quantizer")
9791021
and module.input_quantizer is not None
9801022
and module.input_quantizer.is_enabled
9811023
) and module.input_quantizer.amax is None:
9821024
warn(
983-
f"Missing amax value for {module} input_quantizer. Setting it to 0 for checkpoint export. "
1025+
f"Missing amax value for {module} input_quantizer. Setting it to {set_amax_value} for export. "
9841026
f"This typically occurs in MoE models when certain experts are not activated during calibration. "
9851027
f"Consider increasing your calibration dataset size to ensure all experts are exercised."
9861028
)
9871029
# Use float32 dtype explicitly to ensure we create a floating point tensor
9881030
module.input_quantizer.amax = torch.tensor(
989-
0.0, dtype=torch.float32, device=module.weight_quantizer.amax.device
1031+
set_amax_value, dtype=torch.float32, device=module.weight_quantizer.amax.device
9901032
)
9911033
uncalibrated_experts.append(module)
1034+
1035+
1036+
@contextmanager
1037+
def set_amax_for_uncalibrated_experts_context(
1038+
experts: nn.Module, set_amax_value: float | None = None
1039+
):
1040+
"""Set amax for uncalibrated experts in a context manager."""
1041+
uncalibrated_experts = set_amax_for_uncalibrated_experts(experts, set_amax_value)
9921042
yield
9931043
if uncalibrated_experts:
9941044
for module in uncalibrated_experts:
@@ -1022,12 +1072,13 @@ def build_stacked_experts(
10221072
)
10231073

10241074
# Set amax to 0 for uncalibrated experts
1025-
with set_zero_amax_for_uncalibrated_experts(
1075+
with set_amax_for_uncalibrated_experts_context(
10261076
[
10271077
expert_getter(experts, i, module_name)
10281078
for module_name in linear_names
10291079
for i in range(num_experts)
1030-
]
1080+
],
1081+
0, # set amax to 0 for uncalibrated experts as we will calculate max across all experts later
10311082
):
10321083
# Pre-fuse W1 and W3
10331084
if len(linear_names) == 3:
@@ -1121,12 +1172,14 @@ def build_moe_config(module: nn.Module, decoder_type) -> MOEConfig:
11211172
)
11221173
elif decoder_type == "qwen":
11231174
config.router = build_linear_config(module.gate, LINEAR_ROW)
1124-
preprocess_linear_fusion([module.shared_expert.gate_proj, module.shared_expert.up_proj])
1125-
config.shared_expert = build_mlp_config(
1126-
module.shared_expert, decoder_type, merge_gate_fc=True
1127-
)
1128-
config.shared_expert_gate = build_linear_config(module.shared_expert_gate, LINEAR_ROW)
1129-
config.shared_expert_gate.tp = False
1175+
# Qwen3 doesn't have shared expert
1176+
if hasattr(module, "shared_expert"):
1177+
preprocess_linear_fusion([module.shared_expert.gate_proj, module.shared_expert.up_proj])
1178+
config.shared_expert = build_mlp_config(
1179+
module.shared_expert, decoder_type, merge_gate_fc=True
1180+
)
1181+
config.shared_expert_gate = build_linear_config(module.shared_expert_gate, LINEAR_ROW)
1182+
config.shared_expert_gate.tp = False
11301183
else:
11311184
raise NotImplementedError(f"{decoder_type} not supported")
11321185

modelopt/torch/export/unified_export_hf.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Code that export quantized Hugging Face models for deployment."""
1717

18+
import collections.abc
1819
import json
1920
import tempfile
2021
import warnings
@@ -29,7 +30,14 @@
2930
from modelopt.torch.quantization.nn import SequentialQuantizer
3031

3132
from .convert_hf_config import convert_hf_quant_config_format
32-
from .layer_utils import get_experts_list, is_layernorm, is_moe, is_quantlinear
33+
from .layer_utils import (
34+
get_expert_linear_names,
35+
get_experts_list,
36+
is_layernorm,
37+
is_moe,
38+
is_quantlinear,
39+
set_amax_for_uncalibrated_experts,
40+
)
3341
from .model_config import (
3442
KV_CACHE_FP8,
3543
KV_CACHE_NVFP4,
@@ -184,6 +192,44 @@ def _export_hf_checkpoint(
184192
root = getattr(root, "layers", root)
185193
layer_pool = {f"model.layers.{name}": sub_module for name, sub_module in root.named_modules()}
186194

195+
# Handle input quantizers of experts that are not calibrated
196+
for name, sub_module in model.named_modules():
197+
if is_moe(sub_module) and hasattr(sub_module, "experts"):
198+
expert_linear_names = get_expert_linear_names(sub_module)
199+
for linear_name in expert_linear_names:
200+
# Handle DBRX experts specifically
201+
if "QuantDbrxExperts" in type(sub_module.experts).__name__:
202+
# For DBRX, experts are in sub_module.experts.mlp and linear layers are ModuleLists
203+
experts_mlp = sub_module.experts.mlp
204+
if hasattr(experts_mlp, linear_name):
205+
linear_modulelist = getattr(experts_mlp, linear_name)
206+
if hasattr(linear_modulelist, "__iter__"):
207+
set_amax_for_uncalibrated_experts(list(linear_modulelist))
208+
elif isinstance(sub_module.experts, collections.abc.Iterable):
209+
# For other MoE models (like Mixtral) with iterable experts
210+
try:
211+
set_amax_for_uncalibrated_experts(
212+
[getattr(expert, linear_name) for expert in sub_module.experts]
213+
)
214+
except AttributeError as e:
215+
# Provide more helpful debugging information
216+
expert_types = [type(expert).__name__ for expert in sub_module.experts]
217+
raise AttributeError(
218+
f"Failed to access attribute '{linear_name}' on experts. "
219+
f"MoE module type: {type(sub_module).__name__}, "
220+
f"Expert types: {expert_types}, "
221+
f"Expected linear names: {expert_linear_names}. "
222+
f"This suggests the get_expert_linear_names function may need "
223+
f"to be updated for this model architecture. "
224+
f"Original error: {e}"
225+
) from e
226+
else:
227+
# Unsupported MoE model structure
228+
raise NotImplementedError(
229+
f"MoE model with experts type '{type(sub_module.experts).__name__}' is not supported in export."
230+
f"Please file an issue or add support for this model architecture."
231+
)
232+
187233
# NOTE: Speculative decoding models have extra modules that may be quantized
188234
# Need to add these modules to the layer_pool
189235
for key in SPECULATIVE_DECODING_MODULE_NAMES:

0 commit comments

Comments
 (0)