Skip to content

[TRTLLM-11289][feat] Integrate CuteDSL's bf16 dense GEMMs#12074

Open
peaceh-nv wants to merge 10 commits intoNVIDIA:mainfrom
peaceh-nv:peaceh-bf16-gemm
Open

[TRTLLM-11289][feat] Integrate CuteDSL's bf16 dense GEMMs#12074
peaceh-nv wants to merge 10 commits intoNVIDIA:mainfrom
peaceh-nv:peaceh-bf16-gemm

Conversation

@peaceh-nv
Copy link
Copy Markdown
Collaborator

@peaceh-nv peaceh-nv commented Mar 10, 2026

Add a CuTe DSL BF16 persistent GEMM kernel as an alternative BMM implementation for MLA (Multi-head Latent Attention) on Blackwell GPUs. Gated behind the use_cute_dsl_bf16_bmm flag and is_sm_100f() so it has zero impact on existing code paths when disabled.

Integration:

  • CuteDSLBf16BlackwellBmmRunner + trtllm::cute_dsl_bf16_bmm_blackwell op in cute_dsl_custom_ops.py with AutoTuner tactic selection.
  • use_cute_dsl_bf16_bmm config plumbed through LlmArgs -> ModelConfig -> model_loader -> MLA attention (6 BMM call sites: k_b_proj and v_b_proj in generation, context, and sparse-MLA paths).
  • --use_cute_dsl_bf16_bmm CLI flag in quickstart_advanced.py.
  • Integration tests: single-GPU and 4-GPU (tp4/ep4) accuracy tests with GSM8K for DeepSeek-V3-Lite BF16 in test_llm_api_pytorch.py.

Non-contiguous tensor handling: the runner makes inputs contiguous before extracting data pointers since the kernel layout assumes contiguous [B,M,K].

Perf result on GB200 1k/1k 1ctx + 2gen DEP8 bs512 DeepSeek-FP4 tps/user:
native bf16 bmm + gemm : 25.04 tok/s
cute dsl bf16 bmm + cute dsl bf16 gemm : 24.47 tok/s
cute dsl bf16 bmm + cublas bf16 gemm : 25.26 tok/s
cublas bf16 bmm + cute dsl bf16 gemm : 25.25 tok/s

Summary by CodeRabbit

  • New Features

    • Added CuTe DSL BF16 persistent batched matrix multiply optimization for Blackwell GPUs.
    • Introduced command-line flag to enable/disable the new optimization.
    • Integrated feature into attention and multi-head latent attention modules with automatic fallback for non-Blackwell hardware.
  • Tests

    • Added integration tests for the new optimization on single and multi-GPU configurations.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 10, 2026

📝 Walkthrough

Walkthrough

This PR introduces CuTe DSL-based BF16 persistent batched matrix multiplication support for Blackwell GPUs. Changes include a new CLI flag, a persistent GEMM kernel implementation, a custom Torch operation for execution, model configuration updates to enable the feature, and integration into attention modules with conditional dispatch logic.

Changes

Cohort / File(s) Summary
Configuration & CLI
examples/llm-api/quickstart_advanced.py, tensorrt_llm/llmapi/llm_args.py, tensorrt_llm/_torch/model_config.py
Added new use_cute_dsl_bf16_bmm flag to CLI arguments, LLM args configuration, and model config dataclass to control feature enablement.
Persistent GEMM Kernel
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py
Implemented new PersistentDenseGemmKernel class with multi-stage GEMM orchestration, TMA load/store handling, cooperative group synchronization, epilogue logic, and comprehensive validation utilities (dtype, alignment, tiling shape checks).
Custom Operation Integration
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Registered new Torch custom op trtllm::cute_dsl_bf16_bmm_blackwell with CuteDSLBf16BlackwellBmmRunner kernel runner, autotuning support, TVM-FFI path, and fake op registration for compatibility.
Attention Module Integration
tensorrt_llm/_torch/modules/attention.py
Added conditional dispatch in Attention and MLA classes to use new cute_dsl_bf16_bmm_blackwell operation for BF16 BMM operations on SM 100f devices, with fallback to existing paths when unavailable.
Model Loader
tensorrt_llm/_torch/pyexecutor/model_loader.py
Propagated new use_cute_dsl_bf16_bmm flag through model loading config initialization.
Integration Tests
tests/integration/defs/accuracy/test_llm_api_pytorch.py
Added two new test cases covering single-GPU and 4-GPU distributed scenarios for CuTe DSL BF16 BMM with GSM8K evaluation.

