Skip to content

Commit 9ca464b

Browse files
committed
[TRTLLM-11289][feat] Integrate CuteDSL's bf16 dense GEMMs
Add a CuTe DSL BF16 persistent GEMM kernel as an alternative BMM implementation for MLA (Multi-head Latent Attention) on Blackwell GPUs. Gated behind the `use_cute_dsl_bf16_bmm` flag and `is_sm_100f()` so it has zero impact on existing code paths when disabled. New files: - dense_gemm_persistent.py: Blackwell SM100 warp-specialized kernel with TMA loads, TMEM accumulators, and TMA store epilogue. Adapted from CUTLASS example with API compatibility fixes for the installed DSL. Integration: - CuteDSLBf16BlackwellBmmRunner + trtllm::cute_dsl_bf16_bmm_blackwell op in cute_dsl_custom_ops.py with AutoTuner tactic selection. - use_cute_dsl_bf16_bmm config plumbed through LlmArgs -> ModelConfig -> model_loader -> MLA attention (6 BMM call sites: k_b_proj and v_b_proj in generation, context, and sparse-MLA paths). - --use_cute_dsl_bf16_bmm CLI flag in quickstart_advanced.py. - Integration tests: single-GPU and 4-GPU (tp4/ep4) accuracy tests with GSM8K for DeepSeek-V3-Lite BF16 in test_llm_api_pytorch.py. Non-contiguous tensor handling: the runner makes inputs contiguous before extracting data pointers since the kernel layout assumes contiguous [B,M,K]. Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
1 parent 1fef88e commit 9ca464b

File tree

8 files changed

+1377
-17
lines changed

8 files changed

+1377
-17
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,12 @@ def add_llm_args(parser):
184184
parser.add_argument('--relaxed_topk', type=int, default=1)
185185
parser.add_argument('--relaxed_delta', type=float, default=0.)
186186

187+
# CuTe DSL
188+
parser.add_argument('--use_cute_dsl_bf16_bmm',
189+
default=False,
190+
action='store_true',
191+
help='Use CuTe DSL bf16 persistent GEMM for BMM on Blackwell.')
192+
187193
# HF
188194
parser.add_argument('--trust_remote_code',
189195
default=False,
@@ -311,6 +317,7 @@ def setup_llm(args, **kwargs):
311317
gather_generation_logits=args.return_generation_logits,
312318
max_beam_width=args.max_beam_width,
313319
orchestrator_type=args.orchestrator_type,
320+
use_cute_dsl_bf16_bmm=args.use_cute_dsl_bf16_bmm,
314321
**kwargs)
315322

