Skip to content

Commit 65a416a

Browse files
authored
[AMD] Use 128-bit stores in epilogue for mfma16 on CDNA4 (#6787)
Similar to triton-lang/triton#6688, this commit optimizes threads to own 8 elements in epilogue to enable dwordx4 stores for mfma16x16.
1 parent 334cd33 commit 65a416a

File tree

3 files changed

+182
-13
lines changed

3 files changed

+182
-13
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 111 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,32 +1511,133 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) {
15111511
return {};
15121512
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(valType.getEncoding());
15131513

1514-
// We currently only support transposed [B]F16 MFMA32x32 on CDNA4.
1514+
// We currently only support transposed [B]F16 MFMA32x32 and MFMA16x16 on
1515+
// CDNA4.
15151516
bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32;
1517+
bool isMfma16 = mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16;
1518+
1519+
auto valShape = valType.getShape();
1520+
// For mfma16x16, to use in-wavefront swap, we need to make sure the tiles
1521+
// used are in one wavefront if there are multiple tiles, which means
1522+
// warpsPerCTA = [numWarps, 1] and at least two tiles along the N dim. For
1523+
// now, it is only possible for FA-like kernels since during mfma generation,
1524+
// the WarpsPerCTA of the head dot in the chain will be reshaped to [numWaprs,
1525+
// 1].
1526+
// TODO: For gemm-like kernel, the transformation here cannot be applied for
1527+
// now and will support it.
1528+
bool validForMfma16 = isMfma16 && valShape.back() >= 16 * 2 &&
1529+
mfmaLayout.getWarpsPerCTA().back() == 1;
1530+
15161531
Type elemType = valType.getElementType();
15171532
if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) &&
15181533
mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() &&
1519-
isMfma32))
1534+
(isMfma32 || validForMfma16)))
15201535
return {};
15211536

