Skip to content

Commit ff0dd60

Browse files
authored
[TRTLLM-10062][feat] Enable MTP for Nemotron Super (#10754)
Signed-off-by: qgai <qgai@nvidia.com>
1 parent 43b8a55 commit ff0dd60

File tree

17 files changed

+2244
-313
lines changed

17 files changed

+2244
-313
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,11 @@ class TllmGenFmhaKernel
675675
{
676676
continue;
677677
}
678+
// If tileSizeQ < mNumHeadsQPerKv, this will result in 0, causing division by zero.
679+
if (tileSizeQ < params.mNumHeadsQPerKv)
680+
{
681+
continue;
682+
}
678683

679684
// Update the tileSizeQ.
680685
selectKernelParamsCopy.mTileSizeQ = tileSizeQ;

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,15 @@ def _triton_cached_ssm(
144144
num_seq = num_prefill + num_decode
145145
num_total_tokens = num_prefill_tokens + num_decode
146146

147-
y_prefill = None
148-
y_decode = None
147+
# Preallocate output tensor to avoid memcpy cost for merging prefill
148+
# and decode outputs
149+
preallocated_ssm_out = torch.empty(
150+
[bs, num_heads, head_dim],
151+
dtype=hidden_states.dtype,
152+
device=hidden_states.device,
153+
)
154+
preallocated_ssm_out_p = preallocated_ssm_out[:num_prefill_tokens]
155+
preallocated_ssm_out_d = preallocated_ssm_out[num_prefill_tokens:num_total_tokens]
149156

150157
# Prefill: concatenate tokens at the front and run combined scan
151158
if num_prefill > 0:
@@ -165,7 +172,7 @@ def _triton_cached_ssm(
165172
chunk_indices = None
166173
chunk_offsets = None
167174

168-
y_prefill, varlen_states = mamba_chunk_scan_combined(
175+
varlen_states = mamba_chunk_scan_combined(
169176
hs_prefill,
170177
dt_prefill,
171178
A,
@@ -184,11 +191,12 @@ def _triton_cached_ssm(
184191
dt_limit=(time_step_limit[0], time_step_limit[1]),
185192
return_final_states=False,
186193
return_varlen_states=True,
187-
mamba_ssm_cache_dtype=ssm_state_cache.dtype,
194+
out=preallocated_ssm_out_p.unsqueeze(0),
195+
state_dtype=ssm_state_cache.dtype,
188196
)
189197

190198
ssm_state_cache.index_copy_(
191-
0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype)
199+
0, slot_idx[:num_prefill].long(), varlen_states.to(ssm_state_cache.dtype)
192200
)
193201

194202
# Decode: batch single-token updates via selective_state_update
@@ -205,7 +213,7 @@ def _triton_cached_ssm(
205213
A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size)
206214
D_full = D[..., None].expand(num_heads, head_dim)
207215

208-
y_decode = selective_state_update(
216+
selective_state_update(
209217
ssm_state_cache,
210218
x_decode,
211219
dt_hp,
@@ -217,19 +225,16 @@ def _triton_cached_ssm(
217225
dt_bias=dt_bias_hp,
218226
dt_softplus=True,
219227
state_batch_indices=slot_idx_decode,
220-
) # [nd, H, D]
221-
222-
# Dispatch return logic
223-
if num_prefill > 0 and num_decode > 0:
224-
y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
225-
y_flat = y.view(bs, *y.shape[2:])
226-
y_flat[:num_prefill_tokens].copy_(y_prefill[0])
227-
y_flat[num_prefill_tokens:num_total_tokens].copy_(y_decode)
228-
return y
229-
elif num_prefill > 0:
230-
return y_prefill[0].view(b, s, num_heads, head_dim).to(hidden_states.dtype)
231-
elif num_decode > 0:
232-
return y_decode.view(b, s, num_heads, head_dim).to(hidden_states.dtype)
228+
out=preallocated_ssm_out_d,
229+
)
230+
231+
# Return the preallocated output reshaped to original dimensions
232+
if num_total_tokens > 0:
233+
return (
234+
preallocated_ssm_out[:num_total_tokens]
235+
.view(b, s, num_heads, head_dim)
236+
.to(hidden_states.dtype)
237+
)
233238
else:
234239
return torch.empty_like(hidden_states)
235240

tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22

3+
import tensorrt_llm.logger as logger
34
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import \
45
HfWeightMapper
56
from tensorrt_llm._torch.models.modeling_utils import register_mapper
@@ -55,6 +56,16 @@ def _split_mamba2_mixer_in_proj(w: torch.Tensor) -> torch.Tensor:
5556
if "embeddings" in key:
5657
key = key.replace("embeddings", "embed_tokens")
5758

59+
# MTP layers are stored as mtp.layers.0.xxx (sublayer 0, Attention) and mtp.layers.1.xxx (sublayer 1, MoE)
60+
if "mtp.layers." in key:
61+
import re
62+
match = re.match(r'mtp\.layers\.(\d+)\.(.*)', key)
63+
if match:
64+
sublayer_idx, rest = match.groups()
65+
key = f"model.layers.{config.num_hidden_layers}.layers.{sublayer_idx}.{rest}"
66+
else:
67+
logger.error(f"Failed to match MTP pattern for: {name}")
68+
5869
if "A_log" in key:
5970
key = key.replace("A_log", "A")
6071

0 commit comments

Comments
 (0)