Skip to content
33 changes: 29 additions & 4 deletions cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
* Copyright (c) 2024, Tri Dao.
*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -349,20 +349,45 @@ void causal_conv1d_fwd_launch(ConvParamsBase& params, cudaStream_t stream)
});
}

template <int kWidth, typename input_t, typename weight_t>
void causal_conv1d_fwd_dispatch(ConvParamsBase& params, cudaStream_t stream)
{
bool const isVarlen = params.query_start_loc_ptr != nullptr;
constexpr int kNarrowThreads = 64;
constexpr int kWideThreads = 128;
constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
constexpr int kShortSeqThreshold = kNarrowThreads * kNElts;
// Varlen prefill launches one block per sequence/channel pair, so the per-sequence
// work is usually much smaller than params.seqlen suggests. That path also disables
// the wide vector-load specialization, so the 128-thread kernel tends to overprovision
// threads for many short chunks. Prefer the narrower launch for varlen and for short
// fixed-length inputs; keep the wider launch for long dense sequences.
bool const preferNarrowKernel = isVarlen || params.seqlen <= kShortSeqThreshold;

if (preferNarrowKernel)
{
causal_conv1d_fwd_launch<kNarrowThreads, kWidth, input_t, weight_t>(params, stream);
}
else
{
causal_conv1d_fwd_launch<kWideThreads, kWidth, input_t, weight_t>(params, stream);
}
}

template <typename input_t, typename weight_t>
void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream)
{
if (params.width == 2)
{
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
causal_conv1d_fwd_dispatch<2, input_t, weight_t>(params, stream);
}
else if (params.width == 3)
{
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
causal_conv1d_fwd_dispatch<3, input_t, weight_t>(params, stream);
}
else if (params.width == 4)
{
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
causal_conv1d_fwd_dispatch<4, input_t, weight_t>(params, stream);
}
}

