Skip to content

Commit 53c6b8a

Browse files
committed
[None][chroe] Polish qwen3-next modeling code.
Signed-off-by: nv-guomingz <[email protected]>
1 parent 9ec6a6b commit 53c6b8a

File tree

2 files changed

+37
-29
lines changed

2 files changed

+37
-29
lines changed

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@
5050
from ..modules.linear import Linear, TensorParallelMode
5151
from ..modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
5252
from ..modules.mamba.layernorm_gated import RMSNorm as RMSNormGated
53+
from ..modules.multi_stream_utils import maybe_execute_in_parallel
5354
from ..modules.rms_norm import RMSNorm
5455
from ..speculative import SpecMetadata
55-
from ..utils import AuxStreamType
56+
from ..utils import AuxStreamType, EventType
5657
from .modeling_qwen3 import Qwen3Attention
5758
from .modeling_speculative import SpecDecOneEngineForCausalLM
5859
from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model
@@ -425,6 +426,11 @@ def __init__(
425426
dtype=config.torch_dtype,
426427
quant_config=None)
427428

429+
self.event_dict = {
430+
key: torch.cuda.Event()
431+
for key in [EventType.Main, EventType.MoeShared]
432+
}
433+
428434
def forward(
429435
self,
430436
hidden_states: torch.Tensor,
@@ -450,22 +456,33 @@ def forward(
450456
dim=0,
451457
sizes=all_rank_num_tokens)
452458

453-
router_logits = self.gate(hidden_states)
454-
final_hidden_states = self.experts(
455-
hidden_states,
456-
router_logits,
457-
all_rank_num_tokens=all_rank_num_tokens,
458-
use_dp_padding=use_dp_padding,
459-
do_finalize=do_finalize,
460-
)
459+
def _compute_routed_output():
460+
router_logits = self.gate(hidden_states)
461+
final_hidden_states = self.experts(
462+
hidden_states,
463+
router_logits,
464+
all_rank_num_tokens=all_rank_num_tokens,
465+
use_dp_padding=use_dp_padding,
466+
do_finalize=do_finalize,
467+
)
468+
return final_hidden_states
461469

470+
def _compute_shared_output():
471+
shared_expert_output = self.shared_expert(hidden_states)
472+
shared_expert_output = F.sigmoid(
473+
self.shared_expert_gate(hidden_states)) * shared_expert_output
474+
return shared_expert_output
475+
476+
final_hidden_states, shared_expert_output = maybe_execute_in_parallel(
477+
_compute_routed_output,
478+
_compute_shared_output,
479+
self.event_dict[EventType.Main],
480+
self.event_dict[EventType.MoeShared],
481+
self.aux_stream,
482+
)
462483
if not do_finalize:
463484
return final_hidden_states
464485

465-
shared_expert_output = self.shared_expert(hidden_states)
466-
shared_expert_output = F.sigmoid(
467-
self.shared_expert_gate(hidden_states)) * shared_expert_output
468-
469486
final_hidden_states = final_hidden_states + shared_expert_output
470487

471488
if not self.enable_attention_dp and self.mapping.tp_size > 1:
@@ -1038,8 +1055,6 @@ def forward(
10381055
self.head_v_dim,
10391056
)
10401057
else:
1041-
query, key, value, z, b, a = self.fix_query_key_value_ordering(
1042-
projected_states_qkvz, projected_states_ba)
10431058
query, key, value = map(lambda x: x.reshape(x.shape[0], -1),
10441059
(query, key, value))
10451060
mixed_qkv = torch.cat((query, key, value), dim=-1)
@@ -1061,16 +1076,11 @@ def forward(
10611076
"num_decode": num_decodes,
10621077
}
10631078

1064-
new_implementation = True
1065-
if new_implementation:
1066-
if num_prefills > 0:
1067-
attn_out = self.forward_extend(conv_states, ssm_states,
1068-
**kwargs)
1069-
else:
1070-
attn_out = self.forward_decode(conv_states, ssm_states,
1071-
num_decodes,
1072-
mamba_metadata.cu_seqlens,
1073-
**kwargs)
1079+
if num_prefills > 0:
1080+
attn_out = self.forward_extend(conv_states, ssm_states, **kwargs)
1081+
else:
1082+
attn_out = self.forward_decode(conv_states, ssm_states, num_decodes,
1083+
mamba_metadata.cu_seqlens, **kwargs)
10741084

10751085
z_shape_og = z.shape
10761086
# reshape input data into 2D tensor
@@ -1125,7 +1135,7 @@ def __init__(
11251135
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "1") == "0"
11261136
self.enable_fusion &= not self.enable_attention_dp
11271137

1128-
self.mapping.has_tp()
1138+
# has_tp = self.mapping.has_tp()
11291139
has_pp = self.mapping.has_pp()
11301140

11311141
# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
@@ -1284,7 +1294,7 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig],
12841294
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "0") == "0"
12851295
self.enable_fusion &= not self.enable_attention_dp
12861296

1287-
self.mapping.has_tp()
1297+
# has_tp = self.mapping.has_tp()
12881298
has_pp = self.mapping.has_pp()
12891299

12901300
# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp

tensorrt_llm/_torch/modules/fla/chunk.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ def forward(
9090
cu_seqlens: Optional[torch.LongTensor] = None,
9191
use_qk_l2norm_in_kernel: bool = False,
9292
):
93-
pass
94-
9593
if use_qk_l2norm_in_kernel:
9694
q = l2norm_fwd(q)
9795
k = l2norm_fwd(k)

0 commit comments

Comments
 (0)