5050from ..modules .linear import Linear , TensorParallelMode
5151from ..modules .mamba .causal_conv1d import causal_conv1d_fn , causal_conv1d_update
5252from ..modules .mamba .layernorm_gated import RMSNorm as RMSNormGated
53+ from ..modules .multi_stream_utils import maybe_execute_in_parallel
5354from ..modules .rms_norm import RMSNorm
5455from ..speculative import SpecMetadata
55- from ..utils import AuxStreamType
56+ from ..utils import AuxStreamType , EventType
5657from .modeling_qwen3 import Qwen3Attention
5758from .modeling_speculative import SpecDecOneEngineForCausalLM
5859from .modeling_utils import DecoderModel , EagerFusionConfig , register_auto_model
@@ -425,6 +426,11 @@ def __init__(
425426 dtype = config .torch_dtype ,
426427 quant_config = None )
427428
429+ self .event_dict = {
430+ key : torch .cuda .Event ()
431+ for key in [EventType .Main , EventType .MoeShared ]
432+ }
433+
428434 def forward (
429435 self ,
430436 hidden_states : torch .Tensor ,
@@ -450,22 +456,33 @@ def forward(
450456 dim = 0 ,
451457 sizes = all_rank_num_tokens )
452458
453- router_logits = self .gate (hidden_states )
454- final_hidden_states = self .experts (
455- hidden_states ,
456- router_logits ,
457- all_rank_num_tokens = all_rank_num_tokens ,
458- use_dp_padding = use_dp_padding ,
459- do_finalize = do_finalize ,
460- )
459+ def _compute_routed_output ():
460+ router_logits = self .gate (hidden_states )
461+ final_hidden_states = self .experts (
462+ hidden_states ,
463+ router_logits ,
464+ all_rank_num_tokens = all_rank_num_tokens ,
465+ use_dp_padding = use_dp_padding ,
466+ do_finalize = do_finalize ,
467+ )
468+ return final_hidden_states
461469
470+ def _compute_shared_output ():
471+ shared_expert_output = self .shared_expert (hidden_states )
472+ shared_expert_output = F .sigmoid (
473+ self .shared_expert_gate (hidden_states )) * shared_expert_output
474+ return shared_expert_output
475+
476+ final_hidden_states , shared_expert_output = maybe_execute_in_parallel (
477+ _compute_routed_output ,
478+ _compute_shared_output ,
479+ self .event_dict [EventType .Main ],
480+ self .event_dict [EventType .MoeShared ],
481+ self .aux_stream ,
482+ )
462483 if not do_finalize :
463484 return final_hidden_states
464485
465- shared_expert_output = self .shared_expert (hidden_states )
466- shared_expert_output = F .sigmoid (
467- self .shared_expert_gate (hidden_states )) * shared_expert_output
468-
469486 final_hidden_states = final_hidden_states + shared_expert_output
470487
471488 if not self .enable_attention_dp and self .mapping .tp_size > 1 :
@@ -1038,8 +1055,6 @@ def forward(
10381055 self .head_v_dim ,
10391056 )
10401057 else :
1041- query , key , value , z , b , a = self .fix_query_key_value_ordering (
1042- projected_states_qkvz , projected_states_ba )
10431058 query , key , value = map (lambda x : x .reshape (x .shape [0 ], - 1 ),
10441059 (query , key , value ))
10451060 mixed_qkv = torch .cat ((query , key , value ), dim = - 1 )
@@ -1061,16 +1076,11 @@ def forward(
10611076 "num_decode" : num_decodes ,
10621077 }
10631078
1064- new_implementation = True
1065- if new_implementation :
1066- if num_prefills > 0 :
1067- attn_out = self .forward_extend (conv_states , ssm_states ,
1068- ** kwargs )
1069- else :
1070- attn_out = self .forward_decode (conv_states , ssm_states ,
1071- num_decodes ,
1072- mamba_metadata .cu_seqlens ,
1073- ** kwargs )
1079+ if num_prefills > 0 :
1080+ attn_out = self .forward_extend (conv_states , ssm_states , ** kwargs )
1081+ else :
1082+ attn_out = self .forward_decode (conv_states , ssm_states , num_decodes ,
1083+ mamba_metadata .cu_seqlens , ** kwargs )
10741084
10751085 z_shape_og = z .shape
10761086 # reshape input data into 2D tensor
@@ -1125,7 +1135,7 @@ def __init__(
11251135 "TRTLLM_QWEN3_EAGER_FUSION_DISABLED" , "1" ) == "0"
11261136 self .enable_fusion &= not self .enable_attention_dp
11271137
1128- self .mapping .has_tp ()
1138+ # has_tp = self.mapping.has_tp()
11291139 has_pp = self .mapping .has_pp ()
11301140
11311141 # self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
@@ -1284,7 +1294,7 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig],
12841294 "TRTLLM_QWEN3_EAGER_FUSION_DISABLED" , "0" ) == "0"
12851295 self .enable_fusion &= not self .enable_attention_dp
12861296
1287- self .mapping .has_tp ()
1297+ # has_tp = self.mapping.has_tp()
12881298 has_pp = self .mapping .has_pp ()
12891299
12901300 # self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
0 commit comments