Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions third_party/tlx/tutorials/blackwell_gemm_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,11 +730,10 @@ def _process_tile_epilogue_inner(
[BLOCK_M_SPLIT, slice_size],
)
result = tlx.local_load(acc_tmem_subslice)
# Signal MMA consumer after each slice
tlx.barrier_arrive(tmem_empty_bars[buf_idx], 1)
c = result.to(tlx.dtype_of(c_desc))
c_smem = c_smem_buffers[group_id]
tlx.async_descriptor_store_wait(0)
c_smem = c_smem_buffers[(group_id * EPILOGUE_SUBTILE + slice_id) % 2]
tlx.async_descriptor_store_wait(1)
tlx.local_store(c_smem, c)
tlx.fence_async_shared()
tlx.async_descriptor_store(
Expand Down Expand Up @@ -976,12 +975,13 @@ def matmul_kernel_tma_ws_blackwell(
tlx.storage_kind.tmem,
)

# Allocate SMEM buffer for epilogue TMA store (one per MMA group)
# Allocate SMEM buffers for epilogue TMA store (at least 2 for multi-buffering)
NUM_EPILOGUE_SMEM_BUFFERS: tl.constexpr = NUM_MMA_GROUPS if NUM_MMA_GROUPS > 2 else 2
slice_size: tl.constexpr = BLOCK_SIZE_N // EPILOGUE_SUBTILE
c_smem_buffers = tlx.local_alloc(
(BLOCK_M_SPLIT, slice_size),
tlx.dtype_of(c_desc),
NUM_MMA_GROUPS,
NUM_EPILOGUE_SMEM_BUFFERS,
)

# CTA pairs are placed along M dim
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def matmul_kernel_tma_ws_blackwell(
tile_id += NUM_SMS


def matmul(a, b, config=None, use_heuristic=True):
def matmul(a, b, config=None, use_heuristic=False):
"""Matrix multiplication using TLX GEMM kernel.

Args:
Expand Down
Loading