1522-
auto valShape = valType.getShape();
15231537
LinearLayout mfmaLL = mfmaLayout.toLinearLayout(valShape);
15241538
auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames());
15251539
StringAttr dimM = mfmaOutDims[0];
15261540
StringAttr dimN = mfmaOutDims[1];
1527-
15281541
auto swapLL = LinearLayout::empty();
15291542
// The rows are kept as is with an identity linear layout.
15301543
swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM);
1531-
// In transposed mfma32 layout, each thread holds 4 consecutive values along N
1532-
// dim. We want to exchange column 4-7 (owned by thread 32-63) and column 8-11
1533-
// (owned by thread 0-31) every 16 columns to make each thread holds 8
1534-
// elements. This would mean exchange the 2nd and 3rd basis vector from an
1535-
// identity linear layout.
1544+
/*
1545+
clang-format off
1546+
In transposed mfma32 layout, Each thread holds 4 consecutive values along N
1547+
dim. We want to exchange column 4-7 (owned by thread 32-63, BLK0) and column
1548+
8-11 (owned by thread 0-31, BLK1) every 16 columns to make each thread holds 8
1549+
elements. This would mean exchange the 2nd and 3rd basis vector from an
1550+
identity linear layout on tensor elements.
1551+
1552+
Correspondingly, the transposed mfma16 layout, the output of
1553+
transposed of mfma16x16 is:
1554+
1555+
N/register
1556+
M/Lane v0 v1 v2 v3 v4 v5 v6 v7
1557+
-------------------------------------------------------------------------
1558+
row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1559+
-------------------------------------------------------------------------
1560+
row1: 16-31 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1561+
-------------------------------------------------------------------------
1562+
row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1563+
-------------------------------------------------------------------------
1564+
row3: 48-63 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
1565+
-------------------------------------------------------------------------
1566+
which means:
1567+
The columns from v0 to v3 are in the one output of mfma16x16 and
1568+
the columns from v4 to v7 are in the one output of mfma16x16,
1569+
1570+
The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor,
1571+
N/register
1572+
-----------------------------------------------
1573+
M/lane |(0, 0) ... (0, 3) | (0, 16) ... (0, 19) |
1574+
|.... | sub-tensor-0 |
1575+
|(15, 0) ... (15, 3) | (15, 16) ... (15, 19) |
1576+
-----------------------------------------------
1577+
|(0, 4) ... (0, 7) | (0, 20) ... (0, 23) |
1578+
|sub-tensor-1 | .... |
1579+
|(15, 0) ... (15, 3) | (15, 20) ... (15, 23) |
1580+
-----------------------------------------------
1581+
|(0, 8) ... (0, 11)| (0, 24) ... (0, 27) |
1582+
|.... | sub-tensor-2 |
1583+
|(15, 8) ... (15, 11)| (15, 24) ... (15, 27) |
1584+
-----------------------------------------------
1585+
|(0, 12) ... (0, 15)| (0, 28) ... (0, 31) |
1586+
|sub-tensor-3 | .... |
1587+
|(15, 12) ... (15, 15)| (15, 28) ... (15, 31) |
1588+
-----------------------------------------------
1589+
The basis vector for lane and register are:
1590+
Register = {{0, 1}, {0, 2}}
1591+
Lane = {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}
1592+
With this layout, only 4xfp16 can be packed in the final global store.
1593+
1594+
To use 128-bits global store, we need to pack 8 elements, which means the layout looks like:
1595+
N/register
1596+
M/Lane v0 v1 v2 v3 v4 v5 v6 v7
1597+
-------------------------------------------------------------------------
1598+
row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 |
1599+
-------------------------------------------------------------------------
1600+
row1: 16-31 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 |
1601+
-------------------------------------------------------------------------
1602+
row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 |
1603+
-------------------------------------------------------------------------
1604+
row3: 48-63 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 |
1605+
-------------------------------------------------------------------------
1606+
1607+
The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor:
1608+
N/register
1609+
-----------------------------------------------
1610+
|(0, 0) ... (0, 3) | (0, 4) ... (0, 7) |
1611+
|.... | sub-tensor-1 |
1612+
|(15, 0) ... (15, 3) | (15, 16) ... (15, 19) |
1613+
-----------------------------------------------
1614+
|(0, 16) ... (0, 19) | (0, 20) ... (0, 23) |
1615+
|sub-tensor-0 | .... |
1616+
|(15, 16) ... (15, 19)| (15, 20) ... (15, 23) |
1617+
-----------------------------------------------
1618+
|(0, 8) ... (0, 11)| (0, 12) ... (0, 15) |
1619+
|.... | sub-tensor-3 |
1620+
|(15, 8) ... (15, 11)| (15, 12) ... (15, 15) |
1621+
-----------------------------------------------
1622+
|(0, 24) ... (0, 27)| (0, 28) ... (0, 31) |
1623+
|sub-tensor-2 | .... |
1624+
|(15, 24) ... (15, 27)| (15, 28) ... (15, 31) |
1625+
-----------------------------------------------
1626+
which means we need to exchange sub-tensor-0 with sub-tensor-1 and sub-tensor-2 and sub-tensor-3.
1627+
And basis vector for lane and register are:
1628+
Register = {{0, 1}, {0, 2}, {0, 4}}
1629+
Lane = {{1, 0}, {2, 0, [4, 0}, {8, 0}, {0, 16}, {0, 8}}
1630+
1631+
The steps to get this layout are, firstly we check the last dim of WarpsPerCTA is 1, so we can use v_permlane16.
1632+
Then, we exchange the 2nd and 4th elements in the basis vector of an identity linear and then it will be composed with
1633+
the original mfma16 LL.
1634+
clang-format on
1635+
*/
1636+
auto destIdxInBases = isMfma32 ? 3 : 4;
15361637
std::vector<std::vector<int32_t>> dimNBases(mfmaLL.getOutDimSizeLog2(dimN));
15371638
std::generate(dimNBases.begin(), dimNBases.end(),
15381639
[i = 0]() mutable { return std::vector<int32_t>{1 << i++}; });
1539-
std::swap(dimNBases[2], dimNBases[3]);
1640+
std::swap(dimNBases[2], dimNBases[destIdxInBases]);
15401641
swapLL *= LinearLayout({{dimN, dimNBases}}, {dimN});
15411642

15421643
return mfmaLL.compose(swapLL);