316323
use_beam_search = args.max_beam_width > 1

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def get_dense_gemm_approximate_cta_nums(
320320
Sm100BlockScaledPersistentDenseGemmKernel
321321
from ..cute_dsl_kernels.blackwell.top_k.filtered_top_k_decode_varlen import \
322322
FilteredTopKKernelVarlenDecode
323+
from ..cute_dsl_kernels.blackwell.dense_gemm_persistent import \
324+
PersistentDenseGemmKernel
323325
from ..cute_dsl_kernels.blackwell.utils import make_ptr
324326

325327
class CuteDSLNVFP4BlackwellRunner(TunableRunner):
@@ -3739,3 +3741,273 @@ def warmup_cute_dsl_topk_kernels(
37393741

37403742
logger.info(f"Warmup: pre-compiled {count} CuTE DSL top-k kernels "
37413743
f"(dtype={dtype}, top_k={top_k}, next_n={next_n})")
3744+
3745+
# ======================================================================
3746+
# BF16 Dense Persistent BMM (CuTe DSL) for Blackwell
3747+
# ======================================================================
3748+
3749+
class CuteDSLBf16BlackwellBmmRunner(TunableRunner):
3750+
kernel_class = PersistentDenseGemmKernel
3751+
kernel_cache = dict()
3752+
3753+
tuning_config = TuningConfig(
3754+
dynamic_tensor_specs=(DynamicTensorSpec(
3755+
0, 1, get_last_power_of_2_num_tokens_buckets,
3756+
last_positive_power_of_2), ),
3757+
)
3758+
3759+
def __init__(self, use_tvm_ffi: bool = True):
3760+
super().__init__()
3761+
self.use_tvm_ffi = use_tvm_ffi
3762+
3763+
def get_valid_tactics(
3764+
self,
3765+
inputs: List[torch.Tensor],
3766+
profile: OptimizationProfile,
3767+
**kwargs,
3768+
) -> List[int]:
3769+
3770+
if not is_sm_100f():
3771+
logger.debug(
3772+
f"CuteDSL: SM version {get_sm_version()} is not supported. "
3773+
f"CuteDSL BF16 BMM only supports SM 100 family. Skipping all tactics."
3774+
)
3775+
return []
3776+
# [b, m, k]
3777+
batch_size, m, k = inputs[0].shape[0], inputs[0].shape[
3778+
1], inputs[0].shape[2]
3779+
# [b, n, k]
3780+
n = inputs[1].shape[1]
3781+
# m,k
3782+
a_major = "k"
3783+
# n, k
3784+
b_major = "k"
3785+
# m, n
3786+
c_major = "n"
3787+
3788+
use_2cta_instrs_candi = [False, True]
3789+
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
3790+
cluster_shape_mn_candi = [
3791+
(1, 1),
3792+
(1, 2),
3793+
(1, 4),
3794+
(2, 1),
3795+
(2, 2),
3796+
(2, 4),
3797+
(4, 1),
3798+
(4, 2),
3799+
(4, 4),
3800+
]
3801+
return [
3802+
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
3803+
for use_2cta_instrs in use_2cta_instrs_candi
3804+
for mma_tiler_mn in mma_tiler_mn_candi
3805+
for cluster_shape_mn in cluster_shape_mn_candi
3806+
if self.__class__.kernel_class.can_implement(
3807+
cutlass.BFloat16, # ab_dtype
3808+
cutlass.Float32, # acc_dtype
3809+
cutlass.BFloat16, # c_dtype
3810+
use_2cta_instrs,
3811+
mma_tiler_mn,
3812+
cluster_shape_mn,
3813+
m,
3814+
n,
3815+
k,
3816+
batch_size,
3817+
a_major,
3818+
b_major,
3819+
c_major,
3820+
)
3821+
]
3822+
3823+
def forward(
3824+
self,
3825+
inputs: List[torch.Tensor],
3826+
tactic,
3827+
) -> None:
3828+
"""
3829+
Performs bf16 dense persistent batched gemm using CuTe DSL.
3830+
3831+
Args:
3832+
inputs (List[torch.Tensor]):
3833+
inputs[0]: Input tensor of shape (batch_size, m, k), dtype: bf16.
3834+
inputs[1]: Weight tensor of shape (batch_size, n, k), dtype: bf16.
3835+
inputs[2]: Output tensor of shape (batch_size, m, n), dtype: bf16.
3836+
tactic: Tiling and cluster strategy, typically a tuple
3837+
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
3838+
"""
3839+
if isinstance(tactic, tuple):
3840+
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
3841+
else:
3842+
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
3843+
False,
3844+
(128, 128),
3845+
(1, 1),
3846+
]
3847+
3848+
a_tensor, b_tensor, c_tensor = inputs
3849+
3850+
# Ensure A and B are contiguous — the kernel constructs CuTe
3851+
# layouts via make_ordered_layout assuming contiguous [B, M, K]
3852+
# and [B, N, K]. Transpose views (e.g. from .transpose(0,1))
3853+
# have swapped batch/seq strides which would cause the kernel
3854+
# to read from wrong memory locations.
3855+
a_tensor = a_tensor.contiguous()
3856+
b_tensor = b_tensor.contiguous()
3857+
3858+
# For the output, use a contiguous buffer so TMA store sees a
3859+
# standard layout; copy back afterwards if the original was
3860+
# non-contiguous.
3861+
c_needs_copy = not c_tensor.is_contiguous()
3862+
if c_needs_copy:
3863+
c_buf = torch.empty_like(c_tensor)
3864+
else:
3865+
c_buf = c_tensor
3866+
3867+
# c_buf is [B, M, N], permute to [M, N, B] for cute layout
3868+
c_tmp = c_buf.permute(1, 2, 0)
3869+
3870+
batch_size = a_tensor.shape[0]
3871+
m = a_tensor.shape[1]
3872+
k = a_tensor.shape[2]
3873+
n = b_tensor.shape[1]
3874+
3875+
if not self.use_tvm_ffi:
3876+
a_ptr = make_ptr(
3877+
cutlass.BFloat16,
3878+
a_tensor.data_ptr(),
3879+
cute.AddressSpace.gmem,
3880+
assumed_align=16,
3881+
)
3882+
b_ptr = make_ptr(
3883+
cutlass.BFloat16,
3884+
b_tensor.data_ptr(),
3885+
cute.AddressSpace.gmem,
3886+
assumed_align=16,
3887+
)
3888+
c_cute_tensor = cute.runtime.from_dlpack(
3889+
c_tmp).mark_layout_dynamic(leading_dim=1)
3890+
3891+
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
3892+
3893+
cache_key = (
3894+
use_2cta_instrs,
3895+
mma_tiler_mn,
3896+
cluster_shape_mn,
3897+
self.use_tvm_ffi,
3898+
)
3899+
if cache_key not in self.__class__.kernel_cache:
3900+
if self.use_tvm_ffi:
3901+
a_ptr = make_ptr(
3902+
cutlass.BFloat16,
3903+
a_tensor.data_ptr(),
3904+
cute.AddressSpace.gmem,
3905+
assumed_align=16,
3906+
)
3907+
b_ptr = make_ptr(
3908+
cutlass.BFloat16,
3909+
b_tensor.data_ptr(),
3910+
cute.AddressSpace.gmem,
3911+
assumed_align=16,
3912+
)
3913+
c_cute_tensor = cute.runtime.from_dlpack(
3914+
c_tmp).mark_layout_dynamic(leading_dim=1)
3915+
stream = cute.runtime.make_fake_stream(
3916+
use_tvm_ffi_env_stream=True)
3917+
3918+
gemm = self.__class__.kernel_class(
3919+
cutlass.Float32, # acc_dtype
3920+
use_2cta_instrs=use_2cta_instrs,
3921+
mma_tiler_mn=mma_tiler_mn,
3922+
cluster_shape_mn=cluster_shape_mn,
3923+
)
3924+
hardware_info = cutlass.utils.HardwareInfo()
3925+
max_active_clusters = hardware_info.get_max_active_clusters(
3926+
cluster_shape_mn[0] * cluster_shape_mn[1])
3927+
3928+
compiled_gemm = cute.compile(
3929+
gemm.wrapper,
3930+
m,
3931+
n,
3932+
k,
3933+
batch_size,
3934+
a_ptr,
3935+
b_ptr,
3936+
c_cute_tensor,
3937+
max_active_clusters=max_active_clusters,
3938+
stream=stream,
3939+
options=f"--opt-level 2 --enable-tvm-ffi"
3940+
if self.use_tvm_ffi else "--opt-level 2",
3941+
)
3942+
self.__class__.kernel_cache[cache_key] = compiled_gemm
3943+
else:
3944+
compiled_gemm = self.__class__.kernel_cache[cache_key]
3945+
3946+
# launch gemm kernel
3947+
if self.use_tvm_ffi:
3948+
compiled_gemm(
3949+
m,
3950+
n,
3951+
k,
3952+
batch_size,
3953+
a_tensor.data_ptr(),
3954+
b_tensor.data_ptr(),
3955+
c_tmp,
3956+
)
3957+
else:
3958+
compiled_gemm(
3959+
m,
3960+
n,
3961+
k,
3962+
batch_size,
3963+
a_ptr,
3964+
b_ptr,
3965+
c_cute_tensor,
3966+
stream=stream,
3967+
)
3968+
3969+
# Copy result back if original output was non-contiguous
3970+
if c_needs_copy:
3971+
c_tensor.copy_(c_buf)
3972+
3973+
# a/b: bf16, output: bf16
3974+
@torch.library.custom_op("trtllm::cute_dsl_bf16_bmm_blackwell",
3975+
mutates_args=("output", ),
3976+
device_types="cuda")
3977+
def cute_dsl_bf16_bmm_blackwell(
3978+
input: torch.Tensor,
3979+
weight: torch.Tensor,
3980+
output: torch.Tensor,
3981+
use_tvm_ffi: bool = True,
3982+
) -> None:
3983+
if not is_sm_100f():
3984+
raise ValueError(
3985+
f"CuteDSL: SM version {get_sm_version()} is not supported. "
3986+
f"CuteDSL BF16 BMM only supports SM 100 family.")
3987+
3988+
tuner = AutoTuner.get()
3989+
3990+
runner = CuteDSLBf16BlackwellBmmRunner(use_tvm_ffi=use_tvm_ffi)
3991+
3992+
inputs = [input, weight, output]
3993+
3994+
_, best_tactic = tuner.choose_one(
3995+
"trtllm::cute_dsl_bf16_bmm_blackwell::gemm",
3996+
[runner],
3997+
runner.__class__.tuning_config,
3998+
inputs,
3999+
)
4000+
runner(inputs, tactic=best_tactic)
4001+
4002+
@torch.library.register_fake("trtllm::cute_dsl_bf16_bmm_blackwell")
4003+
def _(
4004+
mat_a: torch.Tensor,
4005+
mat_b: torch.Tensor,
4006+
output: torch.Tensor,
4007+
use_tvm_ffi: bool = True,
4008+
) -> None:
4009+
batch_size, m, k = mat_a.shape[0], mat_a.shape[1], mat_a.shape[2]
4010+
n = mat_b.shape[1]
4011+
assert output.dtype == torch.bfloat16, "CuTe DSL bf16 bmm output dtype must be bf16"
4012+
assert output.shape == (batch_size, m,
4013+
n), "CuTe DSL bf16 bmm output shape is incorrect"

0 commit comments

Comments
 (0)