Skip to content

[QST][CuteDSL] Deadlock when using TMA loads data #2879

@LRlr239

Description

@LRlr239

What is your question?

kernel always hang when using TMA to copy data, ref doc &code: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays

demo code in cute dsl:

import torch

import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.torch as cutlass_torch
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.runtime import from_dlpack

@cute.struct
class SharedStorage:
    barriers: cute.struct.MemRange[cutlass.Int64, 2]

@cute.kernel
def demo_kernel(
    src: cute.Tensor,
    g2s_copy_atom: cute.CopyAtom,
    g2s_tma_coord: cute.Tensor,
    s2g_copy_atom: cute.CopyAtom,
    s2g_tma_coord: cute.Tensor,
    smem_layout: cute.Layout,

    rows_per_block: cutlass.Constexpr,
    TPB: cutlass.Constexpr,
):
    allocator = cutlass.utils.SmemAllocator()
    bar_storage = allocator.allocate(
        SharedStorage, byte_alignment=8
    ).barriers.data_ptr()
    smem = allocator.allocate_tensor(
        src.element_type,
        smem_layout,
        byte_alignment=128,
    )

    warp_idx = cute.arch.warp_idx()
    warp_idx = cute.arch.make_warp_uniform(warp_idx)
    bid_x, bid_y, _ = cute.arch.block_idx()
    tid_x, _, _ = cute.arch.thread_idx()

    if warp_idx == 0:
        cpasync.prefetch_descriptor(g2s_copy_atom)
    
    if tid_x == 0:
        cute.arch.mbarrier_init(bar_storage, cnt=TPB)
        cute.arch.mbarrier_init_fence()
    
    cute.arch.sync_threads()

    if tid_x == 0:
        g_in = cute.local_tile(
            g2s_tma_coord, (rows_per_block, TPB), (bid_x, bid_y), (1, 1)
        )
        st, gt = cpasync.tma_partition(
            atom=g2s_copy_atom,
            cta_coord=0,
            cta_layout=cute.make_layout(1),
            smem_tensor=smem,
            gmem_tensor=g_in,
        )
        print(f"st: {st}")
        print(f"gt: {gt}")
        cute.copy(
            g2s_copy_atom,
            gt,
            st,
            tma_bar_ptr=bar_storage
        )
        print(f"g_in: {g_in}")        
        cute.arch.mbarrier_arrive_and_expect_tx(
            bar_storage,
            cute.size_in_bytes(src.element_type, smem_layout)
        )
    else:
        cute.arch.mbarrier_arrive(bar_storage)
    # cute.printf("tid {} before sync\n", tid_x)
    cute.arch.mbarrier_wait(bar_storage, 0)
    
    cute.arch.sync_threads()
    if tid_x == 0:
        cute.print_tensor(smem)


@cute.jit
def demo_host_func(src: cute.Tensor, dst: cute.Tensor, rows_per_block: cutlass.Constexpr,  TPB: cutlass.Constexpr):
    print(f"src: {src}")
    smem_layout = cute.make_ordered_layout((rows_per_block, TPB), order=(1, 0))
    g2s_tma_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp()
    s2g_tma_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp()
    g2s_atom, g2s_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
        g2s_tma_op,
        src,
        smem_layout,
        (rows_per_block, TPB)
    )
    s2g_atom, s2g_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
        s2g_tma_op,
        src,
        smem_layout,
        (rows_per_block, TPB)
    )
    print(f"op: {g2s_tma_op}")
    print(f"g2s_atom: {g2s_atom}")
    print(f"s2g_atom: {s2g_atom}")
    print(f"g2s_tma_tensor: {g2s_tma_tensor}")
    # cute.print_tensor(g2s_tma_tensor)

    grid = (1,1,1)
    block = (TPB, 1, 1)

    demo_kernel(
        src,
        g2s_atom, g2s_tma_tensor,
        s2g_atom, s2g_tma_tensor,
        smem_layout,
        rows_per_block,
        TPB
    ).launch(
        grid=grid,
        block=block
    )

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions