Skip to content

Commit 33ac890

Browse files
committed
[TRTLLM-11289][perf] Eliminate contiguous copies in CuTe DSL BF16 BMM path
Add wrapper_strided to PersistentDenseGemmKernel that accepts explicit A tensor strides, enabling non-contiguous views (e.g. from .transpose()) to be passed directly to TMA without .contiguous() copies. Update the BMM runner to compute and pass A strides instead of forcing contiguous tensors, removing the direct_copy_kernel_cuda overhead between attention and BMM. Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
1 parent 22261e9 commit 33ac890

File tree

2 files changed

+81
-24
lines changed

2 files changed

+81
-24
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3847,31 +3847,27 @@ def forward(
38473847

38483848
a_tensor, b_tensor, c_tensor = inputs
38493849

3850-
# Ensure A and B are contiguous — the kernel constructs CuTe
3851-
# layouts via make_ordered_layout assuming contiguous [B, M, K]
3852-
# and [B, N, K]. Transpose views (e.g. from .transpose(0,1))
3853-
# have swapped batch/seq strides which would cause the kernel
3854-
# to read from wrong memory locations.
3855-
a_tensor = a_tensor.contiguous()
3856-
b_tensor = b_tensor.contiguous()
3857-
3858-
# For the output, use a contiguous buffer so TMA store sees a
3859-
# standard layout; copy back afterwards if the original was
3860-
# non-contiguous.
3861-
c_needs_copy = not c_tensor.is_contiguous()
3862-
if c_needs_copy:
3863-
c_buf = torch.empty_like(c_tensor)
3864-
else:
3865-
c_buf = c_tensor
3866-
3867-
# c_buf is [B, M, N], permute to [M, N, B] for cute layout
3868-
c_tmp = c_buf.permute(1, 2, 0)
3850+
# Permute C from [B, M, N] to [M, N, B] for CuTe layout.
3851+
# from_dlpack captures the actual strides, so non-contiguous
3852+
# views (e.g. from .transpose(0,1)) are handled natively by
3853+
# TMA without an extra copy.
3854+
c_tmp = c_tensor.permute(1, 2, 0)
38693855

38703856
batch_size = a_tensor.shape[0]
38713857
m = a_tensor.shape[1]
38723858
k = a_tensor.shape[2]
38733859
n = b_tensor.shape[1]
38743860

3861+
# Compute A strides so the kernel can handle non-contiguous
3862+
# views (e.g. [M,B,K].transpose(0,1) → [B,M,K] with
3863+
# non-standard strides) without a .contiguous() copy.
3864+
# CuTe tensor is (M, K, B) so strides map as:
3865+
# M stride = a_tensor.stride(1)
3866+
# K stride = 1 (always innermost)
3867+
# B stride = a_tensor.stride(0)
3868+
a_stride_m = a_tensor.stride(1)
3869+
a_stride_batch = a_tensor.stride(0)
3870+
38753871
if not self.use_tvm_ffi:
38763872
a_ptr = make_ptr(
38773873
cutlass.BFloat16,
@@ -3926,14 +3922,16 @@ def forward(
39263922
cluster_shape_mn[0] * cluster_shape_mn[1])
39273923

39283924
compiled_gemm = cute.compile(
3929-
gemm.wrapper,
3925+
gemm.wrapper_strided,
39303926
m,
39313927
n,
39323928
k,
39333929
batch_size,
39343930
a_ptr,
39353931
b_ptr,
39363932
c_cute_tensor,
3933+
a_stride_m,
3934+
a_stride_batch,
39373935
max_active_clusters=max_active_clusters,
39383936
stream=stream,
39393937
options=f"--opt-level 2 --enable-tvm-ffi"
@@ -3953,6 +3951,8 @@ def forward(
39533951
a_tensor.data_ptr(),
39543952
b_tensor.data_ptr(),
39553953
c_tmp,
3954+
a_stride_m,
3955+
a_stride_batch,
39563956
)
39573957
else:
39583958
compiled_gemm(
@@ -3963,13 +3963,11 @@ def forward(
39633963
a_ptr,
39643964
b_ptr,
39653965
c_cute_tensor,
3966+
a_stride_m,
3967+
a_stride_batch,
39663968
stream=stream,
39673969
)
39683970

3969-
# Copy result back if original output was non-contiguous
3970-
if c_needs_copy:
3971-
c_tensor.copy_(c_buf)
3972-
39733971
# a/b: bf16, output: bf16
39743972
@torch.library.custom_op("trtllm::cute_dsl_bf16_bmm_blackwell",
39753973
mutates_args=("output", ),

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_gemm_persistent.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,3 +1010,62 @@ def wrapper(
10101010
max_active_clusters,
10111011
stream,
10121012
)
1013+
1014+
@cute.jit
1015+
def wrapper_strided(
1016+
self,
1017+
m: cutlass.Int32,
1018+
n: cutlass.Int32,
1019+
k: cutlass.Int32,
1020+
batch_size: cutlass.Int32,
1021+
a_ptr: cute.Pointer,
1022+
b_ptr: cute.Pointer,
1023+
c_tensor: cute.Tensor,
1024+
a_stride_m: cutlass.Int32,
1025+
a_stride_batch: cutlass.Int32,
1026+
max_active_clusters: cutlass.Constexpr,
1027+
stream: cuda.CUstream,
1028+
):
1029+
"""Executes the GEMM kernel with explicit A tensor strides.
1030+
1031+
Like ``wrapper`` but allows non-contiguous A tensors by accepting
1032+
the M and batch strides directly. The K stride is assumed to be 1
1033+
(row-major in K). B is always contiguous.
1034+
1035+
Args:
1036+
m: The M dimension of the GEMM problem.
1037+
n: The N dimension of the GEMM problem.
1038+
k: The K dimension of the GEMM problem.
1039+
batch_size: The batch dimension.
1040+
a_ptr: Pointer to the A tensor data.
1041+
b_ptr: Pointer to the B tensor data.
1042+
c_tensor: Output tensor as cute.Tensor.
1043+
a_stride_m: Stride of A along the M dimension (in elements).
1044+
a_stride_batch: Stride of A along the batch dimension (in elements).
1045+
max_active_clusters: Maximum number of active clusters.
1046+
stream: CUDA stream for the operation.
1047+
"""
1048+
# A with explicit strides: (M, K, batch_size), K stride = 1
1049+
a_tensor = cute.make_tensor(
1050+
a_ptr,
1051+
layout=cute.make_layout(
1052+
(m, k, batch_size),
1053+
stride=(a_stride_m, 1, a_stride_batch),
1054+
),
1055+
)
1056+
# B is always contiguous: (N, K, batch_size) with K innermost
1057+
b_tensor = cute.make_tensor(
1058+
b_ptr,
1059+
layout=cute.make_ordered_layout(
1060+
(n, k, batch_size),
1061+
order=(1, 0, 2),
1062+
),
1063+
)
1064+
1065+
self(
1066+
a_tensor,
1067+
b_tensor,
1068+
c_tensor,
1069+
max_active_clusters,
1070+
stream,
1071+
)

0 commit comments

Comments
 (0)