Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,10 +1864,13 @@ def get_valid_tactics(
mma_tiler_mn_candidates = [(self.tile_size, 128),
(self.tile_size, 256)]
cluster_shape_mn_candidates = [(self.tile_size // 128, 1)]
# TODO: Add raster_along_m=True if we find it more performant in some cases.
raster_along_m_candidates = [False]

valid_tactics = []
for mma_tiler_mn, cluster_shape_mn in itertools.product(
mma_tiler_mn_candidates, cluster_shape_mn_candidates):
for mma_tiler_mn, cluster_shape_mn, raster_along_m in itertools.product(
mma_tiler_mn_candidates, cluster_shape_mn_candidates,
raster_along_m_candidates):
if self.__class__.kernel_class.can_implement(
ab_dtype=cutlass.Float4E2M1FN,
sf_dtype=cutlass.Float8E4M3FN,
Expand All @@ -1883,7 +1886,8 @@ def get_valid_tactics(
b_major="k",
c_major="n",
):
valid_tactics.append((mma_tiler_mn, cluster_shape_mn))
valid_tactics.append(
(mma_tiler_mn, cluster_shape_mn, raster_along_m))

return valid_tactics

Expand Down Expand Up @@ -2013,22 +2017,24 @@ def forward(self, inputs: List[torch.Tensor],
stream = cuda.CUstream(torch_stream.cuda_stream)

if isinstance(tactic, tuple):
mma_tiler_mn, cluster_shape_mn = tactic
mma_tiler_mn, cluster_shape_mn, raster_along_m = tactic
else:
mma_tiler_mn = (self.tile_size, 128)
cluster_shape_mn = (self.tile_size // 128, 1)
raster_along_m = False
assert mma_tiler_mn[
0] == self.tile_size, f"Tactic ({tactic}) is incompatible with tile size ({self.tile_size})"

cache_key = (self.scaling_vector_size, self.tile_size, self.top_k,
mma_tiler_mn, cluster_shape_mn)
mma_tiler_mn, cluster_shape_mn, raster_along_m)
if cache_key not in self.__class__.kernel_cache:
gemm = self.__class__.kernel_class(
sf_vec_size=self.scaling_vector_size,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
vectorized_f32=True,
topk=self.top_k,
raster_along_m=raster_along_m,
)
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import cutlass.utils.blockscaled_layout as blockscaled_utils
from cutlass._mlir.dialects import math
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cutlass_dsl import Int32

from .custom_pipeline import PipelineCpAsyncUmma
from .utils import (
Expand Down Expand Up @@ -154,6 +155,144 @@
"""


# TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.4 is released.
def hooked_PersistentTileSchedulerParams_init(
self,
problem_shape_ntile_mnl: cute.Shape,
cluster_shape_mnk: cute.Shape,
swizzle_size: int = 1,
raster_along_m: bool = True,
*,
loc=None,
ip=None,
):
if cluster_shape_mnk[2] != 1:
raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}")
if swizzle_size < 1:
raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}")

self.problem_shape_ntile_mnl = problem_shape_ntile_mnl
# cluster_shape_mnk is kept for reconstruction
self._cluster_shape_mnk = cluster_shape_mnk
self.cluster_shape_mn = cluster_shape_mnk[:2]
self.swizzle_size = swizzle_size
self._raster_along_m = raster_along_m
self._loc = loc

# Apply swizzle if swizzle_size > 1
if swizzle_size > 1:
problem_shape_ncluster_mnl = cute.round_up(
self.problem_layout_ncluster_mnl.shape,
(1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1),
)

if raster_along_m:
self.problem_layout_ncluster_mnl = cute.make_layout(
(
problem_shape_ncluster_mnl[0],
(swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size),
problem_shape_ncluster_mnl[2],
),
stride=(
swizzle_size,
(1, swizzle_size * problem_shape_ncluster_mnl[0]),
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
),
loc=loc,
ip=ip,
)
else:
self.problem_layout_ncluster_mnl = cute.make_layout(
(
(swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size),
problem_shape_ncluster_mnl[1],
problem_shape_ncluster_mnl[2],
),
stride=(
(1, swizzle_size * problem_shape_ncluster_mnl[1]),
swizzle_size,
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
),
loc=loc,
ip=ip,
)

# Create FastDivmod divisors (only when swizzle_size == 1 for correctness)
# FastDivmod assumes simple col-major/row-major layout, incompatible with swizzled layouts
if swizzle_size == 1:
problem_shape_ncluster_mnl = cute.ceil_div(
self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip
)
if raster_along_m:
self.problem_layout_ncluster_mnl = cute.make_layout(
problem_shape_ncluster_mnl,
stride=(
1,
problem_shape_ncluster_mnl[0],
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
),
loc=loc,
ip=ip,
)
else:
self.problem_layout_ncluster_mnl = cute.make_layout(
problem_shape_ncluster_mnl,
stride=(
problem_shape_ncluster_mnl[1],
1,
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
),
loc=loc,
ip=ip,
)
problem_layout_size = cute.size(self.problem_layout_ncluster_mnl, loc=loc, ip=ip)
cluster_count_m = self.problem_layout_ncluster_mnl.shape[0]
cluster_count_n = self.problem_layout_ncluster_mnl.shape[1]

# batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling)
self.batch_fdd = cute.fast_divmod_create_divisor(problem_layout_size, loc=loc, ip=ip)

# cluster_shape_m_fdd: Used to decode work_unit_id to cluster coordinates
self.cluster_shape_m_fdd = cute.fast_divmod_create_divisor(cluster_count_m, loc=loc, ip=ip)

# cluster_shape_n_fdd: Used for the second level decomposition
self.cluster_shape_n_fdd = cute.fast_divmod_create_divisor(cluster_count_n, loc=loc, ip=ip)
else:
# FastDivmod not applicable with swizzling, set to None
self.batch_fdd = None
self.cluster_shape_m_fdd = None
self.cluster_shape_n_fdd = None


def hooked_get_cluster_work_idx_with_fastdivmod(
self, current_work_linear_idx: Int32, *, loc=None, ip=None
) -> Tuple[Int32, Int32, Int32]:
work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd)

if self.params._raster_along_m:
# raster_along_m=True means column major (m is fastest)
# First, get cluster_m using cluster_shape_m_fdd
cluster_n_batch, cluster_m = divmod(work_unit_id, self.params.cluster_shape_m_fdd)

# Then decode cluster_n_batch to get cluster_n and batch_l using FastDivmod
batch_l, cluster_n = divmod(cluster_n_batch, self.params.cluster_shape_n_fdd)
else:
# raster_along_m=False means row major (n is fastest)
# First, get cluster_n using cluster_shape_n_fdd
cluster_m_batch, cluster_n = divmod(work_unit_id, self.params.cluster_shape_n_fdd)

# Then decode cluster_m_batch to get cluster_m and batch_l using FastDivmod
batch_l, cluster_m = divmod(cluster_m_batch, self.params.cluster_shape_m_fdd)

return (cluster_m, cluster_n, batch_l)


cutlass.utils.PersistentTileSchedulerParams.__init__ = hooked_PersistentTileSchedulerParams_init
cutlass.utils.StaticPersistentTileScheduler._get_cluster_work_idx_with_fastdivmod = (
hooked_get_cluster_work_idx_with_fastdivmod
)


class BlockScaledContiguousGatherGroupedGemmKernel:
"""This class implements contiguous grouped matrix multiplication with gather operation and SwiGLU fusion
for FC1 layer computation (C = up * silu(gate), where up/gate come from interleaved GEMM result).
Expand Down Expand Up @@ -245,6 +384,7 @@ def __init__(
cluster_shape_mn: Tuple[int, int],
vectorized_f32: bool,
topk: cutlass.Int64,
raster_along_m: bool = False,
):
"""Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with
gather operation and SwiGLU fusion.
Expand Down Expand Up @@ -289,6 +429,7 @@ def __init__(
self.cluster_shape_mn = cluster_shape_mn
# K dimension is deferred in _setup_attributes
self.mma_tiler = (*mma_tiler_mn, 1)
self.raster_along_m = raster_along_m

self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE

Expand Down Expand Up @@ -743,7 +884,11 @@ def __call__(

# Compute grid size
self.tile_sched_params, grid = self._compute_grid(
c, self.cta_tile_shape_mnk_c, self.cluster_shape_mn, max_active_clusters
c,
self.cta_tile_shape_mnk_c,
self.cluster_shape_mn,
max_active_clusters,
self.raster_along_m,
)

self.buffer_align_bytes = 1024
Expand Down Expand Up @@ -1254,34 +1399,69 @@ def kernel(
pipeline.PipelineUserType.Producer, self.num_tile_stage
)

while work_tile.is_valid_tile:
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
if mma_tile_coord_m < num_non_exiting_tiles[0]:
tile_info_pipeline.producer_acquire(tile_info_producer_state)
num_non_exiting_tiles_value = num_non_exiting_tiles[0]

if cutlass.const_expr(self.raster_along_m):
while work_tile.is_valid_tile:
cur_tile_coord = work_tile.tile_idx
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
with cute.arch.elect_one():
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
sInfo[(2, tile_info_producer_state.index)] = expert_idx
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
work_tile.is_valid_tile
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
if mma_tile_coord_m < num_non_exiting_tiles_value:
tile_info_pipeline.producer_acquire(tile_info_producer_state)
cur_tile_coord = work_tile.tile_idx
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
with cute.arch.elect_one():
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
sInfo[(2, tile_info_producer_state.index)] = expert_idx
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
work_tile.is_valid_tile
)
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)

self.sched_sync_barrier.arrive_and_wait()
tile_info_pipeline.producer_commit(tile_info_producer_state)
tile_info_producer_state.advance()
self.sched_sync_barrier.arrive_and_wait()
tile_info_pipeline.producer_commit(tile_info_producer_state)
tile_info_producer_state.advance()

tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
else:
is_continue = cutlass.Boolean(1)
while work_tile.is_valid_tile and is_continue:
cur_tile_coord = work_tile.tile_idx
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
if mma_tile_coord_m < num_non_exiting_tiles_value:
tile_info_pipeline.producer_acquire(tile_info_producer_state)
cur_tile_coord = work_tile.tile_idx
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
with cute.arch.elect_one():
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
sInfo[(2, tile_info_producer_state.index)] = expert_idx
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
work_tile.is_valid_tile
)
sInfo[(4, tile_info_producer_state.index)] = mn_limit
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)

self.sched_sync_barrier.arrive_and_wait()
tile_info_pipeline.producer_commit(tile_info_producer_state)
tile_info_producer_state.advance()
else:
is_continue = cutlass.Boolean(0)

tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()

tile_info_pipeline.producer_acquire(tile_info_producer_state)
with cute.arch.elect_one():
Expand Down Expand Up @@ -2781,6 +2961,7 @@ def _compute_grid(
cta_tile_shape_mnk: Tuple[int, int, int],
cluster_shape_mn: Tuple[int, int],
max_active_clusters: cutlass.Constexpr,
raster_along_m: bool = False,
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
"""Use persistent tile scheduler to compute the grid size for the output tensor C.

Expand All @@ -2803,7 +2984,9 @@ def _compute_grid(
num_ctas_mnl = gc[(0, (None, None, None))].shape
cluster_shape_mnl = (*cluster_shape_mn, 1)

tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl)
tile_sched_params = utils.PersistentTileSchedulerParams(
num_ctas_mnl, cluster_shape_mnl, raster_along_m=raster_along_m
)
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
tile_sched_params, max_active_clusters
)
Expand Down Expand Up @@ -3209,3 +3392,33 @@ def wrapper(
stream=stream,
epilogue_op=epilogue_op,
)


@cute.jit
def cvt_sf_MKL_to_M32x4xrm_K4xrk_L(
sf_ref_tensor: cute.Tensor,
sf_mma_tensor: cute.Tensor,
):
"""Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout"""
# sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
# group to ((32, 4, rest_m), (4, rest_k), l)
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3)
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3)
for i in cutlass.range(cute.size(sf_ref_tensor)):
mkl_coord = sf_ref_tensor.layout.get_hier_coord(i)
sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord]


@cute.jit
def cvt_sf_M32x4xrm_K4xrk_L_to_MKL(
sf_swizzled_tensor: cute.Tensor,
sf_unswizzled_tensor: cute.Tensor,
):
"""Convert scale factor tensor from mma specification M(32x4xrest_m)xK(4xrest_k)xL layout to MKL layout"""
# sf_swizzled_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
# group to ((32, 4, rest_m), (4, rest_k), l)
sf_swizzled_tensor = cute.group_modes(sf_swizzled_tensor, 0, 3)
sf_swizzled_tensor = cute.group_modes(sf_swizzled_tensor, 1, 3)
for i in cutlass.range(cute.size(sf_unswizzled_tensor)):
mkl_coord = sf_unswizzled_tensor.layout.get_hier_coord(i)
sf_unswizzled_tensor[mkl_coord] = sf_swizzled_tensor[mkl_coord]
Loading