Skip to content

Commit b7a5e72

Browse files
committed
[TRTLLM-11289][feat] Add CuTe DSL BF16 GEMM for Linear layers on Blackwell
Add use_cute_dsl_bf16_gemm flag to enable CuTe DSL BF16 persistent GEMM for unquantized Linear layers in MLA attention (kv_a_proj_with_mqa, q_b_proj, kv_b_proj). This complements the existing BF16 BMM support. Changes: - Add CuteDSLBf16BlackwellGemmRunner class and custom op in cute_dsl_custom_ops.py - Add use_cute_dsl_bf16_gemm parameter to Linear class and UnquantizedLinearMethod - Wire use_cute_dsl_bf16_gemm through ModelConfig, LlmArgs, and model_loader - Pass flag to MLA Linear layers in attention.py - Add --use_cute_dsl_bf16_gemm CLI argument to quickstart_advanced.py - Add integration tests for single GPU and 4 GPU configurations Signed-off-by: Pei He <peih@nvidia.com> Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
1 parent fb79f39 commit b7a5e72

File tree

8 files changed

+369
-6
lines changed

8 files changed

+369
-6
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ def add_llm_args(parser):
203203
default=False,
204204
action='store_true',
205205
help='Use CuTe DSL bf16 persistent GEMM for BMM on Blackwell.')
206+
parser.add_argument('--use_cute_dsl_bf16_gemm',
207+
default=False,
208+
action='store_true',
209+
help='Use CuTe DSL bf16 persistent GEMM for Linear layers on Blackwell.')
206210

207211
# HF
208212
parser.add_argument('--trust_remote_code',
@@ -334,6 +338,7 @@ def setup_llm(args, **kwargs):
334338
max_beam_width=args.max_beam_width,
335339
orchestrator_type=args.orchestrator_type,
336340
use_cute_dsl_bf16_bmm=args.use_cute_dsl_bf16_bmm,
341+
use_cute_dsl_bf16_gemm=args.use_cute_dsl_bf16_gemm,
337342
**kwargs)
338343

