Skip to content

Commit cf399b4

Browse files
authored
[Hopper][Warp Spec] Enable persistent matmul test (#7642)
This also just works.
1 parent 45a7da6 commit cf399b4

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

python/test/unit/language/test_warp_specialization.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def matmul_tma_persistent_ws_kernel( #
304304
GROUP_SIZE_M: tl.constexpr, #
305305
NUM_SMS: tl.constexpr, #
306306
USE_FP8: tl.constexpr, #
307+
FLATTEN: tl.constexpr, #
307308
):
308309
a_desc = tl.make_tensor_descriptor(a_ptr, shape=[M, K], strides=[a_stride0, a_stride1],
309310
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K])
@@ -318,7 +319,8 @@ def matmul_tma_persistent_ws_kernel( #
318319
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
319320
num_tiles = num_pid_m * num_pid_n
320321

321-
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=True, num_stages=num_stages):
322+
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=True,
323+
num_stages=num_stages):
322324
pid_m, pid_n = _compute_pid(tile_id, num_pid_n, num_pid_m, GROUP_SIZE_M)
323325

324326
off_am = pid_m * BLOCK_SIZE_M
@@ -342,7 +344,7 @@ def matmul_tma_persistent_ws_kernel( #
342344
@pytest.mark.parametrize("num_warps", [4, 8])
343345
@pytest.mark.parametrize("use_fp8", [False, True])
344346
@pytest.mark.skipif(is_hip(), reason="warp specialization is not supported on hip devices")
345-
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
347+
@pytest.mark.skipif(not is_hopper_or_blackwell(), reason="Requires Hopper or Blackwell")
346348
def test_warp_specialize_tma_matmul_persistent(M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, num_stages, num_warps,
347349
use_fp8):
348350
if exceeds_smem_capacity(num_stages, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, use_fp8):
@@ -371,10 +373,18 @@ def grid(META):
371373

372374
kernel = matmul_tma_persistent_ws_kernel[grid](A, B, C, *A.stride(), *B.stride(), *C.stride(), M, N, K, num_stages,
373375
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, NUM_SMS,
374-
num_warps=num_warps, USE_FP8=use_fp8)
376+
num_warps=num_warps, USE_FP8=use_fp8, FLATTEN=is_blackwell())
375377
ttgir = kernel.asm["ttgir"]
376-
assert "ttng.tc_gen5_mma" in ttgir
377-
assert "ttg.warp_specialize" in ttgir
378+
if is_blackwell():
379+
assert "ttng.tc_gen5_mma" in ttgir
380+
assert "ttng.async_tma_copy_global_to_local" in ttgir
381+
else:
382+
assert "ttng.warp_group_dot" in ttgir
383+
assert "ttng.async_tma_copy_global_to_local" in ttgir
384+
if is_hopper() and num_warps == 8:
385+
assert "ttg.warp_specialize" not in ttgir
386+
else:
387+
assert "ttg.warp_specialize" in ttgir
378388

379389
ref_out = torch.empty((M, N), dtype=dtype, device=device)
380390
cublas.matmul(A, B, ref_out)

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerToken.cpp

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,6 @@ namespace mlir {
2424
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
2525
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
2626

27-
static Value createThreadIdOp(OpBuilder &builder, Location loc) {
28-
Value threadId = builder.create<::mlir::gpu::ThreadIdOp>(
29-
loc, builder.getIndexType(), ::mlir::gpu::Dimension::x);
30-
auto cast = builder.create<UnrealizedConversionCastOp>(
31-
loc, TypeRange{builder.getIntegerType(32)}, ValueRange{threadId});
32-
return cast.getResult(0);
33-
}
34-
3527
// Lower to use GetCanonicalWarpIdOp.
3628
// In Hopper, each task is a warpgroup consisting of 4 warps.
3729
static const int WARPS_PER_TASK = 4;
@@ -73,15 +65,9 @@ void processProducerCommitOp(OpBuilder &builder, ttnvws::ProducerCommitOp op,
7365
ttng::ArriveBarrierOp arriveOp;
7466

7567
if (loadType == ttnvws::TokenLoadType::TMALoadOp) {
76-
// Only thread 0 arrives for TMA load.
77-
Value _0 = builder.create<arith::ConstantIntOp>(loc, 0, 32);
78-
Value threadId = createThreadIdOp(builder, loc);
79-
Value pred = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
80-
threadId, _0);
8168
// Get the count from the barriers: trace the local_alloc for the barrier
8269
// then find the count from init_barrier
83-
arriveOp =
84-
builder.create<ttng::ArriveBarrierOp>(loc, bufferFull, fullCnt, pred);
70+
arriveOp = builder.create<ttng::ArriveBarrierOp>(loc, bufferFull, fullCnt);
8571
} else {
8672
assert(false);
8773
}

0 commit comments

Comments
 (0)