test/TritonGPU/amd/amd-optimize-epilogue.mlir

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,66 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
8686
tt.return
8787
}
8888
}
89+
90+
// -----
91+
// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 32], [0, 64], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 16], [0, 8]], warp = [[16, 0], [32, 0]], block = []}>
92+
// CHECK-LABEL: store_dword_16x16
93+
// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked>
94+
// CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128x!tt.ptr<f16>, #mma> -> tensor<128x128x!tt.ptr<f16>, #linear>
95+
// CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #linear>
96+
// CHECK: tt.store %[[PTR]], %[[VAL]] : tensor<128x128x!tt.ptr<f16>, #linear>
97+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
98+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
99+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
100+
tt.func public @store_dword_16x16(%arg0: !tt.ptr<f16>) attributes {noinline = false} {
101+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
102+
%cst_0 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
103+
%cst_1 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
104+
%0 = tt.dot %cst_0, %cst_1, %cst : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
105+
%1 = ttg.convert_layout %0 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked>
106+
%2 = arith.truncf %1 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked>
107+
%3 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked>
108+
tt.store %3, %2 : tensor<128x128x!tt.ptr<f16>, #blocked>
109+
tt.return
110+
}
111+
}
112+
113+
// -----
114+
// To validate if warpsPerCTA is not expected, no linear layout will be created.
115+
// CHECK-LABEL: store_dword_16x16
116+
// CHECK-NOT: #linear
117+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [2, 2], order = [1, 0]}>
118+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}>
119+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
120+
tt.func public @store_dword_16x16(%arg0: !tt.ptr<f16>) attributes {noinline = false} {
121+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
122+
%cst_0 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
123+
%cst_1 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
124+
%0 = tt.dot %cst_0, %cst_1, %cst : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
125+
%1 = ttg.convert_layout %0 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked>
126+
%2 = arith.truncf %1 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked>
127+
%3 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked>
128+
tt.store %3, %2 : tensor<128x128x!tt.ptr<f16>, #blocked>
129+
tt.return
130+
}
131+
}
132+
133+
// -----
134+
// To validate if N of the input shape is not expected, larger or equal 16X2, no linear layout will be created.
135+
// CHECK-LABEL: store_dword_16x16
136+
// CHECK-NOT: #linear
137+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
138+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}>
139+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
140+
tt.func public @store_dword_16x16(%arg0: !tt.ptr<f16>) attributes {noinline = false} {
141+
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
142+
%cst_0 = arith.constant dense<1.230000e+02> : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
143+
%cst_1 = arith.constant dense<1.230000e+02> : tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
144+
%0 = tt.dot %cst_0, %cst_1, %cst : tensor<16x16xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma>
145+
%1 = ttg.convert_layout %0 : tensor<16x16xf32, #mma> -> tensor<16x16xf32, #blocked>
146+
%2 = arith.truncf %1 : tensor<16x16xf32, #blocked> to tensor<16x16xf16, #blocked>
147+
%3 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<16x16x!tt.ptr<f16>, #blocked>
148+
tt.store %3, %2 : tensor<16x16x!tt.ptr<f16>, #blocked>
149+
tt.return
150+
}
151+
}

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,10 @@ struct ConvertLayoutOpMFMAToLinearConversion
223223
return failure();
224224

225225
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcType.getEncoding());
226-
assert(mfmaLayout.getMDim() == 32 && "Expected MFMA size 32");
226+
auto mDim = mfmaLayout.getMDim();
227+
auto nDim = mfmaLayout.getNDim();
228+
assert((mDim == 32 || mDim == 16) && mDim == nDim &&
229+
"Expected MFMA size 32 or 16");
227230
assert(triton::gpu::lookupThreadsPerWarp(rewriter) == 64 &&
228231
"Expected warp size 64 for MFMA");
229232

@@ -233,6 +236,8 @@ struct ConvertLayoutOpMFMAToLinearConversion
233236
SmallVector<Value> outVals;
234237
auto idx0 = b.i32_val(0);
235238
auto idx1 = b.i32_val(1);
239+
auto intrinsicName = mDim == 32 ? "llvm.amdgcn.permlane32.swap"
240+
: "llvm.amdgcn.permlane16.swap";
236241
// Convert MFMA layout to a MFMA-like linear layout where each thread
237242
// holds 8 consecutive elements
238243
for (size_t idx = 0; idx < inVals.size(); idx += 8) {
@@ -252,7 +257,7 @@ struct ConvertLayoutOpMFMAToLinearConversion
252257
Value falseVal = b.false_val();
253258
Value perm =
254259
LLVM::createLLVMIntrinsicCallOp(
255-
rewriter, loc, "llvm.amdgcn.permlane32.swap", retType,
260+
rewriter, loc, intrinsicName, retType,
256261
ValueRange{b.bitcast(inVecs[0], i32_ty),
257262
b.bitcast(inVecs[2], i32_ty), falseVal, falseVal})
258263
->getResult(0);
@@ -261,7 +266,7 @@ struct ConvertLayoutOpMFMAToLinearConversion
261266

262267
// Swap the row 2 and 3 of vec1 and the row 0 and 1 of vec3
263268
perm = LLVM::createLLVMIntrinsicCallOp(
264-
rewriter, loc, "llvm.amdgcn.permlane32.swap", retType,
269+
rewriter, loc, intrinsicName, retType,
265270
ValueRange{b.bitcast(inVecs[1], i32_ty),
266271
b.bitcast(inVecs[3], i32_ty), falseVal, falseVal})
267272
->getResult(0);

0 commit comments

Comments
 (0)