Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 43 additions & 32 deletions tensorrt_llm/_torch/models/modeling_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@
from ..modules.linear import Linear, TensorParallelMode
from ..modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from ..modules.mamba.layernorm_gated import RMSNorm as RMSNormGated
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
from ..utils import AuxStreamType
from ..utils import AuxStreamType, EventType
from .modeling_qwen3 import Qwen3Attention
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model
Expand Down Expand Up @@ -387,6 +388,7 @@ def __init__(
self.mapping = model_config.mapping
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.aux_stream = aux_stream

self.gate = Qwen3NextGate(
hidden_size=self.hidden_dim,
Expand Down Expand Up @@ -425,6 +427,11 @@ def __init__(
dtype=config.torch_dtype,
quant_config=None)

self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.Main, EventType.MoeShared]
}

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -450,22 +457,33 @@ def forward(
dim=0,
sizes=all_rank_num_tokens)

router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
do_finalize=do_finalize,
)
def _compute_routed_output():
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
do_finalize=do_finalize,
)
return final_hidden_states

def _compute_shared_output():
shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_expert_output
return shared_expert_output

final_hidden_states, shared_expert_output = maybe_execute_in_parallel(
_compute_routed_output,
_compute_shared_output,
self.event_dict[EventType.Main],
self.event_dict[EventType.MoeShared],
self.aux_stream,
)
if not do_finalize:
return final_hidden_states

shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_expert_output

final_hidden_states = final_hidden_states + shared_expert_output

if not self.enable_attention_dp and self.mapping.tp_size > 1:
Expand Down Expand Up @@ -1038,8 +1056,6 @@ def forward(
self.head_v_dim,
)
else:
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba)
query, key, value = map(lambda x: x.reshape(x.shape[0], -1),
(query, key, value))
mixed_qkv = torch.cat((query, key, value), dim=-1)
Expand All @@ -1061,25 +1077,20 @@ def forward(
"num_decode": num_decodes,
}

new_implementation = True
if new_implementation:
if num_prefills > 0:
attn_out = self.forward_extend(conv_states, ssm_states,
**kwargs)
else:
attn_out = self.forward_decode(conv_states, ssm_states,
num_decodes,
mamba_metadata.cu_seqlens,
**kwargs)
if num_prefills > 0:
attn_out = self.forward_extend(conv_states, ssm_states, **kwargs)
else:
attn_out = self.forward_decode(conv_states, ssm_states, num_decodes,
mamba_metadata.cu_seqlens, **kwargs)

z_shape_og = z.shape
# reshape input data into 2D tensor
# reshape input data into 2D tensor for norm
batch_size = z.shape[0]
final_dim = z.shape[1] * z.shape[2]
attn_out = attn_out.reshape(-1, attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
attn_out = self.norm(attn_out, z)
attn_out = attn_out.reshape(z_shape_og)
attn_out = attn_out.reshape(*attn_out.shape[:-2], -1)

# directly reshape to final output shape [batch, num_heads_v * head_v]
attn_out = attn_out.reshape(batch_size, final_dim)
output = self.out_proj(attn_out, all_reduce_params=all_reduce_params)
return output

Expand Down Expand Up @@ -1125,7 +1136,7 @@ def __init__(
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "1") == "0"
self.enable_fusion &= not self.enable_attention_dp

self.mapping.has_tp()
# has_tp = self.mapping.has_tp()
has_pp = self.mapping.has_pp()

# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
Expand Down Expand Up @@ -1284,7 +1295,7 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig],
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "0") == "0"
self.enable_fusion &= not self.enable_attention_dp

self.mapping.has_tp()
# has_tp = self.mapping.has_tp()
has_pp = self.mapping.has_pp()

# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
Expand Down
2 changes: 0 additions & 2 deletions tensorrt_llm/_torch/modules/fla/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ def forward(
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
pass

if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
Expand Down