Skip to content

Commit f0ca60a

Browse files
authored
Add allreduce and rmsnorm fusion for qwen3 (#4304)
Signed-off-by: Zongfei Jing <[email protected]>
1 parent 14bfb5e commit f0ca60a

File tree

1 file changed

+75
-13
lines changed

1 file changed

+75
-13
lines changed

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Dict, Optional
23

34
import torch
@@ -6,15 +7,18 @@
67
from transformers import Qwen3MoeConfig
78

89
from ..attention_backend import AttentionMetadata
10+
from ..distributed import AllReduce, AllReduceFusionOp, AllReduceParams
911
from ..model_config import ModelConfig
12+
from ..models.modeling_utils import MissingLayer
1013
from ..modules.decoder_layer import DecoderLayer
1114
from ..modules.embedding import Embedding
1215
from ..modules.fused_moe import FusedMoE, RenormalizeMoeRoutingMethod
1316
from ..modules.linear import Linear, TensorParallelMode
1417
from ..modules.rms_norm import RMSNorm
1518
from .modeling_qwen3 import Qwen3Attention
1619
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
17-
duplicate_kv_weight, register_auto_model)
20+
EagerFusionConfig, duplicate_kv_weight,
21+
register_auto_model)
1822

1923

2024
class Qwen3MoE(nn.Module):
@@ -33,6 +37,8 @@ def __init__(
3337
self.num_experts = config.num_experts
3438
self.top_k = config.num_experts_per_tok
3539
self.enable_attention_dp = model_config.mapping.enable_attention_dp
40+
self.mapping = model_config.mapping
41+
self.allreduce = AllReduce(self.mapping)
3642

3743
# moe gate (linear layer) only runs in half/full precision for now
3844
self.gate = Linear(self.hidden_dim,
@@ -41,23 +47,22 @@ def __init__(
4147
dtype=config.torch_dtype,
4248
quant_config=None)
4349

44-
reduce_results = True
45-
4650
self.experts = FusedMoE(
4751
num_experts=self.num_experts,
4852
routing_method=RenormalizeMoeRoutingMethod(top_k=self.top_k),
4953
hidden_size=self.hidden_dim,
5054
intermediate_size=self.moe_intermediate_size,
5155
aux_stream=aux_stream,
5256
dtype=config.torch_dtype,
53-
reduce_results=reduce_results,
57+
reduce_results=False,
5458
model_config=model_config,
5559
)
5660

5761
def forward(
5862
self,
5963
hidden_states: torch.Tensor,
6064
attn_metadata: AttentionMetadata,
65+
all_reduce_params: Optional[AllReduceParams] = None,
6166
) -> torch.Tensor:
6267
assert hidden_states.shape[-1] == self.hidden_dim
6368
orig_shape = hidden_states.shape
@@ -75,6 +80,10 @@ def forward(
7580
router_logits,
7681
all_rank_num_tokens=all_rank_num_tokens)
7782

83+
if not self.enable_attention_dp and self.mapping.tp_size > 1:
84+
final_hidden_states = self.allreduce(
85+
final_hidden_states, all_reduce_params=all_reduce_params)
86+
7887
return final_hidden_states.view(orig_shape)
7988

8089

@@ -88,6 +97,8 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
8897
model_config,
8998
layer_idx=layer_idx,
9099
)
100+
self.mapping = model_config.mapping
101+
self.enable_attention_dp = self.mapping.enable_attention_dp
91102

92103
self.mlp = Qwen3MoE(model_config, aux_stream)
93104

@@ -100,6 +111,23 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
100111
dtype=config.torch_dtype)
101112
self.layer_idx = layer_idx
102113

114+
self.allreduce = AllReduce(self.mapping)
115+
self.next_layer_layernorm: RMSNorm = None
116+
117+
self.fusion_config = EagerFusionConfig()
118+
self.enable_fusion = os.environ.get(
119+
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "0") == "0"
120+
self.enable_fusion &= not self.enable_attention_dp
121+
122+
has_tp = self.mapping.has_tp()
123+
has_pp = self.mapping.has_pp()
124+
125+
self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
126+
self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION and not has_pp
127+
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
128+
or self.mapping.tp_size == 1
129+
or self.enable_attention_dp)
130+
103131
def forward(
104132
self,
105133
position_ids: torch.LongTensor,
@@ -111,22 +139,51 @@ def forward(
111139
if residual is None:
112140
residual = hidden_states
113141
hidden_states = self.input_layernorm(hidden_states)
114-
else:
115-
hidden_states, residual = self.input_layernorm(
116-
hidden_states, residual)
117142

118143
# Self Attention
119144
hidden_states = self.self_attn(
120145
position_ids=position_ids,
121146
hidden_states=hidden_states,
122147
attn_metadata=attn_metadata,
148+
all_reduce_params=AllReduceParams(
149+
enable_allreduce=not self.disable_attn_allreduce),
123150
**kwargs,
124151
)
125152

126-
# Fully Connected
127-
hidden_states, residual = self.post_attention_layernorm(
128-
hidden_states, residual)
129-
hidden_states = self.mlp(hidden_states, attn_metadata)
153+
if self.fusion_config.PRE_MOE_FUSION:
154+
hidden_states, residual = self.allreduce(
155+
hidden_states,
156+
all_reduce_params=AllReduceParams(
157+
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
158+
residual=residual,
159+
norm_weight=self.post_attention_layernorm.weight,
160+
eps=self.post_attention_layernorm.variance_epsilon,
161+
))
162+
else:
163+
# No fusion
164+
hidden_states, residual = self.post_attention_layernorm(
165+
hidden_states, residual)
166+
167+
hidden_states = self.mlp(
168+
hidden_states,
169+
attn_metadata,
170+
all_reduce_params=AllReduceParams(
171+
enable_allreduce=not (self.fusion_config.POST_MOE_FUSION
172+
or self.mapping.tp_size == 1)))
173+
174+
if self.fusion_config.POST_MOE_FUSION:
175+
hidden_states, residual = self.allreduce(
176+
hidden_states,
177+
all_reduce_params=AllReduceParams(
178+
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
179+
residual=residual,
180+
norm_weight=self.next_layer_layernorm.weight,
181+
eps=self.next_layer_layernorm.variance_epsilon,
182+
))
183+
else:
184+
if self.next_layer_layernorm is not None:
185+
hidden_states, residual = self.next_layer_layernorm(
186+
hidden_states, residual)
130187
return hidden_states, residual
131188

132189

@@ -192,8 +249,6 @@ def forward(
192249
hidden_states=hidden_states,
193250
attn_metadata=attn_metadata,
194251
residual=residual)
195-
196-
hidden_states, _ = self.norm(hidden_states, residual)
197252
return hidden_states
198253

199254

@@ -280,3 +335,10 @@ def filter_weights(prefix, weights: Dict):
280335
for n, p in module._parameters.items():
281336
if p is not None:
282337
p.data.copy_(module_weights[n][:])
338+
for idx, layer in enumerate(
339+
self.model.layers[:self.config.num_hidden_layers]):
340+
if idx == self.config.num_hidden_layers - 1:
341+
layer.next_layer_layernorm = self.model.norm
342+
elif not isinstance(self.model.layers[idx + 1], MissingLayer):
343+
layer.next_layer_layernorm = self.model.layers[
344+
idx + 1].input_layernorm

0 commit comments

Comments
 (0)