Sequence Diagram

sequenceDiagram
    participant User as User/CLI
    participant Config as ModelConfig
    participant Attention as Attention Module
    participant Op as Custom Op<br/>(cute_dsl_bf16_bmm_blackwell)
    participant Runner as CuteDSLBf16<br/>BlackwellBmmRunner
    participant Kernel as PersistentDenseGemmKernel

    User->>Config: Create model with<br/>use_cute_dsl_bf16_bmm=True
    Config->>Attention: Initialize with feature flag
    Attention->>Attention: Store use_cute_dsl_bf16_bmm config
    
    Note over Attention: During forward pass (BF16 BMM needed)
    Attention->>Attention: Check if use_cute_dsl_bf16_bmm<br/>and is_sm_100f()
    alt Feature Enabled & Compatible Hardware
        Attention->>Op: Call cute_dsl_bf16_bmm_blackwell<br/>(input, weight, output)
        Op->>Op: Validate SM compatibility
        Op->>Op: Select autotuner tactic
        Op->>Runner: Execute with tactic
        Runner->>Runner: Prepare pointers &<br/>compile kernel
        Runner->>Kernel: Launch kernel
        Kernel->>Kernel: Perform persistent GEMM<br/>with TMA/epilogue
        Kernel-->>Runner: Return result
        Runner-->>Op: Buffered output
        Op-->>Attention: BF16 result
    else Fall Back
        Attention->>Attention: Use existing bmm_out path
    end
    Attention-->>User: Compute output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and concisely summarizes the main feature: integrating CuteDSL's BF16 dense GEMMs as an alternative BMM implementation. The title is specific, follows the required format with JIRA ID and feature type, and accurately reflects the primary change.
Description check ✅ Passed PR description is detailed and comprehensive, covering motivation, key integration points, implementation details, and performance results, well-aligned with the template structure.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use Trivy to scan for security misconfigurations and secrets in Infrastructure as Code files.

Add a .trivyignore file to your project to customize which findings Trivy reports.

@peaceh-nv peaceh-nv requested a review from liyuhannnnn March 10, 2026 07:04
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (2)
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py (1)

323-324: For new imports, follow the module-namespace rule.

This continues the direct-class import pattern, but the repo guideline for Python imports is to keep the module namespace. Please import dense_gemm_persistent and reference dense_gemm_persistent.PersistentDenseGemmKernel from kernel_class instead.

♻️ Proposed refactor
-    from ..cute_dsl_kernels.blackwell.dense_gemm_persistent import \
-        PersistentDenseGemmKernel
+    from ..cute_dsl_kernels.blackwell import dense_gemm_persistent
...
-        kernel_class = PersistentDenseGemmKernel
+        kernel_class = dense_gemm_persistent.PersistentDenseGemmKernel

As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."

Also applies to: 3750-3750

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py` around lines 323 -
324, The import pulls PersistentDenseGemmKernel directly; change it to import
the module and reference the class via the module namespace: replace the direct
class import of PersistentDenseGemmKernel from
..cute_dsl_kernels.blackwell.dense_gemm_persistent with an import of the module
(dense_gemm_persistent) and update any usages (e.g., kernel_class) to use
dense_gemm_persistent.PersistentDenseGemmKernel; apply the same module-namespace
change for the other occurrence referenced in the comment.
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py (1)

54-58: Prefer module-qualified imports here.

The direct symbol imports on Lines 54-58 make this already DSL-heavy module harder to trace and go against the repo's Python import rule. Please import the modules and qualify the call sites instead of pulling the symbols into the local namespace.

As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions (e.g., use from package.subpackage import foo then foo.SomeClass() instead of from package.subpackage.foo import SomeClass)."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py`
around lines 54 - 58, The file currently uses direct symbol imports
(pipeline_init_arrive, pipeline_init_wait, PipelineTmaUmma, PipelineUmmaAsync,
TRTLLM_ENABLE_PDL, griddepcontrol_launch_dependents, griddepcontrol_wait,
is_power_of_2); change these to module-qualified imports and update all call
sites to use the module namespace (e.g., import cutlass.pipeline as pipeline and
call pipeline.pipeline_init_arrive/pipeline.pipeline_init_wait; import the
custom_pipeline module and reference custom_pipeline.PipelineTmaUmma and
custom_pipeline.PipelineUmmaAsync; import utils as utils and reference
utils.TRTLLM_ENABLE_PDL, utils.griddepcontrol_launch_dependents,
utils.griddepcontrol_wait, utils.is_power_of_2) so the DSL-heavy symbols remain
namespaced and follow the repo's import rule.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py`:
- Around line 3858-3864: The output buffer creation incorrectly uses
torch.empty_like(c_tensor) which preserves c_tensor's stride pattern; modify the
c_buf allocation in the block guarded by c_needs_copy (variables c_needs_copy,
c_buf, c_tensor) to request a contiguous layout by calling torch.empty_like with
memory_format=torch.contiguous_format so the CuTe TMA store sees a standard [B,
M, N] layout; keep the rest of the copy-back logic unchanged.

In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py`:
- Around line 1-14: Update the Apache header year by changing the copyright line
"Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES." to "Copyright (c) 2026
NVIDIA CORPORATION & AFFILIATES." at the top of the file (the Apache 2.0 header
block in dense_gemm_persistent.py) so the header reflects the latest meaningful
modification year.
- Around line 116-125: The constructor (__init__) allows use_tma_store=False but
the rest of the implementation (e.g., _compute_stages called with
c_smem_layout=None and the epilogue later asserting self.use_tma_store) requires
TMA; add a fail-fast check in __init__ that raises a clear exception (ValueError
or AssertionError) when use_tma_store is False to prevent constructing an
unusable kernel object; reference __init__, _compute_stages, and the epilogue
that asserts self.use_tma_store so reviewers can locate the guard and ensure the
invalid configuration is rejected early.

In `@tensorrt_llm/_torch/modules/attention.py`:
- Around line 2215-2223: The branch using CuTe BF16 bmm must be guarded by
op-registration availability, not only by SM check: update the condition that
currently reads "if self.use_cute_dsl_bf16_bmm and is_sm_100f()" to also verify
the CuTe DSL op exists (e.g., check a module-level flag like
IS_CUTLASS_DSL_AVAILABLE or use hasattr(torch.ops.trtllm,
'cute_dsl_bf16_bmm_blackwell')) so that when the op is not registered the code
falls back to torch.ops.trtllm.bmm_out; apply the same change for the other
occurrences noted (the blocks using cute_dsl_bf16_bmm_blackwell at the locations
around lines 2319-2326, 2377-2383, 2440-2447, 2510-2516, 2594-2601) and ensure
maybe_execute_in_parallel still receives the correct bmm_fn fallback.

In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 1753-1769: The test test_cute_dsl_bf16_bmm is only decorated with
`@skip_pre_blackwell` but the CuteDSL BF16 path is runtime-gated by is_sm_100f(),
so add a skip/guard to ensure the test only runs on SM100f silicon (e.g., add a
pytest.mark.skipif(not is_sm_100f(), reason=...) or a guard that calls
is_sm_100f() before constructing the LLM) so the test actually exercises the
CuteDSL BF16 kernel; do the same change for the duplicate tests around lines
1937-1960 (same test names/logic) so all CuteDSL BF16 BMM tests are gated by
is_sm_100f().
- Around line 1753-1769: Add the same GPU memory gate decorator used by the
other single-GPU DeepSeek-V3-Lite BF16 tests to the test_cute_dsl_bf16_bmm
function so it requires at least ~60 GB of GPU memory; place that decorator
above `@parametrize_with_ids` and reuse the existing helper/decorator used by
adjacent BF16 tests (so the test skips when the machine lacks the required
memory), leaving the rest of the test (KvCacheConfig, pytorch_config, LLM
context and GSM8K evaluation) unchanged.

---

Nitpick comments:
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py`:
- Around line 323-324: The import pulls PersistentDenseGemmKernel directly;
change it to import the module and reference the class via the module namespace:
replace the direct class import of PersistentDenseGemmKernel from
..cute_dsl_kernels.blackwell.dense_gemm_persistent with an import of the module
(dense_gemm_persistent) and update any usages (e.g., kernel_class) to use
dense_gemm_persistent.PersistentDenseGemmKernel; apply the same module-namespace
change for the other occurrence referenced in the comment.

In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py`:
- Around line 54-58: The file currently uses direct symbol imports
(pipeline_init_arrive, pipeline_init_wait, PipelineTmaUmma, PipelineUmmaAsync,
TRTLLM_ENABLE_PDL, griddepcontrol_launch_dependents, griddepcontrol_wait,
is_power_of_2); change these to module-qualified imports and update all call
sites to use the module namespace (e.g., import cutlass.pipeline as pipeline and
call pipeline.pipeline_init_arrive/pipeline.pipeline_init_wait; import the
custom_pipeline module and reference custom_pipeline.PipelineTmaUmma and
custom_pipeline.PipelineUmmaAsync; import utils as utils and reference
utils.TRTLLM_ENABLE_PDL, utils.griddepcontrol_launch_dependents,
utils.griddepcontrol_wait, utils.is_power_of_2) so the DSL-heavy symbols remain
namespaced and follow the repo's import rule.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2a192f6b-07d3-4e7f-9aa0-014efdf36cc8

📥 Commits

Reviewing files that changed from the base of the PR and between 39d294b and 8454f55.

📒 Files selected for processing (8)
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
  • tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py
  • tensorrt_llm/_torch/model_config.py
  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/pyexecutor/model_loader.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py

Comment on lines +3858 to +3864
# For the output, use a contiguous buffer so TMA store sees a
# standard layout; copy back afterwards if the original was
# non-contiguous.
c_needs_copy = not c_tensor.is_contiguous()
if c_needs_copy:
c_buf = torch.empty_like(c_tensor)
else:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

head -3870 tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py | tail -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 950


🏁 Script executed:

sed -n '3850,3880p' tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1358


🏁 Script executed:

cat -n tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py | sed -n '3850,3880p'

Repository: NVIDIA/TensorRT-LLM

Length of output: 1575


🏁 Script executed:

# Check PyTorch documentation and verify torch.empty_like behavior
python3 << 'EOF'
import torch

# Test 1: Does empty_like preserve strides?
base_tensor = torch.randn(2, 3, 4)
transposed = base_tensor.transpose(0, 1)
print("Original contiguous:", base_tensor.is_contiguous())
print("Transposed contiguous:", transposed.is_contiguous())
print("Transposed strides:", transposed.stride())

# Create empty_like from non-contiguous tensor
empty = torch.empty_like(transposed)
print("empty_like(transposed) contiguous:", empty.is_contiguous())
print("empty_like(transposed) strides:", empty.stride())

# Test 2: With memory_format parameter
empty_contiguous = torch.empty_like(transposed, memory_format=torch.contiguous_format)
print("empty_like with contiguous_format contiguous:", empty_contiguous.is_contiguous())
print("empty_like with contiguous_format strides:", empty_contiguous.stride())
EOF

Repository: NVIDIA/TensorRT-LLM

Length of output: 180


🏁 Script executed:

# Search for how this function is used with transposed tensors
rg -A 3 -B 3 "cute_dsl_custom_ops" tensorrt_llm --type py | head -50

Repository: NVIDIA/TensorRT-LLM

Length of output: 3157


🏁 Script executed:

# Find call sites that might pass transposed tensors
rg "\.transpose\(" tensorrt_llm/_torch --type py -B 2 -A 2 | grep -A 5 -B 5 "cute_dsl" | head -40

Repository: NVIDIA/TensorRT-LLM

Length of output: 3004


🏁 Script executed:

# Check PyTorch documentation references in codebase
rg "memory_format" tensorrt_llm --type py -B 2 -A 2 | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 5154


🏁 Script executed:

# Check if there are other uses of empty_like with memory_format in the codebase
rg "empty_like.*memory_format" tensorrt_llm --type py

Repository: NVIDIA/TensorRT-LLM

Length of output: 632


🏁 Script executed:

# Check the surrounding code context to understand better what happens with c_tmp
sed -n '3858,3900p' tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1621


🏁 Script executed:

# Look for contiguous() calls to see the pattern used elsewhere
rg "\.contiguous\(" tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py -B 2 -A 2 | head -60

Repository: NVIDIA/TensorRT-LLM

Length of output: 352


Use memory_format=torch.contiguous_format with torch.empty_like() to guarantee contiguous buffer for TMA store.

torch.empty_like(c_tensor) preserves the stride pattern of c_tensor, so when non-contiguous views (from .transpose() operations) are passed in, c_buf remains non-contiguous despite the comment's intent. This violates the assumption for the CuTe TMA store operation, which expects standard [B, M, N] layout.

The fix matches the established pattern elsewhere in the codebase (e.g., mamba and FLA backends):

             if c_needs_copy:
-                c_buf = torch.empty_like(c_tensor)
+                c_buf = torch.empty_like(
+                    c_tensor,
+                    memory_format=torch.contiguous_format,
+                )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py` around lines 3858 -
3864, The output buffer creation incorrectly uses torch.empty_like(c_tensor)
which preserves c_tensor's stride pattern; modify the c_buf allocation in the
block guarded by c_needs_copy (variables c_needs_copy, c_buf, c_tensor) to
request a contiguous layout by calling torch.empty_like with
memory_format=torch.contiguous_format so the CuTe TMA store sees a standard [B,
M, N] layout; keep the rest of the copy-back logic unchanged.

Comment on lines +1 to +14
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Update the Apache header year to 2026.

This file is being added/meaningfully modified in this March 2026 PR, so the repo header should carry the latest modification year. Please bump the Apache header on Line 1 to 2026; the upstream BSD notice can remain as attribution if needed.

As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of the latest meaningful modification. The header should be an Apache 2.0 license block as specified."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py`
around lines 1 - 14, Update the Apache header year by changing the copyright
line "Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES." to "Copyright (c)
2026 NVIDIA CORPORATION & AFFILIATES." at the top of the file (the Apache 2.0
header block in dense_gemm_persistent.py) so the header reflects the latest
meaningful modification year.

Comment on lines +116 to +125
def __init__(
self,
acc_dtype: Type[cutlass.Numeric],
use_2cta_instrs: bool,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
use_tma_store: bool = True,
swizzle_size: int = 1,
raster_along: Literal["m", "n"] = "m",
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast when use_tma_store is false.

The constructor exposes a non-TMA epilogue mode, but the rest of the implementation still assumes TMA store everywhere: _compute_stages() is given c_smem_layout=None, and the epilogue later hard-asserts self.use_tma_store on Line 659. A false value therefore creates a kernel object that cannot launch successfully.

🚧 Suggested guard
     def __init__(
         self,
         acc_dtype: Type[cutlass.Numeric],
         use_2cta_instrs: bool,
         mma_tiler_mn: Tuple[int, int],
         cluster_shape_mn: Tuple[int, int],
         use_tma_store: bool = True,
         swizzle_size: int = 1,
         raster_along: Literal["m", "n"] = "m",
     ):
+        if not use_tma_store:
+            raise NotImplementedError(
+                "PersistentDenseGemmKernel currently requires use_tma_store=True"
+            )
         self.acc_dtype: Type[cutlass.Numeric] = acc_dtype
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py`
around lines 116 - 125, The constructor (__init__) allows use_tma_store=False
but the rest of the implementation (e.g., _compute_stages called with
c_smem_layout=None and the epilogue later asserting self.use_tma_store) requires
TMA; add a fail-fast check in __init__ that raises a clear exception (ValueError
or AssertionError) when use_tma_store is False to prevent constructing an
unusable kernel object; reference __init__, _compute_stages, and the epilogue
that asserts self.use_tma_store so reviewers can locate the guard and ensure the
invalid configuration is rejected early.

Comment on lines +2215 to 2223
if self.use_cute_dsl_bf16_bmm and is_sm_100f():
bmm_fn = lambda: torch.ops.trtllm.cute_dsl_bf16_bmm_blackwell(
q_nope_t, self.k_b_proj_trans, q_nope_out)
else:
bmm_fn = lambda: torch.ops.trtllm.bmm_out(
q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out)
maybe_execute_in_parallel(
lambda: torch.ops.trtllm.bmm_out(
q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out),
bmm_fn,
lambda: self.mqa.mla_rope_generation(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard the CuTe BF16 path on op registration, not just SM.

cute_dsl_bf16_bmm_blackwell is only registered when IS_CUTLASS_DSL_AVAILABLE is true in tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py. If a user enables use_cute_dsl_bf16_bmm on an SM100 machine without that build-time support, these branches will raise a missing torch.ops.trtllm.cute_dsl_bf16_bmm_blackwell attribute instead of taking the existing bmm_out fallback. Please gate this behind a shared availability predicate, or disable the flag up front.

Also applies to: 2319-2326, 2377-2383, 2440-2447, 2510-2516, 2594-2601

🧰 Tools
🪛 Ruff (0.15.5)

[error] 2216-2217: Do not assign a lambda expression, use a def

Rewrite bmm_fn as a def

(E731)


[error] 2219-2220: Do not assign a lambda expression, use a def

Rewrite bmm_fn as a def

(E731)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/attention.py` around lines 2215 - 2223, The
branch using CuTe BF16 bmm must be guarded by op-registration availability, not
only by SM check: update the condition that currently reads "if
self.use_cute_dsl_bf16_bmm and is_sm_100f()" to also verify the CuTe DSL op
exists (e.g., check a module-level flag like IS_CUTLASS_DSL_AVAILABLE or use
hasattr(torch.ops.trtllm, 'cute_dsl_bf16_bmm_blackwell')) so that when the op is
not registered the code falls back to torch.ops.trtllm.bmm_out; apply the same
change for the other occurrences noted (the blocks using
cute_dsl_bf16_bmm_blackwell at the locations around lines 2319-2326, 2377-2383,
2440-2447, 2510-2516, 2594-2601) and ensure maybe_execute_in_parallel still
receives the correct bmm_fn fallback.

Comment on lines +1753 to +1769
@skip_pre_blackwell
@parametrize_with_ids("cuda_graph", [False, True])
def test_cute_dsl_bf16_bmm(self, cuda_graph):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
pytorch_config = dict(
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
use_cute_dsl_bf16_bmm=True,
)

with LLM(
self.MODEL_PATH,
kv_cache_config=kv_cache_config,
**pytorch_config,
) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Gate these tests to SM100f so they actually hit the CuteDSL BF16 path.

This feature is runtime-gated by is_sm_100f(). With only @skip_pre_blackwell, these tests can still pass on non-SM100 Blackwell parts via the fallback BMM path, so they don't reliably validate the new kernel.

Suggested fix
     `@skip_pre_blackwell`
     `@parametrize_with_ids`("cuda_graph", [False, True])
     def test_cute_dsl_bf16_bmm(self, cuda_graph):
+        if not is_sm_100f():
+            pytest.skip("CuteDSL BF16 BMM is only exercised on SM100f")
         kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
         pytorch_config = dict(
             disable_overlap_scheduler=True,
             cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
             use_cute_dsl_bf16_bmm=True,
@@
     `@pytest.mark.skip_less_device`(4)
     `@skip_pre_blackwell`
     `@parametrize_with_ids`("cuda_graph", [False, True])
     `@pytest.mark.parametrize`("tp_size,pp_size,ep_size", [(4, 1, 1), (4, 1, 4)],
                              ids=["tp4", "ep4"])
     def test_cute_dsl_bf16_bmm_4gpus(self, tp_size, pp_size, ep_size,
                                      cuda_graph):
+        if not is_sm_100f():
+            pytest.skip("CuteDSL BF16 BMM is only exercised on SM100f")
         kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
         pytorch_config = dict(
             disable_overlap_scheduler=True,
             cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
             use_cute_dsl_bf16_bmm=True,

Also applies to: 1937-1960

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 1753 -
1769, The test test_cute_dsl_bf16_bmm is only decorated with `@skip_pre_blackwell`
but the CuteDSL BF16 path is runtime-gated by is_sm_100f(), so add a skip/guard
to ensure the test only runs on SM100f silicon (e.g., add a
pytest.mark.skipif(not is_sm_100f(), reason=...) or a guard that calls
is_sm_100f() before constructing the LLM) so the test actually exercises the
CuteDSL BF16 kernel; do the same change for the duplicate tests around lines
1937-1960 (same test names/logic) so all CuteDSL BF16 BMM tests are gated by
is_sm_100f().

⚠️ Potential issue | 🟡 Minor

Add the same memory gate used by the other single-GPU DeepSeek-V3-Lite BF16 tests.

The adjacent BF16 tests for this model already require at least 60 GB. Without that decorator here, this test is more likely to fail because the model does not fit, not because of the CuteDSL BMM path.

Suggested fix
+    `@pytest.mark.skip_less_device_memory`(60000)
     `@skip_pre_blackwell`
     `@parametrize_with_ids`("cuda_graph", [False, True])
     def test_cute_dsl_bf16_bmm(self, cuda_graph):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 1753 -
1769, Add the same GPU memory gate decorator used by the other single-GPU
DeepSeek-V3-Lite BF16 tests to the test_cute_dsl_bf16_bmm function so it
requires at least ~60 GB of GPU memory; place that decorator above
`@parametrize_with_ids` and reuse the existing helper/decorator used by adjacent
BF16 tests (so the test skips when the machine lacks the required memory),
leaving the rest of the test (KvCacheConfig, pytorch_config, LLM context and
GSM8K evaluation) unchanged.

@peaceh-nv peaceh-nv requested a review from a team as a code owner March 10, 2026 09:00
@peaceh-nv peaceh-nv requested a review from xxi-nv March 10, 2026 09:00
@peaceh-nv peaceh-nv requested review from a team as code owners March 16, 2026 05:50
@peaceh-nv peaceh-nv requested a review from brb-nv March 16, 2026 05:50
@peaceh-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39369 [ run ] triggered by Bot. Commit: 33ac890 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39369 [ run ] completed with state FAILURE. Commit: 33ac890
/LLM/main/L0_MergeRequest_PR pipeline #30611 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@peaceh-nv peaceh-nv requested a review from a team as a code owner March 19, 2026 06:43
@peaceh-nv peaceh-nv requested a review from a team as a code owner March 23, 2026 05:47
@peaceh-nv peaceh-nv requested a review from suyoggupta March 23, 2026 05:47
@peaceh-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39887 [ run ] triggered by Bot. Commit: 2cbe5da Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39887 [ run ] completed with state SUCCESS. Commit: 2cbe5da
/LLM/main/L0_MergeRequest_PR pipeline #31056 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@peaceh-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40120 [ run ] triggered by Bot. Commit: 2cbe5da Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40120 [ run ] completed with state FAILURE. Commit: 2cbe5da
/LLM/main/L0_MergeRequest_PR pipeline #31267 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@peaceh-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40143 [ run ] triggered by Bot. Commit: 2cbe5da Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40143 [ run ] completed with state SUCCESS. Commit: 2cbe5da
/LLM/main/L0_MergeRequest_PR pipeline #31289 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

peaceh-nv and others added 9 commits March 29, 2026 18:47
Add a CuTe DSL BF16 persistent GEMM kernel as an alternative BMM
implementation for MLA (Multi-head Latent Attention) on Blackwell GPUs.
Gated behind the `use_cute_dsl_bf16_bmm` flag and `is_sm_100f()` so it
has zero impact on existing code paths when disabled.

New files:
- dense_gemm_persistent.py: Blackwell SM100 warp-specialized kernel with
  TMA loads, TMEM accumulators, and TMA store epilogue. Adapted from
  CUTLASS example with API compatibility fixes for the installed DSL.

Integration:
- CuteDSLBf16BlackwellBmmRunner + trtllm::cute_dsl_bf16_bmm_blackwell op
  in cute_dsl_custom_ops.py with AutoTuner tactic selection.
- use_cute_dsl_bf16_bmm config plumbed through LlmArgs -> ModelConfig ->
  model_loader -> MLA attention (6 BMM call sites: k_b_proj and v_b_proj
  in generation, context, and sparse-MLA paths).
- --use_cute_dsl_bf16_bmm CLI flag in quickstart_advanced.py.
- Integration tests: single-GPU and 4-GPU (tp4/ep4) accuracy tests with
  GSM8K for DeepSeek-V3-Lite BF16 in test_llm_api_pytorch.py.

Non-contiguous tensor handling: the runner makes inputs contiguous before
extracting data pointers since the kernel layout assumes contiguous [B,M,K].

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
When mma_inst_tile_k > 1, cute.gemm() generates multiple sub-MMA
instructions that all share the same ACCUMULATE flag. With
ACCUMULATE=False on the first K tile, every sub-MMA cleared the
accumulator so only the last sub-MMA's result survived, losing
(mma_inst_tile_k - 1) * mma_inst_shape_k elements per output tile.

This caused GSM8K accuracy to drop from 64.7% to 28.5%.

Fix by adding an inner kblock loop that iterates sub-MMA instructions
individually and sets ACCUMULATE=True after the first cute.gemm() call,
matching the pattern used by blockscaled_contiguous_grouped_gemm.py.

GSM8K accuracy restored to 64.86% (reference: 64.74%).

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
…kwell

Add use_cute_dsl_bf16_gemm flag to enable CuTe DSL BF16 persistent GEMM
for unquantized Linear layers in MLA attention (kv_a_proj_with_mqa,
q_b_proj, kv_b_proj). This complements the existing BF16 BMM support.

Changes:
- Add CuteDSLBf16BlackwellGemmRunner class and custom op in cute_dsl_custom_ops.py
- Add use_cute_dsl_bf16_gemm parameter to Linear class and UnquantizedLinearMethod
- Wire use_cute_dsl_bf16_gemm through ModelConfig, LlmArgs, and model_loader
- Pass flag to MLA Linear layers in attention.py
- Add --use_cute_dsl_bf16_gemm CLI argument to quickstart_advanced.py
- Add integration tests for single GPU and 4 GPU configurations

Signed-off-by: Pei He <peih@nvidia.com>
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
…GEMM (FP32 output)

Enable CuTe DSL BF16 GEMM kernel for DeepseekV3Gate router GEMM on Blackwell.
The router computes BF16 input @ BF16 weight -> FP32 logits, which our
persistent GEMM kernel already supports via FP32 accumulator and FP32 output.

Key changes:
- Support FP32 output dtype in CuteDSLBf16BlackwellGemmRunner (detect from
  output tensor instead of hardcoding BF16, add c_dtype to kernel cache key)
- Relax cute_dsl_bf16_gemm_blackwell custom op to accept BF16 or FP32 output
- Add CuTe DSL dispatch in DeepseekV3Gate.forward() gated by
  use_cute_dsl_bf16_gemm flag, with fallback to dsv3_router_gemm_op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
… path

Add wrapper_strided to PersistentDenseGemmKernel that accepts explicit A
tensor strides, enabling non-contiguous views (e.g. from .transpose()) to
be passed directly to TMA without .contiguous() copies. Update the BMM
runner to compute and pass A strides instead of forcing contiguous tensors,
removing the direct_copy_kernel_cuda overhead between attention and BMM.

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
…ction_core

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
…ed IDs

The previous entries lacked pytest parameter brackets, which wouldn't
match actual test node IDs. Expand to all 12 parametrized variants.

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
…BMM code

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Apply ruff format/lint fixes:
- Convert multi-line docstrings to single-line where appropriate (D200)
- Remove f-string prefix on strings without placeholders (F541)
- Remove unused import
- Use consistent double-quote docstrings instead of single-quotes
- Fix indentation in docstrings

Signed-off-by: Peace He <103117813+peaceh-nv@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@peaceh-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40627 [ run ] triggered by Bot. Commit: 3e709a4 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40627 [ run ] completed with state SUCCESS. Commit: 3e709a4
/LLM/main/L0_MergeRequest_PR pipeline #31665 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@peaceh-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40691 [ run ] triggered by Bot. Commit: 3e709a4 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40691 [ run ] completed with state SUCCESS. Commit: 3e709a4
/LLM/main/L0_MergeRequest_PR pipeline #31720 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants