Skip to content

Commit e57b468

Browse files
authored
[BACKEND] Functional fixes for layout conversion that uses stmatrix (#5407)
This PR: 1. Refactored construction logic in `LinearLayoutConversions.cpp` for `stmatrix` selection. Note that the heuristic-based approach will be replaced with LL-driven approach once we have `divideRight` and `divideLeft`. 2. Updated `SharedLayout` class and added `has_leading_offset` attribute. 3. Added comprehensive new test cases for MMA and shared layouts.
1 parent 8dfa7be commit e57b468

File tree

2 files changed

+142
-49
lines changed

2 files changed

+142
-49
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -904,13 +904,6 @@ LinearLayout chooseStMatrixLayoutLeadingOffset(
904904
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
905905
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
906906
int swizzleByteSize) {
907-
StringAttr kReg = S("register");
908-
StringAttr kLane = S("lane");
909-
StringAttr kWarp = S("warp");
910-
StringAttr kCol = S("dim1");
911-
StringAttr kRow = S("dim0");
912-
StringAttr kOffset = S("offset");
913-
914907
int perPhase;
915908
int maxPhase;
916909
if (swizzleByteSize == 32) {
@@ -930,45 +923,84 @@ LinearLayout chooseStMatrixLayoutLeadingOffset(
930923
// stmatrix only supports 16-bit elements, and each vector has 8 elements
931924
int elemBitWidth = 16;
932925
int vecSize = 8;
933-
int numRows = 16;
934-
int numCols = 8 * swizzleByteSize / elemBitWidth;
926+
int numRowsPerTile = 16;
927+
int numColsPerChunk = 8 * swizzleByteSize / elemBitWidth;
935928

936929
// Construct a single stmatrix.x4 (16x16) tile
937930
std::vector<std::vector<int>> basesReg = {{1, 0}, {2, 0}, {4, 0}};
938931
std::vector<std::vector<int>> basesLane;
939-
for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) {
932+
for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile); logRow++) {
940933
int row = 1 << logRow;
941934
basesLane.push_back({vecSize * ((row / perPhase) % maxPhase), row});
942935
}
943936
basesLane.push_back({8, 0});
944937

945-
// Expand the tile's register dimension to fit swizzleByteSize, which is a
946-
// "chunk"
947-
for (int logChunk = 0; logChunk < llvm::Log2_32(numCols / 16); logChunk++) {
948-
int chunk = 1 << logChunk;
949-
basesReg.push_back({16 * chunk, 0});
938+
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
939+
assert(mma.getVersionMajor() >= 3 && "Only MMAv3 is supported");
940+
int instrM = mma.getInstrShape()[0];
941+
int instrN = mma.getInstrShape()[1];
942+
943+
// TODO(Keren): The following logic can be simplified by using the
944+
// `divideLeft` function in `LinearLayout` once it's available.
945+
// Construct the bases for a single chunk
946+
// In theory the following situation is valid but it will be
947+
// suboptimal. Swizzling should happen within a warp.
948+
assert(instrN >= numColsPerChunk &&
949+
"Each chunk is filled in with a single warp");
950+
for (int logCol = 0; logCol < llvm::Log2_32(numColsPerChunk / 16); logCol++) {
951+
int col = 1 << logCol;
952+
basesReg.push_back({16 * col, 0});
950953
}
951954

952-
// Construct the layout for a single chunk
953-
LinearLayout layout =
954-
LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow});
955+
// Construct the bases for warpsPerCTA[0]
956+
std::vector<std::vector<int>> basesWarp;
957+
auto warpsPerCTA = mma.getWarpsPerCTA();
958+
auto shape = tensorTy.getShape();
959+
for (int logWarp = 0; logWarp < llvm::Log2_32(warpsPerCTA[0]); logWarp++) {
960+
int warp = 1 << logWarp;
961+
basesWarp.push_back({0, warp * instrM});
962+
}
955963

956-
// Expand the `warp` dimension according to warpsPerCTA.
957-
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
958-
layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
959-
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
964+
// Expand the `register` dimension so the size of columns matches `shape[1] /
965+
// warpsPerCTA[1]`
966+
auto numColsPerWarp = std::max<int>(instrN, shape[1] / warpsPerCTA[1]);
967+
assert(warpsPerCTA[1] * instrN >= shape[1] &&
968+
"There must be enough columns to use MMAv3");
969+
auto logNumCols = llvm::Log2_32(numColsPerWarp / numColsPerChunk);
970+
for (int logCol = 0; logCol < logNumCols; logCol++) {
971+
int chunk = 1 << logCol;
972+
int basis = chunk * shape[0];
973+
basesReg.push_back({0, basis});
974+
}
960975

961-
// Expand the `register` dimension so the size of columns matches `n`.
962-
int n = mma.getInstrShape()[1];
963-
int numWarpRows = layout.getOutDimSize(kRow);
964-
layout = (layout.reshapeOuts({{kOffset, layout.getTotalOutDimSize()}}) *
965-
LinearLayout::identity1D(n / numCols, kReg, kOffset))
966-
.reshapeOuts({{kCol, n}, {kRow, numWarpRows}});
976+
// Expand the `register` dimension so that the size of rows matches `shape[0]`
977+
assert(warpsPerCTA[0] * instrM <= shape[0] &&
978+
"There must be enough rows to use MMAv3");
979+
auto logNumRows = llvm::Log2_32(shape[0] / (warpsPerCTA[0] * instrM));
980+
for (int logRow = 0; logRow < logNumRows; logRow++) {
981+
int chunk = 1 << logRow;
982+
int basis = chunk * warpsPerCTA[0] * instrM;
983+
basesReg.push_back({0, basis});
984+
}
967985

968-
auto ret =
969-
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
970-
return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames()))
971-
.reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}});
986+
// Expand the `warp` dimension so that the size of cols matches `shape[1]`
987+
for (int logWarp = 0; logWarp < llvm::Log2_32(warpsPerCTA[1]); logWarp++) {
988+
int warp = 1 << logWarp;
989+
if (warp * numColsPerWarp >= shape[1]) {
990+
basesWarp.push_back({0, 0});
991+
} else {
992+
int basis = (warp * numColsPerWarp) / numColsPerChunk * shape[0];
993+
basesWarp.push_back({0, basis});
994+
}
995+
}
996+
997+
auto layout = LinearLayout({{S("register"), basesReg},
998+
{S("lane"), basesLane},
999+
{S("warp"), basesWarp},
1000+
{S("block"), {}}},
1001+
{S("offset1"), S("offset0")});
1002+
return layout.reshapeOuts(
1003+
{{S("offset"), layout.getTotalOutDimSize()}, {S("iteration"), 1}});
9721004
}
9731005

9741006
LinearLayout chooseStMatrixLayoutNoLeadingOffset(

python/test/unit/language/test_core.py

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -179,17 +179,20 @@ def __str__(self):
179179

180180
class SharedLayout:
181181

182-
def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order):
182+
def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order,
183+
has_leading_offset=False):
183184
self.vec = vec
184185
self.per_phase = per_phase
185186
self.max_phase = max_phase
186187
self.order = order
187188
self.ctas_per_cga = ctas_per_cga
188189
self.cta_split_num = cta_split_num
189190
self.cta_order = cta_order
191+
self.has_leading_offset = has_leading_offset
190192

191193
def __str__(self):
192-
return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
194+
has_leading_offset_str = "true" if self.has_leading_offset else "false"
195+
return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, hasLeadingOffset={has_leading_offset_str}}}>"
193196

194197

195198
def is_layout_applicable(layout) -> bool:
@@ -5418,7 +5421,7 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
54185421
k_width=1),
54195422
]
54205423

5421-
shared_layout_3d = [
5424+
shared_layouts_3d = [
54225425
SharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
54235426
SharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
54245427
SharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
@@ -5427,8 +5430,8 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
54275430

54285431

54295432
@pytest.mark.parametrize("M, N, K", [[8, 16, 32]])
5430-
@pytest.mark.parametrize("shared_layout", shared_layout_3d)
5431-
@pytest.mark.parametrize("dist_layout", layouts_3d)
5433+
@pytest.mark.parametrize("shared_layout", shared_layouts_3d)
5434+
@pytest.mark.parametrize("dist_layout", filter_layouts(layouts_3d))
54325435
def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: pathlib.Path):
54335436
layouts = f"""
54345437
#dist = {dist_layout}
@@ -5500,6 +5503,72 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
55005503
assert torch.equal(z, x)
55015504

55025505

5506+
mma_layouts = [
5507+
MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),
5508+
MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # simple 4 warps case
5509+
MmaLayout((3, 0), [8, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # simple 8 warps case
5510+
MmaLayout((3, 0), [4, 2], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # multiple warps on the row
5511+
MmaLayout((3, 0), [4, 2], [1, 1], [1, 1], [0, 1], [16, 64, 16]), # small instrN
5512+
MmaLayout((3, 0), [8, 4], [1, 1], [1, 1], [0, 1], [16, 64, 16]), # large number of warps
5513+
]
5514+
5515+
shared_layouts = [
5516+
SharedLayout(8, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]),
5517+
SharedLayout(8, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1], has_leading_offset=True), # small contiguous bytes
5518+
SharedLayout(8, 1, 8, [1, 0], [1, 1], [1, 1], [0, 1], has_leading_offset=True), # maximum contiguous bytes
5519+
]
5520+
5521+
5522+
@pytest.mark.parametrize("M, N", [[128, 128]])
5523+
@pytest.mark.parametrize("mma_layout", filter_layouts(mma_layouts))
5524+
@pytest.mark.parametrize("shared_layout", shared_layouts)
5525+
def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path: pathlib.Path):
5526+
num_warps = np.prod(mma_layout.warps_per_cta)
5527+
5528+
layouts = f"""
5529+
#dist = {mma_layout}
5530+
#shared = {shared_layout}
5531+
#smem = #ttg.shared_memory
5532+
"""
5533+
ir = layouts + f"""
5534+
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
5535+
tt.func public @kernel(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
5536+
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #dist>
5537+
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
5538+
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>>
5539+
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #dist>
5540+
%3 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #dist>
5541+
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<{M}x1xi32, #dist>
5542+
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #dist>
5543+
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> -> tensor<1x{N}xi32, #dist>
5544+
%7 = tt.broadcast %6 : tensor<1x{N}xi32, #dist> -> tensor<{M}x{N}xi32, #dist>
5545+
%8 = tt.broadcast %5 : tensor<{M}x1xi32, #dist> -> tensor<{M}x{N}xi32, #dist>
5546+
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #dist>
5547+
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #dist>, tensor<{M}x{N}xi32, #dist>
5548+
%11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<f16>, #dist>
5549+
%12 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #dist>) -> !ttg.memdesc<{M}x{N}xf16, #shared, #smem>
5550+
%13 = ttg.local_load %12 : !ttg.memdesc<{M}x{N}xf16, #shared, #smem> -> tensor<{M}x{N}xf16, #dist>
5551+
%14 = tt.addptr %3, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #dist>, tensor<{M}x{N}xi32, #dist>
5552+
tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr<f16>, #dist>
5553+
tt.return
5554+
}}
5555+
}}
5556+
"""
5557+
5558+
x = torch.arange(0, M * N, device=device, dtype=torch.float16).reshape(M, N)
5559+
z = torch.empty_like(x, device=device)
5560+
5561+
temp_file = tmp_path / "test_local_load_store_mma.ttgir"
5562+
temp_file.write_text(ir)
5563+
kernel = triton.compile(str(temp_file))
5564+
5565+
kernel[(1, 1, 1)](x, z)
5566+
assert torch.equal(z, x)
5567+
5568+
if shared_layout.has_leading_offset == "true" and mma_layout.version[0] >= 3:
5569+
assert "stmatrix" in kernel.asm["ptx"]
5570+
5571+
55035572
mma_pairs = [
55045573
[
55055574
MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]),
@@ -5546,18 +5615,10 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
55465615

55475616
@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]])
55485617
@pytest.mark.parametrize("dtype", ['float16'])
5549-
@pytest.mark.parametrize("mma_pair", mma_pairs)
5550-
def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path):
5551-
if is_hip():
5552-
pytest.skip("test_mma2mma is not supported in HIP")
5553-
5618+
@pytest.mark.parametrize("mma_pair", filter_layouts(mma_pairs))
5619+
def test_convert_mma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path):
55545620
src_layout, _ = mma_pair
5555-
if is_cuda():
5556-
cc = torch.cuda.get_device_capability()
5557-
if cc[0] < 9 and src_layout.version[0] >= 3:
5558-
pytest.skip("Skip testing MMAv3 on devices with CC < 9")
5559-
5560-
num_warps = np.cumprod(src_layout.warps_per_cta)[-1]
5621+
num_warps = np.prod(src_layout.warps_per_cta)
55615622

55625623
def do_test(src_layout, dst_layout):
55635624
layouts = f"""
@@ -5593,7 +5654,7 @@ def do_test(src_layout, dst_layout):
55935654
x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device)
55945655
z = torch.empty_like(x)
55955656

5596-
temp_file = tmp_path / "test_convertmma2mma.ttgir"
5657+
temp_file = tmp_path / "test_convert_mma2mma.ttgir"
55975658
temp_file.write_text(ir)
55985659
kernel = triton.compile(str(temp_file))
55995660

0 commit comments

Comments
 (0)