Skip to content

[None][feat] Add Mamba2 MTP SSM cache CUDA kernel for tree-based speculative decoding#12537

Open
JadoTu wants to merge 1 commit intoNVIDIA:mainfrom
JadoTu:mamba2_tree_based_mtp_CUDA_kernel
Open

[None][feat] Add Mamba2 MTP SSM cache CUDA kernel for tree-based speculative decoding#12537
JadoTu wants to merge 1 commit intoNVIDIA:mainfrom
JadoTu:mamba2_tree_based_mtp_CUDA_kernel

Conversation

@JadoTu
Copy link
Collaborator

@JadoTu JadoTu commented Mar 25, 2026

Summary by CodeRabbit

Release Notes

  • New Features
    • Added Mamba2 MTP SSM cache update operation with CUDA acceleration and PyTorch integration for enhanced multi-token state space model computations.
    • Supports float32, float16, and bfloat16 precision formats with configurable gating and caching options.

Description

This commit adds a CUDA kernel (mamba2_mtp_ssm_cache_update) that:

  1. Supports tree-based speculative decoding via retrieve_parent_token — at each step, the kernel can restore the SSM state from an arbitrary parent token's cached entry, enabling non-linear (tree-structured) draft token verification.

  2. Flashinfer kernel now doesn't support tree-based speculative decoding. Achieves ~10% performance improvement over the FlashInfer kernel even without tree-based speculative decoding.
    Flashinfer kernel

image PR's kernel image
  1. With tree-based speculative decoding, this PR's kernel is 2x faster than the initial triton kernel (183 us v.s. 369 us).

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

…ulative decoding

Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
@JadoTu JadoTu requested review from a team as code owners March 25, 2026 09:55
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 25, 2026

📝 Walkthrough

Walkthrough

This pull request introduces a new Mamba2 MTP (Multi-Token Planning) SSM cache update operation. It adds a complete CUDA kernel implementation with templated specializations, PyTorch custom operator bindings, and Python API wrappers to enable multi-token state caching for Mamba2 models.

Changes

Cohort / File(s) Summary
CUDA Kernel Core
cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.h, mamba2MTPSSMCacheKernel.cuh
New header defining Mamba2Dtype enum, Mamba2MTPSSMCacheParams struct, and kernel function declarations. Core kernel implements per-head-batch SSM state update with optional parent-token retrieval, fused B/C accumulation, and data type dispatch across bfloat16, float16, and float32.
Kernel Specializations
cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.cu, mamba2MTPSSMCacheVec4.cu, mamba2MTPSSMCacheVec8.cu, mamba2MTPSSMCacheVec16.cu
Runtime dispatch function validates head_dim divisibility and selects compile-time vector size (4/8/16) based on ssm_dim (128/256/512). Three translation units explicitly instantiate launchMamba2MTPSSMCacheKernel templates for each vector size.
PyTorch Operator
cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp
New Torch custom operator mamba2_mtp_ssm_cache_update validates tensor shapes, device placement, and stride requirements for all inputs (including optional D, z, dt_bias, batch indices tensors). Constructs kernel parameters, handles dtype mapping, and dispatches to invokeMamba2MTPSSMCacheUpdate.
Python API
tensorrt_llm/_torch/modules/mamba/selective_state_update.py
Added selective_state_update_mtp_ssm_cache_trtllm() wrapper function forwarding inputs to the custom Torch operator with multi-token caching parameters and intermediate state buffering.
Build Configuration
cpp/tensorrt_llm/CMakeLists.txt, cpp/tensorrt_llm/kernels/CMakeLists.txt, cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/CMakeLists.txt, cpp/tensorrt_llm/thop/CMakeLists.txt
Extended main build system to link mamba2_mtp_ssm_cache_src library. Added new subdirectory and excluded its .cu files from main globbing. New CMake file defines object library with position-independent code and device symbol resolution enabled.

Sequence Diagram(s)

