Skip to content

Commit 3251bb8

Browse files
authored
[AMD] Support Skinny Blocks for TDM on gfx1250 (#8479)
This PR changed warps distribution to support TDM load/store skinny blocks like 1x512.
1 parent 90666a8 commit 3251bb8

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ decodeTDMDescriptor(RewriterBase &rewriter, Location loc,
3535

3636
return {srcPtr, tensorShape, tensorStride};
3737
}
38+
39+
SmallVector<int> getWarpDistribution(ArrayRef<int64_t> blockShape,
40+
int numWarps) {
41+
int numWarpsDim0 = numWarps;
42+
for (; numWarpsDim0 > blockShape[0]; numWarpsDim0 /= 2)
43+
;
44+
int numWarpsDim1 = numWarps / numWarpsDim0;
45+
46+
assert(numWarpsDim0 > 0 && blockShape[1] % numWarpsDim1 == 0 &&
47+
"Can't distribute warps in TDM");
48+
49+
return {numWarpsDim0, numWarpsDim1};
50+
}
3851
} // namespace
3952

4053
std::pair<SmallVector<Value>, SmallVector<Value>>
@@ -56,8 +69,10 @@ createTDMDescriptor(RewriterBase &rewriter, Location loc,
5669
tensorStride[0] = b.trunc(i32_ty, tensorStride[0]);
5770
tensorStride[1] = b.trunc(i32_ty, tensorStride[1]);
5871

59-
// For block shape [M, N], each warp will handle shape [M/numWarps, N].
60-
blockShape[0] = ceil(blockShape[0], int64_t(numWarps));
72+
// Distribute block among warps
73+
auto warps = getWarpDistribution(blockShape, numWarps);
74+
blockShape[0] = ceil(blockShape[0], int64_t(warps[0]));
75+
blockShape[1] = ceil(blockShape[1], int64_t(warps[1]));
6176

6277
// group0 (128 bits / 4 dwords) effective bit encoding:
6378
// [1:0]: pred (to be filled later)
@@ -122,19 +137,27 @@ void fillTDMDescriptor(RewriterBase &rewriter, Location loc,
122137
decodeTDMDescriptor(rewriter, loc, group0, group1);
123138

124139
auto warpId = getLaneAndWarpId(rewriter, loc).second;
125-
int outerBlockShapePerWarp = ceil(blockShape[0], int64_t(numWarps));
126-
int outerBlockStride = blockShape[1];
140+
auto warps = getWarpDistribution(blockShape, numWarps);
127141

128142
// Shift global pointer by offset
129-
Value outerOffset = b.mul(b.i32_val(outerBlockShapePerWarp), warpId);
130-
offset[0] = b.add(offset[0], outerOffset);
143+
Value warpDim0 = b.i32_val(warps[0]);
144+
SmallVector<Value, 2> warpCoord = {b.urem(warpId, warpDim0),
145+
b.udiv(warpId, warpDim0)};
146+
147+
SmallVector<Value, 2> globalOffset;
148+
for (int i = 0; i < 2; i++) {
149+
int64_t blockShapePerWarp = ceil(blockShape[i], int64_t(warps[i]));
150+
globalOffset.push_back(b.mul(b.i32_val(blockShapePerWarp), warpCoord[i]));
151+
offset[i] = b.add(offset[i], globalOffset[i]);
152+
}
131153

132154
Value baseOffset = b.add(b.mul(tensorStride[0], offset[0]),
133155
b.mul(tensorStride[1], offset[1]));
134156
srcPtr = b.gep(globalPtrTy, elementType, srcPtr, baseOffset);
135157

136158
// Shift shared pointer by offset
137-
Value dstOffset = b.mul(b.i32_val(outerBlockStride), outerOffset);
159+
Value dstOffset =
160+
b.add(b.mul(b.i32_val(blockShape[1]), globalOffset[0]), globalOffset[1]);
138161
if (padInterval > 0 && padAmount > 0) {
139162
Value iVal = b.i32_val(log2(padInterval));
140163
Value pVal = b.i32_val(log2(padAmount));

third_party/amd/python/test/test_gluon_gfx1250.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,9 @@ def torch_gemm_mxfp(a, b, a_scale, b_scale, scale_block, M, N, K):
365365

366366
@gluon.jit
367367
def tensor_copy_kernel(a_ptr, b_ptr, M, N, #
368-
BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, NUM_BUFFERS: ttgl.constexpr):
368+
BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, NUM_BUFFERS: ttgl.constexpr,
369+
BLOCKED_LAYOUT: ttgl.constexpr):
369370
SHARED_LAYOUT: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[32, 4]], [BLOCK_M, BLOCK_N], [1, 0])
370-
BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0])
371371

