Skip to content

Commit 57d412b

Browse files
committed
[None][chroe] Polish qwen3-next modeling code.
Signed-off-by: nv-guomingz <[email protected]>
1 parent 9ec6a6b commit 57d412b

File tree

2 files changed

+43
-34
lines changed

2 files changed

+43
-34
lines changed

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@
5050
from ..modules.linear import Linear, TensorParallelMode
5151
from ..modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
5252
from ..modules.mamba.layernorm_gated import RMSNorm as RMSNormGated
53+
from ..modules.multi_stream_utils import maybe_execute_in_parallel
5354
from ..modules.rms_norm import RMSNorm
5455
from ..speculative import SpecMetadata
55-
from ..utils import AuxStreamType
56+
from ..utils import AuxStreamType, EventType
5657
from .modeling_qwen3 import Qwen3Attention
5758
from .modeling_speculative import SpecDecOneEngineForCausalLM
5859
from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model
@@ -387,6 +388,7 @@ def __init__(
387388
self.mapping = model_config.mapping
388389
self.allreduce = AllReduce(mapping=model_config.mapping,
389390
strategy=model_config.allreduce_strategy)
391+
self.aux_stream = aux_stream
390392

391393
self.gate = Qwen3NextGate(
392394
hidden_size=self.hidden_dim,
@@ -425,6 +427,11 @@ def __init__(
425427
dtype=config.torch_dtype,
426428
quant_config=None)
427429

430+
self.event_dict = {
431+
key: torch.cuda.Event()
432+
for key in [EventType.Main, EventType.MoeShared]
433+
}
434+
428435
def forward(
429436
self,
430437
hidden_states: torch.Tensor,
@@ -450,22 +457,33 @@ def forward(
450457
dim=0,
451458
sizes=all_rank_num_tokens)
452459

453-
router_logits = self.gate(hidden_states)
454-
final_hidden_states = self.experts(
455-
hidden_states,
456-
router_logits,
457-
all_rank_num_tokens=all_rank_num_tokens,
458-
use_dp_padding=use_dp_padding,
459-
do_finalize=do_finalize,
460-
)
460+
def _compute_routed_output():
461+
router_logits = self.gate(hidden_states)
462+
final_hidden_states = self.experts(
463+
hidden_states,
464+
router_logits,
465+
all_rank_num_tokens=all_rank_num_tokens,
466+
use_dp_padding=use_dp_padding,
467+
do_finalize=do_finalize,
468+
)
469+
return final_hidden_states
461470

471+
def _compute_shared_output():
472+
shared_expert_output = self.shared_expert(hidden_states)
473+
shared_expert_output = F.sigmoid(
474+
self.shared_expert_gate(hidden_states)) * shared_expert_output
475+
return shared_expert_output
476+
477+
final_hidden_states, shared_expert_output = maybe_execute_in_parallel(
478+
_compute_routed_output,
479+
_compute_shared_output,
480+
self.event_dict[EventType.Main],
481+
self.event_dict[EventType.MoeShared],
482+
self.aux_stream,
483+
)
462484
if not do_finalize:
463485
return final_hidden_states
464486

465-
shared_expert_output = self.shared_expert(hidden_states)
466-
shared_expert_output = F.sigmoid(
467-
self.shared_expert_gate(hidden_states)) * shared_expert_output
468-
469487
final_hidden_states = final_hidden_states + shared_expert_output
470488

471489
if not self.enable_attention_dp and self.mapping.tp_size > 1:
@@ -1038,8 +1056,6 @@ def forward(
10381056
self.head_v_dim,
10391057
)
10401058
else:
1041-
query, key, value, z, b, a = self.fix_query_key_value_ordering(
1042-
projected_states_qkvz, projected_states_ba)
10431059
query, key, value = map(lambda x: x.reshape(x.shape[0], -1),
10441060
(query, key, value))
10451061
mixed_qkv = torch.cat((query, key, value), dim=-1)
@@ -1061,25 +1077,20 @@ def forward(
10611077
"num_decode": num_decodes,
10621078
}
10631079

1064-
new_implementation = True
1065-
if new_implementation:
1066-
if num_prefills > 0:
1067-
attn_out = self.forward_extend(conv_states, ssm_states,
1068-
**kwargs)
1069-
else:
1070-
attn_out = self.forward_decode(conv_states, ssm_states,
1071-
num_decodes,
1072-
mamba_metadata.cu_seqlens,
1073-
**kwargs)
1080+
if num_prefills > 0:
1081+
attn_out = self.forward_extend(conv_states, ssm_states, **kwargs)
1082+
else:
1083+
attn_out = self.forward_decode(conv_states, ssm_states, num_decodes,
1084+
mamba_metadata.cu_seqlens, **kwargs)
10741085

1075-
z_shape_og = z.shape
1076-
# reshape input data into 2D tensor
1086+
# reshape input data into 2D tensor for norm
1087+
batch_size = z.shape[0]
1088+
final_dim = z.shape[1] * z.shape[2]
10771089
attn_out = attn_out.reshape(-1, attn_out.shape[-1])
10781090
z = z.reshape(-1, z.shape[-1])
10791091
attn_out = self.norm(attn_out, z)
1080-
attn_out = attn_out.reshape(z_shape_og)
1081-
attn_out = attn_out.reshape(*attn_out.shape[:-2], -1)
1082-
1092+
# directly reshape to final output shape [batch, num_heads_v * head_v]
1093+
attn_out = attn_out.reshape(batch_size, final_dim)
10831094
output = self.out_proj(attn_out, all_reduce_params=all_reduce_params)
10841095
return output
10851096

@@ -1125,7 +1136,7 @@ def __init__(
11251136
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "1") == "0"
11261137
self.enable_fusion &= not self.enable_attention_dp
11271138

1128-
self.mapping.has_tp()
1139+
# has_tp = self.mapping.has_tp()
11291140
has_pp = self.mapping.has_pp()
11301141

11311142
# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
@@ -1284,7 +1295,7 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig],
12841295
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "0") == "0"
12851296
self.enable_fusion &= not self.enable_attention_dp
12861297

1287-
self.mapping.has_tp()
1298+
# has_tp = self.mapping.has_tp()
12881299
has_pp = self.mapping.has_pp()
12891300

12901301
# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp

tensorrt_llm/_torch/modules/fla/chunk.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ def forward(
9090
cu_seqlens: Optional[torch.LongTensor] = None,
9191
use_qk_l2norm_in_kernel: bool = False,
9292
):
93-
pass
94-
9593
if use_qk_l2norm_in_kernel:
9694
q = l2norm_fwd(q)
9795
k = l2norm_fwd(k)

0 commit comments

Comments
 (0)