-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Closed
Labels
Description
What is your question?
kernel always hang when using TMA to copy data, ref doc &code: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays
demo code in cute dsl:
import torch
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.torch as cutlass_torch
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.runtime import from_dlpack
@cute.struct
class SharedStorage:
barriers: cute.struct.MemRange[cutlass.Int64, 2]
@cute.kernel
def demo_kernel(
src: cute.Tensor,
g2s_copy_atom: cute.CopyAtom,
g2s_tma_coord: cute.Tensor,
s2g_copy_atom: cute.CopyAtom,
s2g_tma_coord: cute.Tensor,
smem_layout: cute.Layout,
rows_per_block: cutlass.Constexpr,
TPB: cutlass.Constexpr,
):
allocator = cutlass.utils.SmemAllocator()
bar_storage = allocator.allocate(
SharedStorage, byte_alignment=8
).barriers.data_ptr()
smem = allocator.allocate_tensor(
src.element_type,
smem_layout,
byte_alignment=128,
)
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
bid_x, bid_y, _ = cute.arch.block_idx()
tid_x, _, _ = cute.arch.thread_idx()
if warp_idx == 0:
cpasync.prefetch_descriptor(g2s_copy_atom)
if tid_x == 0:
cute.arch.mbarrier_init(bar_storage, cnt=TPB)
cute.arch.mbarrier_init_fence()
cute.arch.sync_threads()
if tid_x == 0:
g_in = cute.local_tile(
g2s_tma_coord, (rows_per_block, TPB), (bid_x, bid_y), (1, 1)
)
st, gt = cpasync.tma_partition(
atom=g2s_copy_atom,
cta_coord=0,
cta_layout=cute.make_layout(1),
smem_tensor=smem,
gmem_tensor=g_in,
)
print(f"st: {st}")
print(f"gt: {gt}")
cute.copy(
g2s_copy_atom,
gt,
st,
tma_bar_ptr=bar_storage
)
print(f"g_in: {g_in}")
cute.arch.mbarrier_arrive_and_expect_tx(
bar_storage,
cute.size_in_bytes(src.element_type, smem_layout)
)
else:
cute.arch.mbarrier_arrive(bar_storage)
# cute.printf("tid {} before sync\n", tid_x)
cute.arch.mbarrier_wait(bar_storage, 0)
cute.arch.sync_threads()
if tid_x == 0:
cute.print_tensor(smem)
@cute.jit
def demo_host_func(src: cute.Tensor, dst: cute.Tensor, rows_per_block: cutlass.Constexpr, TPB: cutlass.Constexpr):
print(f"src: {src}")
smem_layout = cute.make_ordered_layout((rows_per_block, TPB), order=(1, 0))
g2s_tma_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()
s2g_tma_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp()
g2s_atom, g2s_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
g2s_tma_op,
src,
smem_layout,
(rows_per_block, TPB)
)
s2g_atom, s2g_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
s2g_tma_op,
src,
smem_layout,
(rows_per_block, TPB)
)
print(f"op: {g2s_tma_op}")
print(f"g2s_atom: {g2s_atom}")
print(f"s2g_atom: {s2g_atom}")
print(f"g2s_tma_tensor: {g2s_tma_tensor}")
# cute.print_tensor(g2s_tma_tensor)
grid = (1,1,1)
block = (TPB, 1, 1)
demo_kernel(
src,
g2s_atom, g2s_tma_tensor,
s2g_atom, s2g_tma_tensor,
smem_layout,
rows_per_block,
TPB
).launch(
grid=grid,
block=block
)
Reactions are currently unavailable