372372
pid = ttgl.program_id(axis=0)
373373
num_pid_m = ttgl.cdiv(M, BLOCK_M)
@@ -400,31 +400,38 @@ def tensor_copy_kernel(a_ptr, b_ptr, M, N, #
400400
@pytest.mark.parametrize("BLOCK_M,BLOCK_N", [(32, 32), (32, 64), (64, 64)])
401401
@pytest.mark.parametrize("NUM_BUFFERS", [1, 2])
402402
def test_compile_tensor_copy(BLOCK_M, BLOCK_N, NUM_BUFFERS):
403+
BLOCKED_LAYOUT = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0])
403404
k = triton.compile(
404405
gluon._runtime.GluonASTSource(
405406
fn=tensor_copy_kernel, signature={
406407
"a_ptr": "*fp16", "b_ptr": "*fp16", "M": "i32", "N": "i32", #
407-
"BLOCK_M": "constexpr", "BLOCK_N": "constexpr", "NUM_BUFFERS": "constexpr"
408-
}, constexprs={"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "NUM_BUFFERS": NUM_BUFFERS}),
409-
target=GPUTarget("hip", 'gfx1250', 32))
408+
"BLOCK_M": "constexpr", "BLOCK_N": "constexpr", "NUM_BUFFERS": "constexpr", #
409+
"BLOCKED_LAYOUT": "constexpr"
410+
}, constexprs={
411+
"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "NUM_BUFFERS": NUM_BUFFERS, "BLOCKED_LAYOUT": BLOCKED_LAYOUT
412+
}), target=GPUTarget("hip", 'gfx1250', 32))
410413

411414
amdgcn = k.asm["amdgcn"]
412415
for pattern in ("tensor_load_to_lds", "s_wait_tensorcnt 0x0"):
413416
assert re.search(pattern, amdgcn)
414417

415418

416-
@pytest.mark.parametrize("BLOCK_M,BLOCK_N", [(32, 32), (32, 64), (64, 64)])
419+
@pytest.mark.parametrize("BLOCK_M,BLOCK_N", [(32, 32), (32, 64), (64, 64), (1, 512), (256, 2)])
417420
@pytest.mark.parametrize("NUM_BUFFERS", [1, 2])
421+
@pytest.mark.parametrize("NUM_WARPS", [4, 8])
418422
@pytest.mark.parametrize("M,N", [(1024, 1024), (1000, 1000)])
419-
def test_runtime_tensor_copy(M, N, BLOCK_M, BLOCK_N, NUM_BUFFERS):
423+
def test_runtime_tensor_copy(M, N, BLOCK_M, BLOCK_N, NUM_BUFFERS, NUM_WARPS):
424+
blocked_layout = ttgl.BlockedLayout([1, 8], [4, 8], [NUM_WARPS, 1], [1, 0])
425+
420426
torch.manual_seed(42)
421427
a = torch.randint(0x0, 0xFFFF, (M, N), dtype=torch.uint16)
422428
b = torch.zeros_like(a)
423429

424430
a_device = a.cuda()
425431
b_device = b.cuda()
426432
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N * NUM_BUFFERS), 1)
427-
tensor_copy_kernel[grid](a_device, b_device, M, N, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, NUM_BUFFERS=NUM_BUFFERS)
433+
tensor_copy_kernel[grid](a_device, b_device, M, N, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, NUM_BUFFERS=NUM_BUFFERS,
434+
BLOCKED_LAYOUT=blocked_layout, num_warps=NUM_WARPS)
428435

429436
b_triton = b_device.cpu()
430437
assert torch.equal(b_triton, a)

0 commit comments

Comments
 (0)