1+ import os
12from typing import Dict , Optional
23
34import torch
67from transformers import Qwen3MoeConfig
78
89from ..attention_backend import AttentionMetadata
10+ from ..distributed import AllReduce , AllReduceFusionOp , AllReduceParams
911from ..model_config import ModelConfig
12+ from ..models .modeling_utils import MissingLayer
1013from ..modules .decoder_layer import DecoderLayer
1114from ..modules .embedding import Embedding
1215from ..modules .fused_moe import FusedMoE , RenormalizeMoeRoutingMethod
1316from ..modules .linear import Linear , TensorParallelMode
1417from ..modules .rms_norm import RMSNorm
1518from .modeling_qwen3 import Qwen3Attention
1619from .modeling_utils import (DecoderModel , DecoderModelForCausalLM ,
17- duplicate_kv_weight , register_auto_model )
20+ EagerFusionConfig , duplicate_kv_weight ,
21+ register_auto_model )
1822
1923
2024class Qwen3MoE (nn .Module ):
@@ -33,6 +37,8 @@ def __init__(
3337 self .num_experts = config .num_experts
3438 self .top_k = config .num_experts_per_tok
3539 self .enable_attention_dp = model_config .mapping .enable_attention_dp
40+ self .mapping = model_config .mapping
41+ self .allreduce = AllReduce (self .mapping )
3642
3743 # moe gate (linear layer) only runs in half/full precision for now
3844 self .gate = Linear (self .hidden_dim ,
@@ -41,23 +47,22 @@ def __init__(
4147 dtype = config .torch_dtype ,
4248 quant_config = None )
4349
44- reduce_results = True
45-
4650 self .experts = FusedMoE (
4751 num_experts = self .num_experts ,
4852 routing_method = RenormalizeMoeRoutingMethod (top_k = self .top_k ),
4953 hidden_size = self .hidden_dim ,
5054 intermediate_size = self .moe_intermediate_size ,
5155 aux_stream = aux_stream ,
5256 dtype = config .torch_dtype ,
53- reduce_results = reduce_results ,
57+ reduce_results = False ,
5458 model_config = model_config ,
5559 )
5660
5761 def forward (
5862 self ,
5963 hidden_states : torch .Tensor ,
6064 attn_metadata : AttentionMetadata ,
65+ all_reduce_params : Optional [AllReduceParams ] = None ,
6166 ) -> torch .Tensor :
6267 assert hidden_states .shape [- 1 ] == self .hidden_dim
6368 orig_shape = hidden_states .shape
@@ -75,6 +80,10 @@ def forward(
7580 router_logits ,
7681 all_rank_num_tokens = all_rank_num_tokens )
7782
83+ if not self .enable_attention_dp and self .mapping .tp_size > 1 :
84+ final_hidden_states = self .allreduce (
85+ final_hidden_states , all_reduce_params = all_reduce_params )
86+
7887 return final_hidden_states .view (orig_shape )
7988
8089
@@ -88,6 +97,8 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
8897 model_config ,
8998 layer_idx = layer_idx ,
9099 )
100+ self .mapping = model_config .mapping
101+ self .enable_attention_dp = self .mapping .enable_attention_dp
91102
92103 self .mlp = Qwen3MoE (model_config , aux_stream )
93104
@@ -100,6 +111,23 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
100111 dtype = config .torch_dtype )
101112 self .layer_idx = layer_idx
102113
114+ self .allreduce = AllReduce (self .mapping )
115+ self .next_layer_layernorm : RMSNorm = None
116+
117+ self .fusion_config = EagerFusionConfig ()
118+ self .enable_fusion = os .environ .get (
119+ "TRTLLM_QWEN3_EAGER_FUSION_DISABLED" , "0" ) == "0"
120+ self .enable_fusion &= not self .enable_attention_dp
121+
122+ has_tp = self .mapping .has_tp ()
123+ has_pp = self .mapping .has_pp ()
124+
125+ self .fusion_config .PRE_MOE_FUSION = self .enable_fusion and has_tp
126+ self .fusion_config .POST_MOE_FUSION = self .fusion_config .PRE_MOE_FUSION and not has_pp
127+ self .disable_attn_allreduce = (self .fusion_config .PRE_MOE_FUSION
128+ or self .mapping .tp_size == 1
129+ or self .enable_attention_dp )
130+
103131 def forward (
104132 self ,
105133 position_ids : torch .LongTensor ,
@@ -111,22 +139,51 @@ def forward(
111139 if residual is None :
112140 residual = hidden_states
113141 hidden_states = self .input_layernorm (hidden_states )
114- else :
115- hidden_states , residual = self .input_layernorm (
116- hidden_states , residual )
117142
118143 # Self Attention
119144 hidden_states = self .self_attn (
120145 position_ids = position_ids ,
121146 hidden_states = hidden_states ,
122147 attn_metadata = attn_metadata ,
148+ all_reduce_params = AllReduceParams (
149+ enable_allreduce = not self .disable_attn_allreduce ),
123150 ** kwargs ,
124151 )
125152
126- # Fully Connected
127- hidden_states , residual = self .post_attention_layernorm (
128- hidden_states , residual )
129- hidden_states = self .mlp (hidden_states , attn_metadata )
153+ if self .fusion_config .PRE_MOE_FUSION :
154+ hidden_states , residual = self .allreduce (
155+ hidden_states ,
156+ all_reduce_params = AllReduceParams (
157+ fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM ,
158+ residual = residual ,
159+ norm_weight = self .post_attention_layernorm .weight ,
160+ eps = self .post_attention_layernorm .variance_epsilon ,
161+ ))
162+ else :
163+ # No fusion
164+ hidden_states , residual = self .post_attention_layernorm (
165+ hidden_states , residual )
166+
167+ hidden_states = self .mlp (
168+ hidden_states ,
169+ attn_metadata ,
170+ all_reduce_params = AllReduceParams (
171+ enable_allreduce = not (self .fusion_config .POST_MOE_FUSION
172+ or self .mapping .tp_size == 1 )))
173+
174+ if self .fusion_config .POST_MOE_FUSION :
175+ hidden_states , residual = self .allreduce (
176+ hidden_states ,
177+ all_reduce_params = AllReduceParams (
178+ fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM ,
179+ residual = residual ,
180+ norm_weight = self .next_layer_layernorm .weight ,
181+ eps = self .next_layer_layernorm .variance_epsilon ,
182+ ))
183+ else :
184+ if self .next_layer_layernorm is not None :
185+ hidden_states , residual = self .next_layer_layernorm (
186+ hidden_states , residual )
130187 return hidden_states , residual
131188
132189
@@ -192,8 +249,6 @@ def forward(
192249 hidden_states = hidden_states ,
193250 attn_metadata = attn_metadata ,
194251 residual = residual )
195-
196- hidden_states , _ = self .norm (hidden_states , residual )
197252 return hidden_states
198253
199254
@@ -280,3 +335,10 @@ def filter_weights(prefix, weights: Dict):
280335 for n , p in module ._parameters .items ():
281336 if p is not None :
282337 p .data .copy_ (module_weights [n ][:])
338+ for idx , layer in enumerate (
339+ self .model .layers [:self .config .num_hidden_layers ]):
340+ if idx == self .config .num_hidden_layers - 1 :
341+ layer .next_layer_layernorm = self .model .norm
342+ elif not isinstance (self .model .layers [idx + 1 ], MissingLayer ):
343+ layer .next_layer_layernorm = self .model .layers [
344+ idx + 1 ].input_layernorm
0 commit comments