sequenceDiagram
    participant PY as Python Code
    participant TORCH as PyTorch Layer
    participant OP as Custom Operator<br/>(mamba2MTPSSMCacheOp)
    participant DISPATCH as Runtime Dispatch<br/>(mamba2MTPSSMCache.cu)
    participant KERNEL as CUDA Kernel<br/>(mamba2MTPSSMCacheKernel)

    PY->>TORCH: Call selective_state_update_mtp_ssm_cache_trtllm()
    TORCH->>OP: torch.ops.trtllm.mamba2_mtp_ssm_cache_update(...)
    OP->>OP: Validate tensor shapes & device placement
    OP->>OP: Extract raw pointers & optional parameters
    OP->>OP: Map PyTorch dtypes → Mamba2Dtype
    OP->>DISPATCH: invokeMamba2MTPSSMCacheUpdate(params, stream)
    DISPATCH->>DISPATCH: Validate head_dim divisibility
    DISPATCH->>DISPATCH: Select VEC_SIZE from ssm_dim
    DISPATCH->>KERNEL: Launch kernel with grid/block/stream
    KERNEL->>KERNEL: Process per (head_id, bs_id) pair
    KERNEL->>KERNEL: Load state rows & iterate cache_steps
    KERNEL->>KERNEL: Compute dt dynamics & accumulate B/C
    KERNEL->>KERNEL: Optional D gating & z modulation
    KERNEL->>KERNEL: Write results to intermediate_states & state
    KERNEL-->>PY: Return updated tensors
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description addresses key features and performance metrics but lacks clarity in critical sections required by the template. Complete the 'Test Coverage' section with specific test names/locations; clarify which checklist items were actually completed versus marked as assumed complete.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding a Mamba2 MTP SSM cache CUDA kernel for tree-based speculative decoding, which aligns with the substantial code additions across multiple kernel files and wrapper functions.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.cu`:
- Around line 56-65: Add a precondition check in invokeMamba2MTPSSMCacheUpdate
to validate params.ngroups and the head/group ratio before dispatch: ensure
params.ngroups > 0 and params.nheads % params.ngroups == 0 (so
heads_groups_ratio = nheads/ngroups is integral and non-zero) to prevent
division-by-zero and incorrect grouping used by mamba2MTPSSMCacheKernel.cuh; use
the existing TLLM_CHECK_WITH_INFO pattern (same as the head_dim check) to fail
fast with a clear message referencing ngroups and nheads before calling
MTP_DISPATCH_VEC_SIZE and launchMamba2MTPSSMCacheKernel.

In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.h`:
- Around line 35-79: Update the Mamba2MTPSSMCacheParams struct documentation to
reflect the actual tensor layouts used by mamba2MTPSSMCacheKernel: change dt, A,
D, and dt_bias comments to indicate they are indexed only by head (no
head_dim/ssm_dim axes) and correct their shapes accordingly (e.g., dt: [bs,
cache_steps, nheads], A: [nheads], D/dt_bias: [nheads] or similarly 1D per head
as used in kernel), and convert all public comments to Doxygen C++ style (use
//! for single-line comments and //!< for member annotations) including the
prototype for invokeMamba2MTPSSMCacheUpdate; reference struct name
Mamba2MTPSSMCacheParams and the kernel usage locations
(mamba2MTPSSMCacheKernel.cuh lines mentioned) so callers provision tensors with
the right layouts.

In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheKernel.cuh`:
- Around line 372-383: The kernel is restoring parent state from
retrieve_parent_token even when parent_step_idx == t or > t (future/unwritten
entries); update the guard in the RETRIEVE_PARENT_TOKEN block so you only load
when parent_step_idx >= 0 && parent_step_idx < t (not < cache_steps), e.g. in
the block around retrieve_parent_token[bs_id * cache_steps + t] before calling
mtp_load_vec_to_float for state_4_a/state_4_b with inter_base_a/inter_base_b and
stride_nheads_hdim_ssm_dim ensure the check uses t as the upper bound to only
restore from already-materialized steps.
- Around line 113-116: The current mtp_softplus uses __logf(1 +
__expf(dt_value)) which overflows for large dt_value (causing
xdt_val_a/xdt_val_b to become inf); replace it with a numerically stable branch:
for dt_value > 0 return dt_value + __logf(1.f + __expf(-dt_value)), otherwise
return __logf(1.f + __expf(dt_value)) so exponentials are computed on non-large
positive inputs and downstream state/output corruption is avoided; update the
mtp_softplus device function accordingly (it’s the function to change).

In `@cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp`:
- Around line 78-82: B.size(2) can be zero, causing undefined host behavior when
evaluating nheads % ngroups; add an explicit check that rejects ngroups == 0
before performing the modulo. Concretely, after computing int const ngroups =
B.size(2); add a TORCH_CHECK(ngroups > 0, "ngroups must be > 0") (or similar)
and then keep the existing TORCH_CHECK(nheads % ngroups == 0, "unsupported pair
of nheads and ngroups") so the modulo is only executed on a non-zero ngroups.
- Around line 50-63: Add a host-side validation that when ssm_batch_indices is
not provided (i.e., kernel will use direct bs_id indexing) the first dimension
of ssm is at least the batch size of x: check ssm.size(0) >= x.size(0) and raise
a TORCH_CHECK with a clear message if not; locate this guard near the existing
tensor shape/device checks around ssm, x, and intermediate_states in
mamba2MTPSSMCacheOp.cpp so it runs before launching the kernel and prevents
out-of-bounds access when ssm_batch_indices is null.
- Around line 209-210: The code calls at::cuda::getCurrentCUDAStream() without
setting the device, which can pick the thread-local device instead of the device
associated with params; fix by including <c10/cuda/CUDAGuard.h> and creating a
c10::cuda::CUDAGuard guard(params.device()) (or the appropriate device accessor
from params) immediately before calling at::cuda::getCurrentCUDAStream(), then
get the stream and call tk::invokeMamba2MTPSSMCacheUpdate(params, stream) so the
kernel is launched on the correct device.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 73658ce3-902d-4e56-bd1e-53fc964d780d

📥 Commits

Reviewing files that changed from the base of the PR and between 2b5c434 and 1d2bb3b.

📒 Files selected for processing (12)
  • cpp/tensorrt_llm/CMakeLists.txt
  • cpp/tensorrt_llm/kernels/CMakeLists.txt
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/CMakeLists.txt
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.cu
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.h
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheKernel.cuh
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheVec16.cu
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheVec4.cu
  • cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheVec8.cu
  • cpp/tensorrt_llm/thop/CMakeLists.txt
  • cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp
  • tensorrt_llm/_torch/modules/mamba/selective_state_update.py

Comment on lines +56 to +65
void invokeMamba2MTPSSMCacheUpdate(Mamba2MTPSSMCacheParams const& params, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(
(params.head_dim % (MTP_NUM_WARPS * MTP_HDIMS_PER_WARP)) == 0, "head_dim should be a multiple of 8");

dim3 block(MTP_NUM_BLOCK_THREADS);
dim3 grid(params.head_dim / (MTP_NUM_WARPS * MTP_HDIMS_PER_WARP), params.nheads, params.bs);

MTP_DISPATCH_VEC_SIZE(
params.ssm_dim, DISPATCH_VS, [&] { launchMamba2MTPSSMCacheKernel<DISPATCH_VS>(params, grid, block, stream); });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard the head/group ratio before dispatch.

The launcher computes heads_groups_ratio = nheads / ngroups in mamba2MTPSSMCacheKernel.cuh at Line 492 and uses it for the B/C group lookup at Line 359 and Line 362. Without a precondition here, ngroups == 0 faults and non-divisible shapes silently route heads to the wrong group.

Suggested fix
 void invokeMamba2MTPSSMCacheUpdate(Mamba2MTPSSMCacheParams const& params, cudaStream_t stream)
 {
+    TLLM_CHECK_WITH_INFO(
+        params.ngroups > 0 && (params.nheads % params.ngroups) == 0,
+        "ngroups must be > 0 and divide nheads");
     TLLM_CHECK_WITH_INFO(
         (params.head_dim % (MTP_NUM_WARPS * MTP_HDIMS_PER_WARP)) == 0, "head_dim should be a multiple of 8");
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void invokeMamba2MTPSSMCacheUpdate(Mamba2MTPSSMCacheParams const& params, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(
(params.head_dim % (MTP_NUM_WARPS * MTP_HDIMS_PER_WARP)) == 0, "head_dim should be a multiple of 8");
dim3 block(MTP_NUM_BLOCK_THREADS);
dim3 grid(params.head_dim / (MTP_NUM_WARPS * MTP_HDIMS_PER_WARP), params.nheads, params.bs);
MTP_DISPATCH_VEC_SIZE(
params.ssm_dim, DISPATCH_VS, [&] { launchMamba2MTPSSMCacheKernel<DISPATCH_VS>(params, grid, block, stream); });
void invokeMamba2MTPSSMCacheUpdate(Mamba2MTPSSMCacheParams const& params, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(
params.ngroups > 0 && (params.nheads % params.ngroups) == 0,
"ngroups must be > 0 and divide nheads");
TLLM_CHECK_WITH_INFO(
(params.head_dim % (MTP_NUM_WARPS * MTP_HDIMS_PER_WARP)) == 0, "head_dim should be a multiple of 8");
dim3 block(MTP_NUM_BLOCK_THREADS);
dim3 grid(params.head_dim / (MTP_NUM_WARPS * MTP_HDIMS_PER_WARP), params.nheads, params.bs);
MTP_DISPATCH_VEC_SIZE(
params.ssm_dim, DISPATCH_VS, [&] { launchMamba2MTPSSMCacheKernel<DISPATCH_VS>(params, grid, block, stream); });
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.cu` around lines
56 - 65, Add a precondition check in invokeMamba2MTPSSMCacheUpdate to validate
params.ngroups and the head/group ratio before dispatch: ensure params.ngroups >
0 and params.nheads % params.ngroups == 0 (so heads_groups_ratio =
nheads/ngroups is integral and non-zero) to prevent division-by-zero and
incorrect grouping used by mamba2MTPSSMCacheKernel.cuh; use the existing
TLLM_CHECK_WITH_INFO pattern (same as the head_dim check) to fail fast with a
clear message referencing ngroups and nheads before calling
MTP_DISPATCH_VEC_SIZE and launchMamba2MTPSSMCacheKernel.

Comment on lines +35 to +79
struct Mamba2MTPSSMCacheParams
{
// State tensors
void* ssm; // [batch, nheads, head_dim, ssm_dim]
void* intermediate_states; // [batch, cache_steps, nheads, head_dim, ssm_dim]

// Input tensors
void const* x; // [bs, cache_steps, nheads, head_dim]
void const* dt; // [bs, cache_steps, nheads, head_dim] (strided)
void const* A; // [nheads, head_dim, ssm_dim] (strided)
void const* B; // [bs, cache_steps, ngroups, ssm_dim]
void const* C; // [bs, cache_steps, ngroups, ssm_dim]

// Output
void* out; // [bs, cache_steps, nheads, head_dim]

// Optional tensors (nullptr if not present)
void const* D; // [nheads, head_dim]
void const* z; // [bs, cache_steps, nheads, head_dim]
void const* dt_bias; // [nheads, head_dim]
int32_t const* ssm_batch_indices; // [bs]
int32_t const* intermediate_states_indices; // [bs]
int32_t const* retrieve_parent_token; // [bs, cache_steps]

// Flags
bool dt_softplus;

// Dimensions
int cache_steps;
int pad_slot_id;
bool disable_state_update;
int bs;
int nheads;
int head_dim;
int ssm_dim;
int ngroups;

// Data types
Mamba2Dtype ssm_dtype;
Mamba2Dtype in_out_dtype;
Mamba2Dtype weight_dtype;
Mamba2Dtype a_dtype;
};

void invokeMamba2MTPSSMCacheUpdate(Mamba2MTPSSMCacheParams const& params, cudaStream_t stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Clarify the real tensor layout in this public header.

dt, A, D, and dt_bias are documented as if they carry extra head_dim/ssm_dim axes, but the kernel only indexes them by head_id in mamba2MTPSSMCacheKernel.cuh at Line 348, Line 351, Line 355, and Line 395. Please fix the shapes and add proper Doxygen on this struct/prototype so future callers do not provision the wrong layout.

As per coding guidelines, "Follow Doxygen rules for documenting new C++ class interfaces and function prototypes. Use //! for C++-style single-line comments and //!< for class members".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCache.h` around lines
35 - 79, Update the Mamba2MTPSSMCacheParams struct documentation to reflect the
actual tensor layouts used by mamba2MTPSSMCacheKernel: change dt, A, D, and
dt_bias comments to indicate they are indexed only by head (no head_dim/ssm_dim
axes) and correct their shapes accordingly (e.g., dt: [bs, cache_steps, nheads],
A: [nheads], D/dt_bias: [nheads] or similarly 1D per head as used in kernel),
and convert all public comments to Doxygen C++ style (use //! for single-line
comments and //!< for member annotations) including the prototype for
invokeMamba2MTPSSMCacheUpdate; reference struct name Mamba2MTPSSMCacheParams and
the kernel usage locations (mamba2MTPSSMCacheKernel.cuh lines mentioned) so
callers provision tensors with the right layouts.

Comment on lines +113 to +116
__device__ __forceinline__ float mtp_softplus(float dt_value)
{
return __logf(1.f + __expf(dt_value));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use a stable softplus implementation here.

log(1 + exp(x)) overflows for large positive dt_value, and this target is compiled with --use_fast_math in cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/CMakeLists.txt at Line 21 to Line 24. Once that happens, xdt_val_a and xdt_val_b become inf, and every downstream state/output value for the step is corrupted.

Suggested fix
 __device__ __forceinline__ float mtp_softplus(float dt_value)
 {
-    return __logf(1.f + __expf(dt_value));
+    float const absDtValue = fabsf(dt_value);
+    return fmaxf(dt_value, 0.0f) + __logf(1.0f + __expf(-absDtValue));
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
__device__ __forceinline__ float mtp_softplus(float dt_value)
{
return __logf(1.f + __expf(dt_value));
}
__device__ __forceinline__ float mtp_softplus(float dt_value)
{
float const absDtValue = fabsf(dt_value);
return fmaxf(dt_value, 0.0f) + __logf(1.0f + __expf(-absDtValue));
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheKernel.cuh`
around lines 113 - 116, The current mtp_softplus uses __logf(1 +
__expf(dt_value)) which overflows for large dt_value (causing
xdt_val_a/xdt_val_b to become inf); replace it with a numerically stable branch:
for dt_value > 0 return dt_value + __logf(1.f + __expf(-dt_value)), otherwise
return __logf(1.f + __expf(dt_value)) so exponentials are computed on non-large
positive inputs and downstream state/output corruption is avoided; update the
mtp_softplus device function accordingly (it’s the function to change).

Comment on lines +372 to +383
if constexpr (RETRIEVE_PARENT_TOKEN)
{
if (t != 0)
{
int parent_step_idx = retrieve_parent_token[bs_id * cache_steps + t];
if (parent_step_idx >= 0 && parent_step_idx < cache_steps)
{
mtp_load_vec_to_float<VEC_SIZE>(
state_4_a, inter_base_a + parent_step_idx * stride_nheads_hdim_ssm_dim);
mtp_load_vec_to_float<VEC_SIZE>(
state_4_b, inter_base_b + parent_step_idx * stride_nheads_hdim_ssm_dim);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Only restore from already-materialized parent steps.

At step t, the kernel has only written intermediate states [0, t). The current parent_step_idx < cache_steps check still accepts retrieve_parent_token[t] == t or any future step, which reads unwritten cache entries and silently corrupts the branch state.

Suggested fix
-                if (parent_step_idx >= 0 && parent_step_idx < cache_steps)
+                if (parent_step_idx >= 0 && parent_step_idx < t)
                 {
                     mtp_load_vec_to_float<VEC_SIZE>(
                         state_4_a, inter_base_a + parent_step_idx * stride_nheads_hdim_ssm_dim);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if constexpr (RETRIEVE_PARENT_TOKEN)
{
if (t != 0)
{
int parent_step_idx = retrieve_parent_token[bs_id * cache_steps + t];
if (parent_step_idx >= 0 && parent_step_idx < cache_steps)
{
mtp_load_vec_to_float<VEC_SIZE>(
state_4_a, inter_base_a + parent_step_idx * stride_nheads_hdim_ssm_dim);
mtp_load_vec_to_float<VEC_SIZE>(
state_4_b, inter_base_b + parent_step_idx * stride_nheads_hdim_ssm_dim);
}
if constexpr (RETRIEVE_PARENT_TOKEN)
{
if (t != 0)
{
int parent_step_idx = retrieve_parent_token[bs_id * cache_steps + t];
if (parent_step_idx >= 0 && parent_step_idx < t)
{
mtp_load_vec_to_float<VEC_SIZE>(
state_4_a, inter_base_a + parent_step_idx * stride_nheads_hdim_ssm_dim);
mtp_load_vec_to_float<VEC_SIZE>(
state_4_b, inter_base_b + parent_step_idx * stride_nheads_hdim_ssm_dim);
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/kernels/mamba2MTPSSMCache/mamba2MTPSSMCacheKernel.cuh`
around lines 372 - 383, The kernel is restoring parent state from
retrieve_parent_token even when parent_step_idx == t or > t (future/unwritten
entries); update the guard in the RETRIEVE_PARENT_TOKEN block so you only load
when parent_step_idx >= 0 && parent_step_idx < t (not < cache_steps), e.g. in
the block around retrieve_parent_token[bs_id * cache_steps + t] before calling
mtp_load_vec_to_float for state_4_a/state_4_b with inter_base_a/inter_base_b and
stride_nheads_hdim_ssm_dim ensure the check uses t as the upper bound to only
restore from already-materialized steps.

Comment on lines +50 to +63
TORCH_CHECK(ssm.dim() == 4 && ssm.is_cuda() && ssm.is_contiguous(), "ssm should be a 4D contiguous CUDA tensor");
TORCH_CHECK(x.dim() == 4, "x should be a 4D tensor");
auto device = ssm.device();
int const bs = x.size(0);
int const nheads = ssm.size(1);
int const head_dim = ssm.size(2);
int const ssm_dim = ssm.size(3);

TORCH_CHECK(intermediate_states.dim() == 5 && intermediate_states.size(0) == ssm.size(0)
&& intermediate_states.size(1) == cache_steps && intermediate_states.size(2) == nheads
&& intermediate_states.size(3) == head_dim && intermediate_states.size(4) == ssm_dim,
"intermediate_states shape check failed");
TORCH_CHECK(intermediate_states.device() == device && intermediate_states.is_contiguous(),
"intermediate_states is not a contiguous tensor of the same device as ssm");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

fd . cpp/tensorrt_llm/kernels/mamba2MTPSSMCache -t f | sort
rg -n -C4 'ssm.*indices|batch.*idx|nullptr|cache.*idx' cpp/tensorrt_llm/kernels/mamba2MTPSSMCache

Repository: NVIDIA/TensorRT-LLM

Length of output: 11446


Add host-side validation for minimum SSM buffer size when batch indices are not provided.

When ssm_batch_indices is not provided, the kernel uses direct batch indexing (bs_id) instead of remapping. This requires ssm.size(0) >= x.size(0) to prevent out-of-bounds access during kernel execution.

Required guard
     auto device = ssm.device();
     int const bs = x.size(0);
+    TORCH_CHECK(
+        ssm_batch_indices.has_value() || ssm.size(0) >= bs,
+        "ssm must have at least x.size(0) slots when ssm_batch_indices is not provided");
     int const nheads = ssm.size(1);
     int const head_dim = ssm.size(2);
     int const ssm_dim = ssm.size(3);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp` around lines 50 - 63, Add a
host-side validation that when ssm_batch_indices is not provided (i.e., kernel
will use direct bs_id indexing) the first dimension of ssm is at least the batch
size of x: check ssm.size(0) >= x.size(0) and raise a TORCH_CHECK with a clear
message if not; locate this guard near the existing tensor shape/device checks
around ssm, x, and intermediate_states in mamba2MTPSSMCacheOp.cpp so it runs
before launching the kernel and prevents out-of-bounds access when
ssm_batch_indices is null.

Comment on lines +78 to +82
TORCH_CHECK(B.dim() == 4 && B.size(0) == bs && B.size(1) == cache_steps && B.size(3) == ssm_dim,
"B tensor has incorrect shapes");
TORCH_CHECK(B.device() == device && B.is_contiguous(), "B is not a contiguous tensor of the same device as ssm");
int const ngroups = B.size(2);
TORCH_CHECK(nheads % ngroups == 0, "unsupported pair of nheads and ngroups");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reject ngroups == 0 before the modulo.

B.size(2) can be zero for an empty [batch, steps, 0, ssm_dim] tensor. Line 82 then evaluates nheads % ngroups, which is host-side undefined behavior instead of a clean TORCH_CHECK.

🐛 Proposed fix
     TORCH_CHECK(B.dim() == 4 && B.size(0) == bs && B.size(1) == cache_steps && B.size(3) == ssm_dim,
         "B tensor has incorrect shapes");
     TORCH_CHECK(B.device() == device && B.is_contiguous(), "B is not a contiguous tensor of the same device as ssm");
     int const ngroups = B.size(2);
+    TORCH_CHECK(ngroups > 0, "B.size(2) (ngroups) must be > 0");
     TORCH_CHECK(nheads % ngroups == 0, "unsupported pair of nheads and ngroups");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp` around lines 78 - 82,
B.size(2) can be zero, causing undefined host behavior when evaluating nheads %
ngroups; add an explicit check that rejects ngroups == 0 before performing the
modulo. Concretely, after computing int const ngroups = B.size(2); add a
TORCH_CHECK(ngroups > 0, "ngroups must be > 0") (or similar) and then keep the
existing TORCH_CHECK(nheads % ngroups == 0, "unsupported pair of nheads and
ngroups") so the modulo is only executed on a non-zero ngroups.

Comment on lines +209 to +210
auto stream = at::cuda::getCurrentCUDAStream().stream();
tk::invokeMamba2MTPSSMCacheUpdate(params, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

In the PyTorch C++ / ATen CUDA API, when at::cuda::getCurrentCUDAStream()is called without a device argument, does it use the current CUDA device rather than an input tensor's device? What is the recommended custom-op pattern for launching on the same device as a tensor:CUDAGuard/OptionalCUDAGuard, getCurrentCUDAStream(device.index()), or both?

💡 Result:

Yes, when at::cuda::getCurrentCUDAStream is called without a device argument (device_index = -1), it uses the current CUDA device. The official PyTorch C++ documentation states: "Acquire the current CUDA stream, for the CUDA device with index device_index, or for the current device if no device index is passed." The source code in CUDAStream.cpp confirms this: if device_index == -1, it sets device_index = current_device before returning the stream. It does not use an input tensor's device; it always uses the current device context. For the recommended custom-op pattern to launch on the same device as a tensor: Use CUDAGuard or OptionalCUDAGuard to set the current device to tensor.device.index, then use at::cuda::getCurrentCUDAStream (without device_index, or explicitly with tensor.device.index). The PyTorch C++ docs recommend CUDAGuard and CUDAStreamGuard over direct setCurrentCUDAStream because guards handle device switching and restoration properly via RAII. Examples in the docs show using CUDAGuard to switch devices before acquiring/setting streams for tensor operations. OptionalCUDAGuard is useful when the device might be optional (e.g., CPU fallback). Both device guard + current stream is the safe pattern; stream alone assumes correct current device.

Citations:


🏁 Script executed:

# Check if the file exists
ls -la cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp 2>&1 | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 157


🏁 Script executed:

# Read the code around lines 209-210 with context
wc -l cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 112


🏁 Script executed:

# Get the content around the specified lines
sed -n '200,220p' cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 678


🏁 Script executed:

# Check for existing includes at the top of the file to understand the pattern
head -30 cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 929


🏁 Script executed:

# Read more context to find the function signature and device access
sed -n '160,215p' cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 2129


🏁 Script executed:

# Find the function signature - look earlier in the function
sed -n '100,165p' cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 2453


🏁 Script executed:

# Find the function signature
sed -n '30,105p' cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 4189


🏁 Script executed:

# Check all includes in the file to see if CUDAGuard is already included
grep -n "^#include" cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp

Repository: NVIDIA/TensorRT-LLM

Length of output: 177


🏁 Script executed:

# Also check thUtils.h to see what it includes
head -50 cpp/tensorrt_llm/thop/thUtils.h

Repository: NVIDIA/TensorRT-LLM

Length of output: 2862


🏁 Script executed:

# Search for CUDAGuard usage in the codebase to see if it's used elsewhere
rg "CUDAGuard" cpp/tensorrt_llm/ | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 1050


🏁 Script executed:

# Check if ATen/cuda/CUDAGuard.h is available in the project or referenced
find . -name "*.h" -o -name "*.hpp" | xargs grep -l "ATen/cuda/CUDAGuard" 2>/dev/null | head -10

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


Use CUDAGuard to ensure the kernel launches on the correct device.

Line 209 uses at::cuda::getCurrentCUDAStream() without a device argument, which uses the thread-local current CUDA device instead of the tensor's device. On multi-GPU calls, this launches the kernel on a stream that does not belong to the tensor's device, causing incorrect behavior.

The fix requires setting the current device context before acquiring the stream. The correct include header is <c10/cuda/CUDAGuard.h> (not ATen/cuda/CUDAGuard.h), matching the pattern used elsewhere in the codebase.

Proposed fix
+#include <c10/cuda/CUDAGuard.h>
...
-    auto stream = at::cuda::getCurrentCUDAStream().stream();
+    at::cuda::CUDAGuard deviceGuard{static_cast<signed char>(ssm.get_device())};
+    auto const stream = at::cuda::getCurrentCUDAStream().stream();
     tk::invokeMamba2MTPSSMCacheUpdate(params, stream);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/thop/mamba2MTPSSMCacheOp.cpp` around lines 209 - 210, The
code calls at::cuda::getCurrentCUDAStream() without setting the device, which
can pick the thread-local device instead of the device associated with params;
fix by including <c10/cuda/CUDAGuard.h> and creating a c10::cuda::CUDAGuard
guard(params.device()) (or the appropriate device accessor from params)
immediately before calling at::cuda::getCurrentCUDAStream(), then get the stream
and call tk::invokeMamba2MTPSSMCacheUpdate(params, stream) so the kernel is
launched on the correct device.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant