[TRTLLM-11289][feat] Integrate CuteDSL's bf16 dense GEMMs#12074
[TRTLLM-11289][feat] Integrate CuteDSL's bf16 dense GEMMs#12074peaceh-nv wants to merge 10 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR introduces CuTe DSL-based BF16 persistent batched matrix multiplication support for Blackwell GPUs. Changes include a new CLI flag, a persistent GEMM kernel implementation, a custom Torch operation for execution, model configuration updates to enable the feature, and integration into attention modules with conditional dispatch logic. Changes
Sequence DiagramsequenceDiagram
participant User as User/CLI
participant Config as ModelConfig
participant Attention as Attention Module
participant Op as Custom Op<br/>(cute_dsl_bf16_bmm_blackwell)
participant Runner as CuteDSLBf16<br/>BlackwellBmmRunner
participant Kernel as PersistentDenseGemmKernel
User->>Config: Create model with<br/>use_cute_dsl_bf16_bmm=True
Config->>Attention: Initialize with feature flag
Attention->>Attention: Store use_cute_dsl_bf16_bmm config
Note over Attention: During forward pass (BF16 BMM needed)
Attention->>Attention: Check if use_cute_dsl_bf16_bmm<br/>and is_sm_100f()
alt Feature Enabled & Compatible Hardware
Attention->>Op: Call cute_dsl_bf16_bmm_blackwell<br/>(input, weight, output)
Op->>Op: Validate SM compatibility
Op->>Op: Select autotuner tactic
Op->>Runner: Execute with tactic
Runner->>Runner: Prepare pointers &<br/>compile kernel
Runner->>Kernel: Launch kernel
Kernel->>Kernel: Perform persistent GEMM<br/>with TMA/epilogue
Kernel-->>Runner: Return result
Runner-->>Op: Buffered output
Op-->>Attention: BF16 result
else Fall Back
Attention->>Attention: Use existing bmm_out path
end
Attention-->>User: Compute output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment Tip CodeRabbit can use Trivy to scan for security misconfigurations and secrets in Infrastructure as Code files.Add a .trivyignore file to your project to customize which findings Trivy reports. |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (2)
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py (1)
323-324: For new imports, follow the module-namespace rule.This continues the direct-class import pattern, but the repo guideline for Python imports is to keep the module namespace. Please import
dense_gemm_persistentand referencedense_gemm_persistent.PersistentDenseGemmKernelfromkernel_classinstead.♻️ Proposed refactor
- from ..cute_dsl_kernels.blackwell.dense_gemm_persistent import \ - PersistentDenseGemmKernel + from ..cute_dsl_kernels.blackwell import dense_gemm_persistent ... - kernel_class = PersistentDenseGemmKernel + kernel_class = dense_gemm_persistent.PersistentDenseGemmKernelAs per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."
Also applies to: 3750-3750
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py` around lines 323 - 324, The import pulls PersistentDenseGemmKernel directly; change it to import the module and reference the class via the module namespace: replace the direct class import of PersistentDenseGemmKernel from ..cute_dsl_kernels.blackwell.dense_gemm_persistent with an import of the module (dense_gemm_persistent) and update any usages (e.g., kernel_class) to use dense_gemm_persistent.PersistentDenseGemmKernel; apply the same module-namespace change for the other occurrence referenced in the comment.tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py (1)
54-58: Prefer module-qualified imports here.The direct symbol imports on Lines 54-58 make this already DSL-heavy module harder to trace and go against the repo's Python import rule. Please import the modules and qualify the call sites instead of pulling the symbols into the local namespace.
As per coding guidelines, "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions (e.g., use
from package.subpackage import foothenfoo.SomeClass()instead offrom package.subpackage.foo import SomeClass)."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py` around lines 54 - 58, The file currently uses direct symbol imports (pipeline_init_arrive, pipeline_init_wait, PipelineTmaUmma, PipelineUmmaAsync, TRTLLM_ENABLE_PDL, griddepcontrol_launch_dependents, griddepcontrol_wait, is_power_of_2); change these to module-qualified imports and update all call sites to use the module namespace (e.g., import cutlass.pipeline as pipeline and call pipeline.pipeline_init_arrive/pipeline.pipeline_init_wait; import the custom_pipeline module and reference custom_pipeline.PipelineTmaUmma and custom_pipeline.PipelineUmmaAsync; import utils as utils and reference utils.TRTLLM_ENABLE_PDL, utils.griddepcontrol_launch_dependents, utils.griddepcontrol_wait, utils.is_power_of_2) so the DSL-heavy symbols remain namespaced and follow the repo's import rule.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py`:
- Around line 3858-3864: The output buffer creation incorrectly uses
torch.empty_like(c_tensor) which preserves c_tensor's stride pattern; modify the
c_buf allocation in the block guarded by c_needs_copy (variables c_needs_copy,
c_buf, c_tensor) to request a contiguous layout by calling torch.empty_like with
memory_format=torch.contiguous_format so the CuTe TMA store sees a standard [B,
M, N] layout; keep the rest of the copy-back logic unchanged.
In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py`:
- Around line 1-14: Update the Apache header year by changing the copyright line
"Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES." to "Copyright (c) 2026
NVIDIA CORPORATION & AFFILIATES." at the top of the file (the Apache 2.0 header
block in dense_gemm_persistent.py) so the header reflects the latest meaningful
modification year.
- Around line 116-125: The constructor (__init__) allows use_tma_store=False but
the rest of the implementation (e.g., _compute_stages called with
c_smem_layout=None and the epilogue later asserting self.use_tma_store) requires
TMA; add a fail-fast check in __init__ that raises a clear exception (ValueError
or AssertionError) when use_tma_store is False to prevent constructing an
unusable kernel object; reference __init__, _compute_stages, and the epilogue
that asserts self.use_tma_store so reviewers can locate the guard and ensure the
invalid configuration is rejected early.
In `@tensorrt_llm/_torch/modules/attention.py`:
- Around line 2215-2223: The branch using CuTe BF16 bmm must be guarded by
op-registration availability, not only by SM check: update the condition that
currently reads "if self.use_cute_dsl_bf16_bmm and is_sm_100f()" to also verify
the CuTe DSL op exists (e.g., check a module-level flag like
IS_CUTLASS_DSL_AVAILABLE or use hasattr(torch.ops.trtllm,
'cute_dsl_bf16_bmm_blackwell')) so that when the op is not registered the code
falls back to torch.ops.trtllm.bmm_out; apply the same change for the other
occurrences noted (the blocks using cute_dsl_bf16_bmm_blackwell at the locations
around lines 2319-2326, 2377-2383, 2440-2447, 2510-2516, 2594-2601) and ensure
maybe_execute_in_parallel still receives the correct bmm_fn fallback.
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 1753-1769: The test test_cute_dsl_bf16_bmm is only decorated with
`@skip_pre_blackwell` but the CuteDSL BF16 path is runtime-gated by is_sm_100f(),
so add a skip/guard to ensure the test only runs on SM100f silicon (e.g., add a
pytest.mark.skipif(not is_sm_100f(), reason=...) or a guard that calls
is_sm_100f() before constructing the LLM) so the test actually exercises the
CuteDSL BF16 kernel; do the same change for the duplicate tests around lines
1937-1960 (same test names/logic) so all CuteDSL BF16 BMM tests are gated by
is_sm_100f().
- Around line 1753-1769: Add the same GPU memory gate decorator used by the
other single-GPU DeepSeek-V3-Lite BF16 tests to the test_cute_dsl_bf16_bmm
function so it requires at least ~60 GB of GPU memory; place that decorator
above `@parametrize_with_ids` and reuse the existing helper/decorator used by
adjacent BF16 tests (so the test skips when the machine lacks the required
memory), leaving the rest of the test (KvCacheConfig, pytorch_config, LLM
context and GSM8K evaluation) unchanged.
---
Nitpick comments:
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py`:
- Around line 323-324: The import pulls PersistentDenseGemmKernel directly;
change it to import the module and reference the class via the module namespace:
replace the direct class import of PersistentDenseGemmKernel from
..cute_dsl_kernels.blackwell.dense_gemm_persistent with an import of the module
(dense_gemm_persistent) and update any usages (e.g., kernel_class) to use
dense_gemm_persistent.PersistentDenseGemmKernel; apply the same module-namespace
change for the other occurrence referenced in the comment.
In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py`:
- Around line 54-58: The file currently uses direct symbol imports
(pipeline_init_arrive, pipeline_init_wait, PipelineTmaUmma, PipelineUmmaAsync,
TRTLLM_ENABLE_PDL, griddepcontrol_launch_dependents, griddepcontrol_wait,
is_power_of_2); change these to module-qualified imports and update all call
sites to use the module namespace (e.g., import cutlass.pipeline as pipeline and
call pipeline.pipeline_init_arrive/pipeline.pipeline_init_wait; import the
custom_pipeline module and reference custom_pipeline.PipelineTmaUmma and
custom_pipeline.PipelineUmmaAsync; import utils as utils and reference
utils.TRTLLM_ENABLE_PDL, utils.griddepcontrol_launch_dependents,
utils.griddepcontrol_wait, utils.is_power_of_2) so the DSL-heavy symbols remain
namespaced and follow the repo's import rule.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 2a192f6b-07d3-4e7f-9aa0-014efdf36cc8
📒 Files selected for processing (8)
examples/llm-api/quickstart_advanced.pytensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.pytensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.pytensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/modules/attention.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/llmapi/llm_args.pytests/integration/defs/accuracy/test_llm_api_pytorch.py
| # For the output, use a contiguous buffer so TMA store sees a | ||
| # standard layout; copy back afterwards if the original was | ||
| # non-contiguous. | ||
| c_needs_copy = not c_tensor.is_contiguous() | ||
| if c_needs_copy: | ||
| c_buf = torch.empty_like(c_tensor) | ||
| else: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
head -3870 tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py | tail -20Repository: NVIDIA/TensorRT-LLM
Length of output: 950
🏁 Script executed:
sed -n '3850,3880p' tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1358
🏁 Script executed:
cat -n tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py | sed -n '3850,3880p'Repository: NVIDIA/TensorRT-LLM
Length of output: 1575
🏁 Script executed:
# Check PyTorch documentation and verify torch.empty_like behavior
python3 << 'EOF'
import torch
# Test 1: Does empty_like preserve strides?
base_tensor = torch.randn(2, 3, 4)
transposed = base_tensor.transpose(0, 1)
print("Original contiguous:", base_tensor.is_contiguous())
print("Transposed contiguous:", transposed.is_contiguous())
print("Transposed strides:", transposed.stride())
# Create empty_like from non-contiguous tensor
empty = torch.empty_like(transposed)
print("empty_like(transposed) contiguous:", empty.is_contiguous())
print("empty_like(transposed) strides:", empty.stride())
# Test 2: With memory_format parameter
empty_contiguous = torch.empty_like(transposed, memory_format=torch.contiguous_format)
print("empty_like with contiguous_format contiguous:", empty_contiguous.is_contiguous())
print("empty_like with contiguous_format strides:", empty_contiguous.stride())
EOFRepository: NVIDIA/TensorRT-LLM
Length of output: 180
🏁 Script executed:
# Search for how this function is used with transposed tensors
rg -A 3 -B 3 "cute_dsl_custom_ops" tensorrt_llm --type py | head -50Repository: NVIDIA/TensorRT-LLM
Length of output: 3157
🏁 Script executed:
# Find call sites that might pass transposed tensors
rg "\.transpose\(" tensorrt_llm/_torch --type py -B 2 -A 2 | grep -A 5 -B 5 "cute_dsl" | head -40Repository: NVIDIA/TensorRT-LLM
Length of output: 3004
🏁 Script executed:
# Check PyTorch documentation references in codebase
rg "memory_format" tensorrt_llm --type py -B 2 -A 2 | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 5154
🏁 Script executed:
# Check if there are other uses of empty_like with memory_format in the codebase
rg "empty_like.*memory_format" tensorrt_llm --type pyRepository: NVIDIA/TensorRT-LLM
Length of output: 632
🏁 Script executed:
# Check the surrounding code context to understand better what happens with c_tmp
sed -n '3858,3900p' tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1621
🏁 Script executed:
# Look for contiguous() calls to see the pattern used elsewhere
rg "\.contiguous\(" tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py -B 2 -A 2 | head -60Repository: NVIDIA/TensorRT-LLM
Length of output: 352
Use memory_format=torch.contiguous_format with torch.empty_like() to guarantee contiguous buffer for TMA store.
torch.empty_like(c_tensor) preserves the stride pattern of c_tensor, so when non-contiguous views (from .transpose() operations) are passed in, c_buf remains non-contiguous despite the comment's intent. This violates the assumption for the CuTe TMA store operation, which expects standard [B, M, N] layout.
The fix matches the established pattern elsewhere in the codebase (e.g., mamba and FLA backends):
if c_needs_copy:
- c_buf = torch.empty_like(c_tensor)
+ c_buf = torch.empty_like(
+ c_tensor,
+ memory_format=torch.contiguous_format,
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py` around lines 3858 -
3864, The output buffer creation incorrectly uses torch.empty_like(c_tensor)
which preserves c_tensor's stride pattern; modify the c_buf allocation in the
block guarded by c_needs_copy (variables c_needs_copy, c_buf, c_tensor) to
request a contiguous layout by calling torch.empty_like with
memory_format=torch.contiguous_format so the CuTe TMA store sees a standard [B,
M, N] layout; keep the rest of the copy-back logic unchanged.
| # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
There was a problem hiding this comment.
Update the Apache header year to 2026.
This file is being added/meaningfully modified in this March 2026 PR, so the repo header should carry the latest modification year. Please bump the Apache header on Line 1 to 2026; the upstream BSD notice can remain as attribution if needed.
As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of the latest meaningful modification. The header should be an Apache 2.0 license block as specified."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py`
around lines 1 - 14, Update the Apache header year by changing the copyright
line "Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES." to "Copyright (c)
2026 NVIDIA CORPORATION & AFFILIATES." at the top of the file (the Apache 2.0
header block in dense_gemm_persistent.py) so the header reflects the latest
meaningful modification year.
| def __init__( | ||
| self, | ||
| acc_dtype: Type[cutlass.Numeric], | ||
| use_2cta_instrs: bool, | ||
| mma_tiler_mn: Tuple[int, int], | ||
| cluster_shape_mn: Tuple[int, int], | ||
| use_tma_store: bool = True, | ||
| swizzle_size: int = 1, | ||
| raster_along: Literal["m", "n"] = "m", | ||
| ): |
There was a problem hiding this comment.
Fail fast when use_tma_store is false.
The constructor exposes a non-TMA epilogue mode, but the rest of the implementation still assumes TMA store everywhere: _compute_stages() is given c_smem_layout=None, and the epilogue later hard-asserts self.use_tma_store on Line 659. A false value therefore creates a kernel object that cannot launch successfully.
🚧 Suggested guard
def __init__(
self,
acc_dtype: Type[cutlass.Numeric],
use_2cta_instrs: bool,
mma_tiler_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
use_tma_store: bool = True,
swizzle_size: int = 1,
raster_along: Literal["m", "n"] = "m",
):
+ if not use_tma_store:
+ raise NotImplementedError(
+ "PersistentDenseGemmKernel currently requires use_tma_store=True"
+ )
self.acc_dtype: Type[cutlass.Numeric] = acc_dtype🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py`
around lines 116 - 125, The constructor (__init__) allows use_tma_store=False
but the rest of the implementation (e.g., _compute_stages called with
c_smem_layout=None and the epilogue later asserting self.use_tma_store) requires
TMA; add a fail-fast check in __init__ that raises a clear exception (ValueError
or AssertionError) when use_tma_store is False to prevent constructing an
unusable kernel object; reference __init__, _compute_stages, and the epilogue
that asserts self.use_tma_store so reviewers can locate the guard and ensure the
invalid configuration is rejected early.
| if self.use_cute_dsl_bf16_bmm and is_sm_100f(): | ||
| bmm_fn = lambda: torch.ops.trtllm.cute_dsl_bf16_bmm_blackwell( | ||
| q_nope_t, self.k_b_proj_trans, q_nope_out) | ||
| else: | ||
| bmm_fn = lambda: torch.ops.trtllm.bmm_out( | ||
| q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out) | ||
| maybe_execute_in_parallel( | ||
| lambda: torch.ops.trtllm.bmm_out( | ||
| q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out), | ||
| bmm_fn, | ||
| lambda: self.mqa.mla_rope_generation( |
There was a problem hiding this comment.
Guard the CuTe BF16 path on op registration, not just SM.
cute_dsl_bf16_bmm_blackwell is only registered when IS_CUTLASS_DSL_AVAILABLE is true in tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py. If a user enables use_cute_dsl_bf16_bmm on an SM100 machine without that build-time support, these branches will raise a missing torch.ops.trtllm.cute_dsl_bf16_bmm_blackwell attribute instead of taking the existing bmm_out fallback. Please gate this behind a shared availability predicate, or disable the flag up front.
Also applies to: 2319-2326, 2377-2383, 2440-2447, 2510-2516, 2594-2601
🧰 Tools
🪛 Ruff (0.15.5)
[error] 2216-2217: Do not assign a lambda expression, use a def
Rewrite bmm_fn as a def
(E731)
[error] 2219-2220: Do not assign a lambda expression, use a def
Rewrite bmm_fn as a def
(E731)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/modules/attention.py` around lines 2215 - 2223, The
branch using CuTe BF16 bmm must be guarded by op-registration availability, not
only by SM check: update the condition that currently reads "if
self.use_cute_dsl_bf16_bmm and is_sm_100f()" to also verify the CuTe DSL op
exists (e.g., check a module-level flag like IS_CUTLASS_DSL_AVAILABLE or use
hasattr(torch.ops.trtllm, 'cute_dsl_bf16_bmm_blackwell')) so that when the op is
not registered the code falls back to torch.ops.trtllm.bmm_out; apply the same
change for the other occurrences noted (the blocks using
cute_dsl_bf16_bmm_blackwell at the locations around lines 2319-2326, 2377-2383,
2440-2447, 2510-2516, 2594-2601) and ensure maybe_execute_in_parallel still
receives the correct bmm_fn fallback.
| @skip_pre_blackwell | ||
| @parametrize_with_ids("cuda_graph", [False, True]) | ||
| def test_cute_dsl_bf16_bmm(self, cuda_graph): | ||
| kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) | ||
| pytorch_config = dict( | ||
| disable_overlap_scheduler=True, | ||
| cuda_graph_config=CudaGraphConfig() if cuda_graph else None, | ||
| use_cute_dsl_bf16_bmm=True, | ||
| ) | ||
|
|
||
| with LLM( | ||
| self.MODEL_PATH, | ||
| kv_cache_config=kv_cache_config, | ||
| **pytorch_config, | ||
| ) as llm: | ||
| task = GSM8K(self.MODEL_NAME) | ||
| task.evaluate(llm) |
There was a problem hiding this comment.
Gate these tests to SM100f so they actually hit the CuteDSL BF16 path.
This feature is runtime-gated by is_sm_100f(). With only @skip_pre_blackwell, these tests can still pass on non-SM100 Blackwell parts via the fallback BMM path, so they don't reliably validate the new kernel.
Suggested fix
`@skip_pre_blackwell`
`@parametrize_with_ids`("cuda_graph", [False, True])
def test_cute_dsl_bf16_bmm(self, cuda_graph):
+ if not is_sm_100f():
+ pytest.skip("CuteDSL BF16 BMM is only exercised on SM100f")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
pytorch_config = dict(
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
use_cute_dsl_bf16_bmm=True,
@@
`@pytest.mark.skip_less_device`(4)
`@skip_pre_blackwell`
`@parametrize_with_ids`("cuda_graph", [False, True])
`@pytest.mark.parametrize`("tp_size,pp_size,ep_size", [(4, 1, 1), (4, 1, 4)],
ids=["tp4", "ep4"])
def test_cute_dsl_bf16_bmm_4gpus(self, tp_size, pp_size, ep_size,
cuda_graph):
+ if not is_sm_100f():
+ pytest.skip("CuteDSL BF16 BMM is only exercised on SM100f")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
pytorch_config = dict(
disable_overlap_scheduler=True,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
use_cute_dsl_bf16_bmm=True,Also applies to: 1937-1960
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 1753 -
1769, The test test_cute_dsl_bf16_bmm is only decorated with `@skip_pre_blackwell`
but the CuteDSL BF16 path is runtime-gated by is_sm_100f(), so add a skip/guard
to ensure the test only runs on SM100f silicon (e.g., add a
pytest.mark.skipif(not is_sm_100f(), reason=...) or a guard that calls
is_sm_100f() before constructing the LLM) so the test actually exercises the
CuteDSL BF16 kernel; do the same change for the duplicate tests around lines
1937-1960 (same test names/logic) so all CuteDSL BF16 BMM tests are gated by
is_sm_100f().
Add the same memory gate used by the other single-GPU DeepSeek-V3-Lite BF16 tests.
The adjacent BF16 tests for this model already require at least 60 GB. Without that decorator here, this test is more likely to fail because the model does not fit, not because of the CuteDSL BMM path.
Suggested fix
+ `@pytest.mark.skip_less_device_memory`(60000)
`@skip_pre_blackwell`
`@parametrize_with_ids`("cuda_graph", [False, True])
def test_cute_dsl_bf16_bmm(self, cuda_graph):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 1753 -
1769, Add the same GPU memory gate decorator used by the other single-GPU
DeepSeek-V3-Lite BF16 tests to the test_cute_dsl_bf16_bmm function so it
requires at least ~60 GB of GPU memory; place that decorator above
`@parametrize_with_ids` and reuse the existing helper/decorator used by adjacent
BF16 tests (so the test skips when the machine lacks the required memory),
leaving the rest of the test (KvCacheConfig, pytorch_config, LLM context and
GSM8K evaluation) unchanged.
8454f55 to
748d50a
Compare
748d50a to
8f624e9
Compare
8f624e9 to
33ac890
Compare
|
/bot run |
|
PR_Github #39369 [ run ] triggered by Bot. Commit: |
|
PR_Github #39369 [ run ] completed with state
|
33ac890 to
1353e16
Compare
|
/bot run |
|
PR_Github #39887 [ run ] triggered by Bot. Commit: |
|
PR_Github #39887 [ run ] completed with state
|
|
/bot run |
|
PR_Github #40120 [ run ] triggered by Bot. Commit: |
|
PR_Github #40120 [ run ] completed with state
|
|
/bot run |
|
PR_Github #40143 [ run ] triggered by Bot. Commit: |
|
PR_Github #40143 [ run ] completed with state
|
2cbe5da to
2f2647d
Compare
Add a CuTe DSL BF16 persistent GEMM kernel as an alternative BMM implementation for MLA (Multi-head Latent Attention) on Blackwell GPUs. Gated behind the `use_cute_dsl_bf16_bmm` flag and `is_sm_100f()` so it has zero impact on existing code paths when disabled. New files: - dense_gemm_persistent.py: Blackwell SM100 warp-specialized kernel with TMA loads, TMEM accumulators, and TMA store epilogue. Adapted from CUTLASS example with API compatibility fixes for the installed DSL. Integration: - CuteDSLBf16BlackwellBmmRunner + trtllm::cute_dsl_bf16_bmm_blackwell op in cute_dsl_custom_ops.py with AutoTuner tactic selection. - use_cute_dsl_bf16_bmm config plumbed through LlmArgs -> ModelConfig -> model_loader -> MLA attention (6 BMM call sites: k_b_proj and v_b_proj in generation, context, and sparse-MLA paths). - --use_cute_dsl_bf16_bmm CLI flag in quickstart_advanced.py. - Integration tests: single-GPU and 4-GPU (tp4/ep4) accuracy tests with GSM8K for DeepSeek-V3-Lite BF16 in test_llm_api_pytorch.py. Non-contiguous tensor handling: the runner makes inputs contiguous before extracting data pointers since the kernel layout assumes contiguous [B,M,K]. Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
When mma_inst_tile_k > 1, cute.gemm() generates multiple sub-MMA instructions that all share the same ACCUMULATE flag. With ACCUMULATE=False on the first K tile, every sub-MMA cleared the accumulator so only the last sub-MMA's result survived, losing (mma_inst_tile_k - 1) * mma_inst_shape_k elements per output tile. This caused GSM8K accuracy to drop from 64.7% to 28.5%. Fix by adding an inner kblock loop that iterates sub-MMA instructions individually and sets ACCUMULATE=True after the first cute.gemm() call, matching the pattern used by blockscaled_contiguous_grouped_gemm.py. GSM8K accuracy restored to 64.86% (reference: 64.74%). Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
…kwell Add use_cute_dsl_bf16_gemm flag to enable CuTe DSL BF16 persistent GEMM for unquantized Linear layers in MLA attention (kv_a_proj_with_mqa, q_b_proj, kv_b_proj). This complements the existing BF16 BMM support. Changes: - Add CuteDSLBf16BlackwellGemmRunner class and custom op in cute_dsl_custom_ops.py - Add use_cute_dsl_bf16_gemm parameter to Linear class and UnquantizedLinearMethod - Wire use_cute_dsl_bf16_gemm through ModelConfig, LlmArgs, and model_loader - Pass flag to MLA Linear layers in attention.py - Add --use_cute_dsl_bf16_gemm CLI argument to quickstart_advanced.py - Add integration tests for single GPU and 4 GPU configurations Signed-off-by: Pei He <peih@nvidia.com> Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
…GEMM (FP32 output) Enable CuTe DSL BF16 GEMM kernel for DeepseekV3Gate router GEMM on Blackwell. The router computes BF16 input @ BF16 weight -> FP32 logits, which our persistent GEMM kernel already supports via FP32 accumulator and FP32 output. Key changes: - Support FP32 output dtype in CuteDSLBf16BlackwellGemmRunner (detect from output tensor instead of hardcoding BF16, add c_dtype to kernel cache key) - Relax cute_dsl_bf16_gemm_blackwell custom op to accept BF16 or FP32 output - Add CuTe DSL dispatch in DeepseekV3Gate.forward() gated by use_cute_dsl_bf16_gemm flag, with fallback to dsv3_router_gemm_op Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
… path Add wrapper_strided to PersistentDenseGemmKernel that accepts explicit A tensor strides, enabling non-contiguous views (e.g. from .transpose()) to be passed directly to TMA without .contiguous() copies. Update the BMM runner to compute and pass A strides instead of forcing contiguous tensors, removing the direct_copy_kernel_cuda overhead between attention and BMM. Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
…ction_core Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
…ed IDs The previous entries lacked pytest parameter brackets, which wouldn't match actual test node IDs. Expand to all 12 parametrized variants. Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
…BMM code Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
a02c00b to
54fdce9
Compare
Apply ruff format/lint fixes: - Convert multi-line docstrings to single-line where appropriate (D200) - Remove f-string prefix on strings without placeholders (F541) - Remove unused import - Use consistent double-quote docstrings instead of single-quotes - Fix indentation in docstrings Signed-off-by: Peace He <103117813+peaceh-nv@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
54fdce9 to
3e709a4
Compare
|
/bot run |
|
PR_Github #40627 [ run ] triggered by Bot. Commit: |
|
PR_Github #40627 [ run ] completed with state
|
|
/bot run |
|
PR_Github #40691 [ run ] triggered by Bot. Commit: |
|
PR_Github #40691 [ run ] completed with state
|
Add a CuTe DSL BF16 persistent GEMM kernel as an alternative BMM implementation for MLA (Multi-head Latent Attention) on Blackwell GPUs. Gated behind the
use_cute_dsl_bf16_bmmflag andis_sm_100f()so it has zero impact on existing code paths when disabled.Integration:
Non-contiguous tensor handling: the runner makes inputs contiguous before extracting data pointers since the kernel layout assumes contiguous [B,M,K].
Perf result on GB200 1k/1k 1ctx + 2gen DEP8 bs512 DeepSeek-FP4 tps/user:
native bf16 bmm + gemm : 25.04 tok/s
cute dsl bf16 bmm + cute dsl bf16 gemm : 24.47 tok/s
cute dsl bf16 bmm + cublas bf16 gemm : 25.26 tok/s
cublas bf16 bmm + cute dsl bf16 gemm : 25.25 tok/s
Summary by CodeRabbit
New Features
Tests