339344
use_beam_search = args.max_beam_width > 1

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4011,3 +4011,284 @@ def _(
40114011
assert output.dtype == torch.bfloat16, "CuTe DSL bf16 bmm output dtype must be bf16"
40124012
assert output.shape == (batch_size, m,
40134013
n), "CuTe DSL bf16 bmm output shape is incorrect"
4014+
4015+
# ======================================================================
4016+
# BF16 Dense Persistent GEMM (CuTe DSL) for Blackwell - Linear layers
4017+
# ======================================================================
4018+
4019+
class CuteDSLBf16BlackwellGemmRunner(TunableRunner):
4020+
"""
4021+
CuTe DSL BF16 GEMM runner for Linear layers.
4022+
4023+
Unlike BMM which operates on [B, M, K] @ [B, N, K] -> [B, M, N],
4024+
GEMM operates on [M, K] @ [N, K]^T -> [M, N] (standard Linear).
4025+
4026+
We reuse PersistentDenseGemmKernel with batch_size=1.
4027+
"""
4028+
kernel_class = PersistentDenseGemmKernel
4029+
kernel_cache = dict()
4030+
4031+
tuning_config = TuningConfig(
4032+
dynamic_tensor_specs=(DynamicTensorSpec(
4033+
0, 0, get_last_power_of_2_num_tokens_buckets,
4034+
last_positive_power_of_2), ),
4035+
)
4036+
4037+
def __init__(self, use_tvm_ffi: bool = True):
4038+
super().__init__()
4039+
self.use_tvm_ffi = use_tvm_ffi
4040+
4041+
def get_valid_tactics(
4042+
self,
4043+
inputs: List[torch.Tensor],
4044+
profile: OptimizationProfile,
4045+
**kwargs,
4046+
) -> List[int]:
4047+
4048+
if not is_sm_100f():
4049+
logger.debug(
4050+
f"CuteDSL: SM version {get_sm_version()} is not supported. "
4051+
f"CuteDSL BF16 GEMM only supports SM 100 family. Skipping all tactics."
4052+
)
4053+
return []
4054+
4055+
# input: [M, K], weight: [N, K]
4056+
m, k = inputs[0].shape[0], inputs[0].shape[1]
4057+
n = inputs[1].shape[0]
4058+
batch_size = 1
4059+
4060+
# Layouts: A is [M, K] K-major, B is [N, K] K-major
4061+
a_major = "k"
4062+
b_major = "k"
4063+
c_major = "n"
4064+
4065+
use_2cta_instrs_candi = [False, True]
4066+
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
4067+
cluster_shape_mn_candi = [
4068+
(1, 1),
4069+
(1, 2),
4070+
(1, 4),
4071+
(2, 1),
4072+
(2, 2),
4073+
(2, 4),
4074+
(4, 1),
4075+
(4, 2),
4076+
(4, 4),
4077+
]
4078+
return [
4079+
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
4080+
for use_2cta_instrs in use_2cta_instrs_candi
4081+
for mma_tiler_mn in mma_tiler_mn_candi
4082+
for cluster_shape_mn in cluster_shape_mn_candi
4083+
if self.__class__.kernel_class.can_implement(
4084+
cutlass.BFloat16, # ab_dtype
4085+
cutlass.Float32, # acc_dtype
4086+
cutlass.BFloat16, # c_dtype
4087+
use_2cta_instrs,
4088+
mma_tiler_mn,
4089+
cluster_shape_mn,
4090+
m,
4091+
n,
4092+
k,
4093+
batch_size,
4094+
a_major,
4095+
b_major,
4096+
c_major,
4097+
)
4098+
]
4099+
4100+
def forward(
4101+
self,
4102+
inputs: List[torch.Tensor],
4103+
tactic,
4104+
) -> None:
4105+
"""
4106+
Performs bf16 dense persistent GEMM using CuTe DSL.
4107+
4108+
Args:
4109+
inputs (List[torch.Tensor]):
4110+
inputs[0]: Input tensor of shape (m, k), dtype: bf16.
4111+
inputs[1]: Weight tensor of shape (n, k), dtype: bf16.
4112+
inputs[2]: Output tensor of shape (m, n), dtype: bf16.
4113+
tactic: Tiling and cluster strategy, typically a tuple
4114+
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
4115+
"""
4116+
if isinstance(tactic, tuple):
4117+
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
4118+
else:
4119+
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
4120+
False,
4121+
(128, 128),
4122+
(1, 1),
4123+
]
4124+
4125+
a_tensor, b_tensor, c_tensor = inputs
4126+
4127+
# Input: [M, K], Weight: [N, K], Output: [M, N]
4128+
m, k = a_tensor.shape[0], a_tensor.shape[1]
4129+
n = b_tensor.shape[0]
4130+
batch_size = 1
4131+
4132+
# Ensure inputs are contiguous
4133+
a_tensor = a_tensor.contiguous()
4134+
b_tensor = b_tensor.contiguous()
4135+
4136+
# For output, use contiguous buffer if needed
4137+
c_needs_copy = not c_tensor.is_contiguous()
4138+
if c_needs_copy:
4139+
c_buf = torch.empty_like(c_tensor)
4140+
else:
4141+
c_buf = c_tensor
4142+
4143+
# Reshape to [1, M, K], [1, N, K], [1, M, N] for the batched kernel
4144+
a_batched = a_tensor.unsqueeze(0) # [1, M, K]
4145+
b_batched = b_tensor.unsqueeze(0) # [1, N, K]
4146+
# c_buf is [M, N], permute to [M, N, 1] for cute layout
4147+
c_tmp = c_buf.unsqueeze(-1) # [M, N, 1]
4148+
4149+
if not self.use_tvm_ffi:
4150+
a_ptr = make_ptr(
4151+
cutlass.BFloat16,
4152+
a_batched.data_ptr(),
4153+
cute.AddressSpace.gmem,
4154+
assumed_align=16,
4155+
)
4156+
b_ptr = make_ptr(
4157+
cutlass.BFloat16,
4158+
b_batched.data_ptr(),
4159+
cute.AddressSpace.gmem,
4160+
assumed_align=16,
4161+
)
4162+
c_cute_tensor = cute.runtime.from_dlpack(
4163+
c_tmp).mark_layout_dynamic(leading_dim=1)
4164+
4165+
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
4166+
4167+
cache_key = (
4168+
use_2cta_instrs,
4169+
mma_tiler_mn,
4170+
cluster_shape_mn,
4171+
self.use_tvm_ffi,
4172+
)
4173+
if cache_key not in self.__class__.kernel_cache:
4174+
if self.use_tvm_ffi:
4175+
a_ptr = make_ptr(
4176+
cutlass.BFloat16,
4177+
a_batched.data_ptr(),
4178+
cute.AddressSpace.gmem,
4179+
assumed_align=16,
4180+
)
4181+
b_ptr = make_ptr(
4182+
cutlass.BFloat16,
4183+
b_batched.data_ptr(),
4184+
cute.AddressSpace.gmem,
4185+
assumed_align=16,
4186+
)
4187+
c_cute_tensor = cute.runtime.from_dlpack(
4188+
c_tmp).mark_layout_dynamic(leading_dim=1)
4189+
stream = cute.runtime.make_fake_stream(
4190+
use_tvm_ffi_env_stream=True)
4191+
4192+
gemm = self.__class__.kernel_class(
4193+
cutlass.Float32, # acc_dtype
4194+
use_2cta_instrs=use_2cta_instrs,
4195+
mma_tiler_mn=mma_tiler_mn,
4196+
cluster_shape_mn=cluster_shape_mn,
4197+
)
4198+
hardware_info = cutlass.utils.HardwareInfo()
4199+
max_active_clusters = hardware_info.get_max_active_clusters(
4200+
cluster_shape_mn[0] * cluster_shape_mn[1])
4201+
4202+
compiled_gemm = cute.compile(
4203+
gemm.wrapper,
4204+
m,
4205+
n,
4206+
k,
4207+
batch_size,
4208+
a_ptr,
4209+
b_ptr,
4210+
c_cute_tensor,
4211+
max_active_clusters=max_active_clusters,
4212+
stream=stream,
4213+
options=f"--opt-level 2 --enable-tvm-ffi"
4214+
if self.use_tvm_ffi else "--opt-level 2",
4215+
)
4216+
self.__class__.kernel_cache[cache_key] = compiled_gemm
4217+
else:
4218+
compiled_gemm = self.__class__.kernel_cache[cache_key]
4219+
4220+
# launch gemm kernel
4221+
if self.use_tvm_ffi:
4222+
compiled_gemm(
4223+
m,
4224+
n,
4225+
k,
4226+
batch_size,
4227+
a_batched.data_ptr(),
4228+
b_batched.data_ptr(),
4229+
c_tmp,
4230+
)
4231+
else:
4232+
compiled_gemm(
4233+
m,
4234+
n,
4235+
k,
4236+
batch_size,
4237+
a_ptr,
4238+
b_ptr,
4239+
c_cute_tensor,
4240+
stream=stream,
4241+
)
4242+
4243+
# Copy result back if original output was non-contiguous
4244+
if c_needs_copy:
4245+
c_tensor.copy_(c_buf)
4246+
4247+
# input: [M, K], weight: [N, K], output: [M, N]
4248+
@torch.library.custom_op("trtllm::cute_dsl_bf16_gemm_blackwell",
4249+
mutates_args=("output", ),
4250+
device_types="cuda")
4251+
def cute_dsl_bf16_gemm_blackwell(
4252+
input: torch.Tensor,
4253+
weight: torch.Tensor,
4254+
output: torch.Tensor,
4255+
use_tvm_ffi: bool = True,
4256+
) -> None:
4257+
"""
4258+
CuTe DSL BF16 GEMM for Linear layers on Blackwell.
4259+
4260+
Computes: output = input @ weight^T
4261+
- input: [M, K] (num_tokens, in_features)
4262+
- weight: [N, K] (out_features, in_features)
4263+
- output: [M, N] (num_tokens, out_features)
4264+
"""
4265+
if not is_sm_100f():
4266+
raise ValueError(
4267+
f"CuteDSL: SM version {get_sm_version()} is not supported. "
4268+
f"CuteDSL BF16 GEMM only supports SM 100 family.")
4269+
4270+
tuner = AutoTuner.get()
4271+
4272+
runner = CuteDSLBf16BlackwellGemmRunner(use_tvm_ffi=use_tvm_ffi)
4273+
4274+
inputs = [input, weight, output]
4275+
4276+
_, best_tactic = tuner.choose_one(
4277+
"trtllm::cute_dsl_bf16_gemm_blackwell::gemm",
4278+
[runner],
4279+
runner.__class__.tuning_config,
4280+
inputs,
4281+
)
4282+
runner(inputs, tactic=best_tactic)
4283+
4284+
@torch.library.register_fake("trtllm::cute_dsl_bf16_gemm_blackwell")
4285+
def _(
4286+
mat_a: torch.Tensor,
4287+
mat_b: torch.Tensor,
4288+
output: torch.Tensor,
4289+
use_tvm_ffi: bool = True,
4290+
) -> None:
4291+
m, k = mat_a.shape[0], mat_a.shape[1]
4292+
n = mat_b.shape[0]
4293+
assert output.dtype == torch.bfloat16, "CuTe DSL bf16 gemm output dtype must be bf16"
4294+
assert output.shape == (m, n), "CuTe DSL bf16 gemm output shape is incorrect"

tensorrt_llm/_torch/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class ModelConfig(Generic[TConfig]):
126126
use_cute_dsl_blockscaling_mm: bool = False
127127
use_cute_dsl_blockscaling_bmm: bool = False
128128
use_cute_dsl_bf16_bmm: bool = False
129+
use_cute_dsl_bf16_gemm: bool = False
129130

130131
_frozen: bool = field(default=False, init=False, repr=False)
131132

tensorrt_llm/_torch/modules/attention.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def __init__(
462462
self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
463463
self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm
464464
self.use_cute_dsl_bf16_bmm = config.use_cute_dsl_bf16_bmm
465+
self.use_cute_dsl_bf16_gemm = config.use_cute_dsl_bf16_gemm
465466

466467
qkv_shard_indices_mapping = {
467468
"q": (0, self.q_size * (2 if self.attn_output_gate else 1)),
@@ -1124,6 +1125,7 @@ def __init__(
11241125
self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
11251126
self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm
11261127
self.use_cute_dsl_bf16_bmm = config.use_cute_dsl_bf16_bmm
1128+
self.use_cute_dsl_bf16_gemm = config.use_cute_dsl_bf16_gemm
11271129

11281130
if not self.is_lite:
11291131
self.kv_a_proj_with_mqa = Linear(
@@ -1135,7 +1137,8 @@ def __init__(
11351137
skip_create_weights_in_init=config.skip_create_weights_in_init,
11361138
use_custom_cublas_mm=True,
11371139
force_dynamic_quantization=config.force_dynamic_quantization,
1138-
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
1140+
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm,
1141+
use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm)
11391142

11401143
self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank,
11411144
eps=rms_norm_eps,
@@ -1152,7 +1155,8 @@ def __init__(
11521155
skip_create_weights_in_init=config.skip_create_weights_in_init,
11531156
allreduce_strategy=config.allreduce_strategy,
11541157
force_dynamic_quantization=config.force_dynamic_quantization,
1155-
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
1158+
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm,
1159+
use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm)
11561160
else:
11571161
self.kv_a_proj_with_mqa = Linear(
11581162
hidden_size,
@@ -1163,7 +1167,8 @@ def __init__(
11631167
skip_create_weights_in_init=config.skip_create_weights_in_init,
11641168
use_custom_cublas_mm=True,
11651169
force_dynamic_quantization=config.force_dynamic_quantization,
1166-
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
1170+
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm,
1171+
use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm)
11671172

11681173
self.q_proj = Linear(
11691174
self.q_lora_rank,
@@ -1176,7 +1181,8 @@ def __init__(
11761181
skip_create_weights_in_init=config.skip_create_weights_in_init,
11771182
allreduce_strategy=config.allreduce_strategy,
11781183
force_dynamic_quantization=config.force_dynamic_quantization,
1179-
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
1184+
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm,
1185+
use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm)
11801186
self.q_b_proj = self.q_proj
11811187

11821188
self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank,
@@ -1194,7 +1200,8 @@ def __init__(
11941200
skip_create_weights_in_init=config.skip_create_weights_in_init,
11951201
allreduce_strategy=config.allreduce_strategy,
11961202
force_dynamic_quantization=config.force_dynamic_quantization,
1197-
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
1203+
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm,
1204+
use_cute_dsl_bf16_gemm=self.use_cute_dsl_bf16_gemm)
11981205
# This parameter will view into self.kv_b_proj.weight after loading weights.
11991206
# For dummy weight initialization, this parameter is initialized with empty tensor.
12001207
# Used in forward_absorption only

0 commit comments

Comments
 (0)