5050from ..modules .linear import Linear , TensorParallelMode
5151from ..modules .mamba .causal_conv1d import causal_conv1d_fn , causal_conv1d_update
5252from ..modules .mamba .layernorm_gated import RMSNorm as RMSNormGated
53+ from ..modules .multi_stream_utils import maybe_execute_in_parallel
5354from ..modules .rms_norm import RMSNorm
5455from ..speculative import SpecMetadata
55- from ..utils import AuxStreamType
56+ from ..utils import AuxStreamType , EventType
5657from .modeling_qwen3 import Qwen3Attention
5758from .modeling_speculative import SpecDecOneEngineForCausalLM
5859from .modeling_utils import DecoderModel , EagerFusionConfig , register_auto_model
@@ -387,6 +388,7 @@ def __init__(
387388 self .mapping = model_config .mapping
388389 self .allreduce = AllReduce (mapping = model_config .mapping ,
389390 strategy = model_config .allreduce_strategy )
391+ self .aux_stream = aux_stream
390392
391393 self .gate = Qwen3NextGate (
392394 hidden_size = self .hidden_dim ,
@@ -425,6 +427,11 @@ def __init__(
425427 dtype = config .torch_dtype ,
426428 quant_config = None )
427429
430+ self .event_dict = {
431+ key : torch .cuda .Event ()
432+ for key in [EventType .Main , EventType .MoeShared ]
433+ }
434+
428435 def forward (
429436 self ,
430437 hidden_states : torch .Tensor ,
@@ -450,22 +457,33 @@ def forward(
450457 dim = 0 ,
451458 sizes = all_rank_num_tokens )
452459
453- router_logits = self .gate (hidden_states )
454- final_hidden_states = self .experts (
455- hidden_states ,
456- router_logits ,
457- all_rank_num_tokens = all_rank_num_tokens ,
458- use_dp_padding = use_dp_padding ,
459- do_finalize = do_finalize ,
460- )
460+ def _compute_routed_output ():
461+ router_logits = self .gate (hidden_states )
462+ final_hidden_states = self .experts (
463+ hidden_states ,
464+ router_logits ,
465+ all_rank_num_tokens = all_rank_num_tokens ,
466+ use_dp_padding = use_dp_padding ,
467+ do_finalize = do_finalize ,
468+ )
469+ return final_hidden_states
461470
471+ def _compute_shared_output ():
472+ shared_expert_output = self .shared_expert (hidden_states )
473+ shared_expert_output = F .sigmoid (
474+ self .shared_expert_gate (hidden_states )) * shared_expert_output
475+ return shared_expert_output
476+
477+ final_hidden_states , shared_expert_output = maybe_execute_in_parallel (
478+ _compute_routed_output ,
479+ _compute_shared_output ,
480+ self .event_dict [EventType .Main ],
481+ self .event_dict [EventType .MoeShared ],
482+ self .aux_stream ,
483+ )
462484 if not do_finalize :
463485 return final_hidden_states
464486
465- shared_expert_output = self .shared_expert (hidden_states )
466- shared_expert_output = F .sigmoid (
467- self .shared_expert_gate (hidden_states )) * shared_expert_output
468-
469487 final_hidden_states = final_hidden_states + shared_expert_output
470488
471489 if not self .enable_attention_dp and self .mapping .tp_size > 1 :
@@ -1038,8 +1056,6 @@ def forward(
10381056 self .head_v_dim ,
10391057 )
10401058 else :
1041- query , key , value , z , b , a = self .fix_query_key_value_ordering (
1042- projected_states_qkvz , projected_states_ba )
10431059 query , key , value = map (lambda x : x .reshape (x .shape [0 ], - 1 ),
10441060 (query , key , value ))
10451061 mixed_qkv = torch .cat ((query , key , value ), dim = - 1 )
@@ -1061,16 +1077,11 @@ def forward(
10611077 "num_decode" : num_decodes ,
10621078 }
10631079
1064- new_implementation = True
1065- if new_implementation :
1066- if num_prefills > 0 :
1067- attn_out = self .forward_extend (conv_states , ssm_states ,
1068- ** kwargs )
1069- else :
1070- attn_out = self .forward_decode (conv_states , ssm_states ,
1071- num_decodes ,
1072- mamba_metadata .cu_seqlens ,
1073- ** kwargs )
1080+ if num_prefills > 0 :
1081+ attn_out = self .forward_extend (conv_states , ssm_states , ** kwargs )
1082+ else :
1083+ attn_out = self .forward_decode (conv_states , ssm_states , num_decodes ,
1084+ mamba_metadata .cu_seqlens , ** kwargs )
10741085
10751086 z_shape_og = z .shape
10761087 # reshape input data into 2D tensor
@@ -1125,7 +1136,7 @@ def __init__(
11251136 "TRTLLM_QWEN3_EAGER_FUSION_DISABLED" , "1" ) == "0"
11261137 self .enable_fusion &= not self .enable_attention_dp
11271138
1128- self .mapping .has_tp ()
1139+ # has_tp = self.mapping.has_tp()
11291140 has_pp = self .mapping .has_pp ()
11301141
11311142 # self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
@@ -1284,7 +1295,7 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig],
12841295 "TRTLLM_QWEN3_EAGER_FUSION_DISABLED" , "0" ) == "0"
12851296 self .enable_fusion &= not self .enable_attention_dp
12861297
1287- self .mapping .has_tp ()
1298+ # has_tp = self.mapping.has_tp()
12881299 has_pp = self .mapping .has_pp ()
12891300
12901301 # self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
0 commit comments