Expand Down
62 changes: 38 additions & 24 deletions tensorrt_llm/_torch/models/modeling_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from tensorrt_llm._torch.modules.fla.chunk import chunk_gated_delta_rule
from tensorrt_llm._torch.modules.fla.fused_sigmoid_gating_recurrent import \
fused_sigmoid_gating_delta_rule_update
from tensorrt_llm._torch.modules.mamba.fuse_elementwise_ops import \
extract_transpose_prefill_slice
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
from tensorrt_llm._torch.pyexecutor.config_utils import \
get_qwen3_hybrid_layer_types
Expand Down Expand Up @@ -454,6 +456,10 @@ def __init__(self,
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.num_k_heads_per_tp = divide(self.num_k_heads, self.attn_tp_size)
self.num_v_heads_per_tp = divide(self.num_v_heads, self.attn_tp_size)
self.key_dim_per_tp = self.head_k_dim * self.num_k_heads_per_tp
self.value_dim_per_tp = self.head_v_dim * self.num_v_heads_per_tp

self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = layer_idx
Expand Down Expand Up @@ -618,18 +624,15 @@ def forward_decode(
conv_state_indices=cache_indices,
)

# Direct slicing instead of torch.split for better performance
key_size = self.key_dim // self.attn_tp_size
query = mixed_qkv[..., :key_size]
key = mixed_qkv[..., key_size:key_size * 2]
value = mixed_qkv[..., key_size * 2:]
# Reshape from [l, h*d] to [1, l, h, d]
# Keep q/k/v as views over mixed_qkv so the fused decode kernel can
# consume their native strides without forcing packed copies.
query = mixed_qkv[..., :self.key_dim_per_tp]
key = mixed_qkv[..., self.key_dim_per_tp:self.key_dim_per_tp * 2]
value = mixed_qkv[..., self.key_dim_per_tp * 2:]
seq_len = query.shape[0]
num_heads = query.shape[1] // self.head_k_dim
query = query.view(1, seq_len, num_heads, self.head_k_dim)
key = key.view(1, seq_len, num_heads, self.head_k_dim)
value = value.view(1, seq_len, value.shape[1] // self.head_v_dim,
self.head_v_dim)
query = query.view(1, seq_len, self.num_k_heads_per_tp, self.head_k_dim)
key = key.view(1, seq_len, self.num_k_heads_per_tp, self.head_k_dim)
value = value.view(1, seq_len, self.num_v_heads_per_tp, self.head_v_dim)

core_attn_out = fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log,
Expand Down Expand Up @@ -679,16 +682,22 @@ def forward_extend(
query_start_loc_p = query_start_loc[:num_prefill + 1]
has_initial_states_p = has_initial_states[:num_prefill]

mixed_qkv_p = causal_conv1d_fn(
mixed_qkv_p.transpose(0, 1),
mixed_qkv_p_t = extract_transpose_prefill_slice(
mixed_qkv_p,
mixed_qkv_p.shape[0],
0,
mixed_qkv_p.shape[1],
)
mixed_qkv_p_t = causal_conv1d_fn(
mixed_qkv_p_t,
self.conv1d.weight,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_states_to_use,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_p,
query_start_loc=query_start_loc_p,
).transpose(0, 1)
)

mixed_qkv_d = causal_conv1d_update(
mixed_qkv_d,
Expand All @@ -698,10 +707,16 @@ def forward_extend(
activation=self.activation,
conv_state_indices=state_indices_d,
)
mixed_qkv = torch.cat((mixed_qkv_p, mixed_qkv_d), dim=0)
mixed_qkv_p.copy_(mixed_qkv_p_t.transpose(0, 1))
else:
mixed_qkv_t = extract_transpose_prefill_slice(
mixed_qkv,
mixed_qkv.shape[0],
0,
mixed_qkv.shape[1],
)
mixed_qkv = causal_conv1d_fn(
mixed_qkv.transpose(0, 1),
mixed_qkv_t,
self.conv1d.weight,
self.conv1d.bias,
activation=self.activation,
Expand Down Expand Up @@ -733,23 +748,22 @@ def forward_extend(
g = g.unsqueeze(0)
beta = beta.unsqueeze(0)

recurrent_state = ssm_states[cache_indices]

core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
core_attn_out, _ = chunk_gated_delta_rule(
q=query,
k=key,
v=value,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=True,
initial_state=ssm_states,
initial_state_indices=cache_indices,
# This path writes recurrent state directly back into the shared
# pool; callers **must** ensure cache_indices do not alias live slots.
inplace_indexed_state_update=True,
output_final_state=False,
cu_seqlens=query_start_loc_long,
head_first=False,
use_qk_l2norm_in_kernel=True,
)
last_recurrent_state = last_recurrent_state.to(ssm_states.dtype,
copy=False)
ssm_states[cache_indices] = last_recurrent_state

return core_attn_out

Expand Down
33 changes: 29 additions & 4 deletions tensorrt_llm/_torch/modules/fla/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def chunk_gated_delta_rule_fwd(
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
initial_state_indices: Optional[torch.Tensor],
inplace_indexed_state_update: bool,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
Expand All @@ -54,7 +56,9 @@ def chunk_gated_delta_rule_fwd(
u=u,
g=g,
initial_state=initial_state,
initial_state_indices=initial_state_indices,
output_final_state=output_final_state,
inplace_indexed_state_update=inplace_indexed_state_update,
cu_seqlens=cu_seqlens,
)
o = chunk_fwd_o(
Expand Down Expand Up @@ -86,6 +90,8 @@ def forward(
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
initial_state_indices: Optional[torch.Tensor],
inplace_indexed_state_update: bool,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
Expand All @@ -102,6 +108,8 @@ def forward(
beta=beta,
scale=scale,
initial_state=initial_state,
initial_state_indices=initial_state_indices,
inplace_indexed_state_update=inplace_indexed_state_update,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
Expand All @@ -117,6 +125,8 @@ def chunk_gated_delta_rule(
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
initial_state_indices: Optional[torch.Tensor] = None,
inplace_indexed_state_update: bool = False,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
Expand All @@ -141,6 +151,13 @@ def chunk_gated_delta_rule(
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
initial_state_indices (Optional[torch.Tensor]):
Optional state-pool indices of shape `[N]` selecting the slots to
read from `initial_state`.
inplace_indexed_state_update (Optional[bool]):
Explicit opt-in for writing indexed final states back into
`initial_state` in-place. Callers are responsible for ensuring the
selected slots are safe to update without aliasing races.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Expand Down Expand Up @@ -211,12 +228,18 @@ def chunk_gated_delta_rule(
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if initial_state is not None and initial_state.shape[0] != len(
cu_seqlens) - 1:
num_sequences = len(cu_seqlens) - 1
if initial_state_indices is not None:
if initial_state_indices.shape[0] != num_sequences:
raise ValueError(
f"The number of initial-state indices is expected to be equal to the number of input "
f"sequences, i.e., {num_sequences} rather than {initial_state_indices.shape[0]}."
)
elif initial_state is not None and initial_state.shape[
0] != num_sequences:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
f"i.e., {num_sequences} rather than {initial_state.shape[0]}.")
if scale is None:
scale = k.shape[-1]**-0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
Expand All @@ -227,6 +250,8 @@ def chunk_gated_delta_rule(
beta,
scale,
initial_state,
initial_state_indices,
inplace_indexed_state_update,
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
Expand Down
29 changes: 26 additions & 3 deletions tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
@triton.heuristics({
"USE_G": lambda args: args["g"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"USE_INDEXED_STATE": lambda args: args["h0_i"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
Expand All @@ -42,6 +43,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
g,
h,
h0,
h0_i,
ht,
cu_seqlens,
chunk_offsets,
Expand All @@ -54,6 +56,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
USE_INDEXED_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
IS_VARLEN: tl.constexpr,
Expand Down Expand Up @@ -91,10 +94,16 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
stride_h = H * K * V
stride_k = Hg * K
stride_w = H * K
if USE_INDEXED_STATE:
state_index = tl.load(h0_i + i_n).to(tl.int64)
h0 = h0 + state_index * stride_h
ht = h0
if USE_INITIAL_STATE:
h0 = h0 + i_nh * K * V
h0 = h0 + ((i_h if USE_INDEXED_STATE else i_nh) * K * V)
if STORE_FINAL_STATE:
ht = ht + i_nh * K * V
elif USE_INDEXED_STATE:
ht = ht + i_h * K * V

# load initial state
if USE_INITIAL_STATE:
Expand Down Expand Up @@ -209,7 +218,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h4 += tl.dot(b_k, b_v_new)

# epilogue
if STORE_FINAL_STATE:
if STORE_FINAL_STATE or USE_INDEXED_STATE:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV),
(1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
Expand Down Expand Up @@ -239,7 +248,9 @@ def chunk_gated_delta_rule_fwd_h(
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
initial_state_indices: Optional[torch.Tensor] = None,
output_final_state: bool = False,
inplace_indexed_state_update: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
Expand All @@ -262,8 +273,14 @@ def chunk_gated_delta_rule_fwd_h(
assert K <= 256, "current kernel does not support head dimension larger than 256."

h = k.new_empty(B, NT, H, K, V)
use_indexed_state = initial_state is not None and initial_state_indices is not None
if use_indexed_state and not inplace_indexed_state_update:
raise ValueError(
"Indexed chunk state updates require inplace_indexed_state_update=True."
)
store_final_state_in_kernel = output_final_state and not use_indexed_state
final_state = (k.new_empty(N, H, K, V, dtype=torch.float32)
if output_final_state else None)
if store_final_state_in_kernel else None)

v_new = torch.empty_like(u) if save_new_value else None

Expand All @@ -278,6 +295,7 @@ def grid(meta):
g=g,
h=h,
h0=initial_state,
h0_i=initial_state_indices,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
Expand All @@ -291,4 +309,9 @@ def grid(meta):
num_warps=4,
num_stages=2,
)
if output_final_state and use_indexed_state:
# The indexed kernel path updates h0 in-place, so returning
# the final state means gathering those updated slots back out.
final_state = initial_state.index_select(
0, initial_state_indices.to(torch.long))
return h, v_new, final_state
Loading
Loading