Skip to content

Commit d59ca04

Browse files
Support Qwen MOE with the TRT LLM C++ route (#155)
Co-authored-by: michaelfeil <[email protected]>
1 parent 9cb36cb commit d59ca04

File tree

6 files changed

+117
-74
lines changed

6 files changed

+117
-74
lines changed

modelopt/torch/export/layer_utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def is_moe(module: nn.Module) -> bool:
330330
"MoELayer".lower(),
331331
"PhimoeSparseMoeBlock".lower(),
332332
"DeepseekMoE".lower(),
333+
"Qwen2MoeSparseMoeBlock".lower(),
333334
]
334335

335336

@@ -594,7 +595,7 @@ def build_fused_linear_config(modules: list[nn.Module], linear_type: str) -> Lin
594595
config = build_linear_config(modules[0], linear_type=linear_type)
595596
config.weight = torch.cat([module.weight for module in modules], dim=0)
596597

597-
if config.weights_scaling_factor.numel() != 1:
598+
if config.weights_scaling_factor is not None and config.weights_scaling_factor.numel() != 1:
598599
config.weights_scaling_factor = torch.cat(
599600
[get_weight_scaling_factor(module) for module in modules], dim=0
600601
)
@@ -710,7 +711,7 @@ def build_mlp_config(
710711
"""Builds the MLP config for the module."""
711712
assert is_mlp(module)
712713

713-
config = MLPConfig()
714+
config = MLPConfig(merge_gate_fc=merge_gate_fc)
714715

715716
def _split_gate_from_fc(decoder_type, module, fc_name, fc_layer):
716717
if (
@@ -833,11 +834,7 @@ def _split_gate_from_fc(decoder_type, module, fc_name, fc_layer):
833834
weight_quantizer.amax = torch.cat([amax_chunks[1], amax_chunks[0]], dim=0)
834835

835836
split_gate = _split_gate_from_fc(decoder_type, module, name, fc_linear)
836-
if merge_gate_fc:
837-
config.fc = build_fused_linear_config([gate_linear, fc_linear], LINEAR_COLUMN)
838-
gate_linear = None
839-
840-
elif split_gate:
837+
if split_gate:
841838
# We have to split the gate from the fc
842839
weights = torch.chunk(fc_linear.weight, 2, dim=0)
843840
weight_scaling_factor = get_weight_scaling_factor(fc_linear)
@@ -1104,7 +1101,7 @@ def build_stacked_experts(
11041101
def build_moe_config(module: nn.Module, decoder_type) -> MOEConfig:
11051102
"""Builds the MOE config for the module."""
11061103
assert is_moe(module)
1107-
assert decoder_type in ["llama", "dbrx", "phi3", "deepseek"]
1104+
assert decoder_type in ["llama", "dbrx", "phi3", "deepseek", "qwen"]
11081105

11091106
config = MOEConfig()
11101107

@@ -1128,10 +1125,19 @@ def build_moe_config(module: nn.Module, decoder_type) -> MOEConfig:
11281125
config.shared_expert = build_mlp_config(
11291126
module.shared_experts, decoder_type, merge_gate_fc=True
11301127
)
1128+
elif decoder_type == "qwen":
1129+
config.router = build_linear_config(module.gate, LINEAR_ROW)
1130+
preprocess_linear_fusion([module.shared_expert.gate_proj, module.shared_expert.up_proj])
1131+
config.shared_expert = build_mlp_config(
1132+
module.shared_expert, decoder_type, merge_gate_fc=True
1133+
)
1134+
config.shared_expert_gate = build_linear_config(module.shared_expert_gate, LINEAR_ROW)
1135+
config.shared_expert_gate.tp = False
11311136
else:
11321137
raise NotImplementedError(f"{decoder_type} not supported")
11331138

11341139
config.router.weight = config.router.weight.type(torch.float)
1140+
config.router.tp = False
11351141

11361142
# Experts
11371143
experts = ExpertConfig()
@@ -1168,7 +1174,7 @@ def build_moe_config(module: nn.Module, decoder_type) -> MOEConfig:
11681174
len(module.experts.mlp.w1_linear),
11691175
_get_dbrx_expert,
11701176
)
1171-
elif decoder_type == "deepseek":
1177+
elif decoder_type in ["deepseek", "qwen"]:
11721178
experts.fc, experts.proj = build_stacked_experts(
11731179
module.experts,
11741180
["gate_proj", "down_proj", "up_proj"],

modelopt/torch/export/model_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class LinearConfig:
9595
prequant_scaling_factor: torch.Tensor = None
9696
awq_block_size: int = 0
9797

98+
# If set to false, we do not split or merge this config during post tp processing.
99+
tp: bool = True
100+
98101
def __del__(self):
99102
del self.weight
100103
del self.bias
@@ -320,6 +323,7 @@ class MLPConfig:
320323
gate: LinearConfig = None
321324
proj: LinearConfig = None
322325
hidden_act: str = ""
326+
merge_gate_fc: bool = False
323327

324328

325329
@dataclass
@@ -359,6 +363,7 @@ class MOEConfig:
359363
router: LinearConfig = None
360364
experts: ExpertConfig = None
361365
shared_expert: MLPConfig = None # Deepseek MOE
366+
shared_expert_gate: LinearConfig = None # Qwen MOE
362367
hidden_act: str = ""
363368

364369
@property

modelopt/torch/export/model_config_export.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454
from .model_config import QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ, ModelConfig
5555
from .model_config_utils import (
56+
merge_gate_fc,
5657
merge_qkv,
5758
model_config_to_dict,
5859
pack_linear_weights,
@@ -409,6 +410,7 @@ def torch_to_tensorrt_llm_checkpoint(
409410
assert model_config.rank >= 0, "Invalid model_config, postprocess_model_config fails."
410411

411412
merge_qkv(model_config)
413+
merge_gate_fc(model_config)
412414
pack_linear_weights(model_config)
413415

414416
weights = {}

modelopt/torch/export/model_config_utils.py

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@
2323
import torch
2424

2525
from .model_config import (
26-
QUANTIZATION_FP8,
2726
QUANTIZATION_INT4_AWQ,
28-
QUANTIZATION_NVFP4,
29-
QUANTIZATION_NVFP4_AWQ,
3027
QUANTIZATION_W4A8_AWQ,
3128
DecoderLayerConfig,
3229
LayernormConfig,
@@ -272,6 +269,38 @@ def merge_qkv(model_config):
272269
del splitted_qkv
273270

274271

272+
def merge_gate_fc(model_config):
273+
"""Postprocess the MLP config for TensorRT-LLM export."""
274+
for decoder_config in model_config.layers:
275+
mlp = None
276+
if isinstance(decoder_config.mlp, MLPConfig):
277+
mlp = decoder_config.mlp
278+
elif (
279+
isinstance(decoder_config.mlp, MOEConfig)
280+
and decoder_config.mlp.shared_expert is not None
281+
):
282+
mlp = decoder_config.mlp.shared_expert
283+
284+
if mlp is not None and mlp.merge_gate_fc and mlp.gate is not None and mlp.fc is not None:
285+
mlp.fc.weight = torch.cat(
286+
[
287+
mlp.gate.weight,
288+
mlp.fc.weight,
289+
],
290+
dim=0,
291+
)
292+
293+
if (
294+
mlp.fc.weights_scaling_factor is not None
295+
and mlp.fc.weights_scaling_factor.numel() > 1
296+
):
297+
mlp.fc.weights_scaling_factor = torch.cat(
298+
[mlp.gate.weights_scaling_factor, mlp.fc.weights_scaling_factor], dim=0
299+
)
300+
301+
mlp.gate = None
302+
303+
275304
def pack_linear_weights(model_config: ModelConfig):
276305
"""Packs the quantized linear weights in the model_config to the quantized format."""
277306

@@ -314,43 +343,39 @@ def _linear_layer_to_quantized_weight(linear_layers):
314343
if not model_config.quantization:
315344
return
316345

317-
attention_key_list = ["attention", "self_attention", "cross_attention"]
318-
for decoder_config in model_config.layers:
319-
linear_layers = []
320-
if any([hasattr(decoder_config, attention_key) for attention_key in attention_key_list]):
321-
for attention_key in attention_key_list:
322-
attention = getattr(decoder_config, attention_key, None)
323-
if attention:
324-
linear_layers += [
325-
attention.qkv,
326-
attention.dense,
327-
]
328-
if decoder_config.recurrent:
329-
linear_layers = [
330-
decoder_config.recurrent.linear_y,
331-
decoder_config.recurrent.linear_x,
332-
decoder_config.recurrent.linear_out,
333-
]
334-
335-
if isinstance(decoder_config.mlp, MOEConfig):
336-
if model_config.quantization not in [
337-
QUANTIZATION_FP8,
338-
QUANTIZATION_INT4_AWQ,
339-
QUANTIZATION_NVFP4,
340-
QUANTIZATION_NVFP4_AWQ,
341-
]:
342-
raise NotImplementedError(
343-
f"MOE quantization for {model_config.quantization} is not supported yet."
344-
)
345-
else:
346-
linear_layers.append(decoder_config.mlp.experts.fc)
347-
linear_layers.append(decoder_config.mlp.experts.proj)
348-
elif decoder_config.mlp is not None:
349-
linear_layers.append(decoder_config.mlp.fc)
350-
linear_layers.append(decoder_config.mlp.proj)
351-
linear_layers.append(decoder_config.mlp.gate)
346+
def _find_linear_configs_recursive(model_config):
347+
linear_configs = []
352348

353-
_linear_layer_to_quantized_weight(linear_layers)
349+
# Base case - not a dataclass
350+
if not dataclasses.is_dataclass(model_config):
351+
return linear_configs
352+
353+
# Check if current object is a LinearConfig
354+
if isinstance(model_config, LinearConfig):
355+
linear_configs.append(model_config)
356+
return linear_configs
357+
358+
# Recursively check all fields
359+
for field in dataclasses.fields(model_config):
360+
value = getattr(model_config, field.name)
361+
362+
if isinstance(value, list):
363+
for item in value:
364+
linear_configs.extend(_find_linear_configs_recursive(item))
365+
366+
elif isinstance(value, dict):
367+
for _, item in value.items():
368+
linear_configs.extend(_find_linear_configs_recursive(item))
369+
370+
# Handle nested dataclasses
371+
elif dataclasses.is_dataclass(value):
372+
linear_configs.extend(_find_linear_configs_recursive(value))
373+
374+
return linear_configs
375+
376+
linear_layers = _find_linear_configs_recursive(model_config)
377+
378+
_linear_layer_to_quantized_weight(linear_layers)
354379

355380
if model_config.medusa_heads is not None:
356381
linear_layers = []

modelopt/torch/export/postprocess.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
ExpertConfig,
4242
LinearConfig,
4343
ModelConfig,
44-
MOEConfig,
4544
RelativeAttentionTableConfig,
4645
)
4746
from .model_config_utils import pad_weights
@@ -99,17 +98,6 @@ def _split_model_config_for_tp(merged_config, split_factor):
9998
for i, config in enumerate(configs):
10099
config.weight = weights[i]
101100

102-
elif isinstance(merged_config, MOEConfig):
103-
split_expert_configs = _split_model_config_for_tp(
104-
merged_config.experts,
105-
split_factor,
106-
)
107-
# TP for rounter of MoE is skipped for better performance
108-
# See https://github.com/NVIDIA/TensorRT-LLM/pull/1091 for details
109-
for i in range(split_factor):
110-
configs[i].experts = split_expert_configs[i]
111-
configs[i].router = merged_config.router
112-
113101
elif isinstance(merged_config, ExpertConfig):
114102
assert merged_config.proj.linear_type != LINEAR_COLUMN # row
115103
assert merged_config.fc.linear_type == LINEAR_COLUMN # column
@@ -199,6 +187,10 @@ def _split_model_config_for_tp(merged_config, split_factor):
199187
"Do not support group linear TP merge or split"
200188
)
201189

190+
# Do not do anything if we don't need to process TP.
191+
if not merged_config.tp:
192+
return configs
193+
202194
split_axis = 0 if merged_config.linear_type == LINEAR_COLUMN else 1
203195
if merged_config.linear_type == LINEAR_COLUMN:
204196
merged_config.weight = pad_weights(merged_config.weight, split_factor)
@@ -342,6 +334,10 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None):
342334

343335
assert config.linear_type != LINEAR_GROUP, "Do not support group linear TP merge or split"
344336

337+
# No merge is needed if tp is disabled.
338+
if not config.tp:
339+
return
340+
345341
# Handling constants
346342
for field_name in [
347343
"activation_scaling_factor",
@@ -758,41 +754,48 @@ def check_weight_shape_valid(config, inference_tensor_parallel=1, training_tenso
758754
This function is recurisve.
759755
"""
760756

761-
def _check_merged_weight(merged_k):
762-
assert merged_k % inference_tensor_parallel == 0, (
763-
f"Weights cannot be split into {inference_tensor_parallel} ranks."
764-
)
757+
def _check_merged_weight(merged_k, tp):
758+
assert merged_k % tp == 0, f"Weights with shape {merged_k} cannot be split into {tp} ranks."
765759

766-
def _check_merged_weight_scaling_factor(merged_k, awq_block_size):
767-
if awq_block_size > 0 and (merged_k // inference_tensor_parallel) % awq_block_size != 0:
760+
def _check_merged_weight_scaling_factor(merged_k, tp, awq_block_size):
761+
if awq_block_size > 0 and (merged_k // tp) % awq_block_size != 0:
768762
raise NotImplementedError(
769-
"Weight shape is not divisible for block size for block quantization."
763+
f"Weight shape {merged_k} of each TP tp={tp} "
764+
f"is not divisible for block size {awq_block_size} for block quantization."
770765
)
771766

772-
def _check_merged_channel_is_valid(merged_k, awq_block_size):
773-
_check_merged_weight(merged_k=merged_k)
774-
_check_merged_weight_scaling_factor(merged_k=merged_k, awq_block_size=awq_block_size)
767+
def _check_merged_channel_is_valid(merged_k, tp, awq_block_size):
768+
_check_merged_weight(merged_k=merged_k, tp=tp)
769+
_check_merged_weight_scaling_factor(merged_k=merged_k, tp=tp, awq_block_size=awq_block_size)
775770

776771
if isinstance(config, LinearConfig):
777772
# check weight shape
773+
if not config.tp:
774+
inference_tensor_parallel = 1
778775
if config.linear_type == LINEAR_COLUMN:
779776
_, k = config.weight.shape
780777
merged_k = k * training_tensor_parallel
781-
_check_merged_channel_is_valid(merged_k, config.awq_block_size)
778+
_check_merged_channel_is_valid(
779+
merged_k, tp=inference_tensor_parallel, awq_block_size=config.awq_block_size
780+
)
782781
elif config.linear_type == LINEAR_ROW:
783782
k, m = config.weight.shape
784783
merged_k = k * training_tensor_parallel
785784
merged_m = m * training_tensor_parallel
786785
# For int4_awq, weight scaling factors will be split as (k, (merged_m // TP) // block_size)
787-
_check_merged_weight(merged_k=merged_k)
788-
_check_merged_weight_scaling_factor(merged_m, config.awq_block_size)
786+
_check_merged_weight(merged_k=merged_k, tp=inference_tensor_parallel)
787+
_check_merged_weight_scaling_factor(
788+
merged_m, tp=inference_tensor_parallel, awq_block_size=config.awq_block_size
789+
)
789790

790791
return
791792

792793
if isinstance(config, ExpertConfig):
793794
_, _, k = config.fc.weight.shape
794795
merged_k = k * training_tensor_parallel
795-
_check_merged_channel_is_valid(merged_k, config.fc.awq_block_size)
796+
_check_merged_channel_is_valid(
797+
merged_k, tp=inference_tensor_parallel, awq_block_size=config.fc.awq_block_size
798+
)
796799
return
797800

798801
if is_dataclass(config):

modelopt/torch/quantization/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@
164164
"*proj_out.*": {"enable": False}, # In Whisper model, lm_head has key name proj_out
165165
"*block_sparse_moe.gate*": {"enable": False}, # Skip the MOE router
166166
"*router*": {"enable": False}, # Skip the MOE router
167+
"*mlp.gate.*": {"enable": False}, # Skip the MOE router
168+
"*mlp.shared_expert_gate.*": {"enable": False}, # Skip the MOE router
167169
"*output_layer*": {"enable": False},
168170
"output.*": {"enable": False},
169171
"default": {"enable": False},

0 commit comments

Comments
 (0)