Skip to content

Commit e6b9efd

Browse files
antiagainstyiqian1
andauthored
[AMD] Use composition to swap columns for mfma like store layout (#6844)
This commit improves how we create the mfma-like layout for optimizing global store by using linear layout composition. Along the way fixes a few implemenation issues. --------- Co-authored-by: Yi Qian <[email protected]>
1 parent 49a72f5 commit e6b9efd

File tree

4 files changed

+65
-65
lines changed

4 files changed

+65
-65
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,40 +1535,35 @@ std::optional<LinearLayout>
15351535
chooseMfmaLikeStoreLayout(RankedTensorType valType) {
15361536
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(valType.getEncoding());
15371537

1538-
// Currently support transposed [B]F16 MFMA32x32 on CDNA4
1538+
// We currently only support transposed [B]F16 MFMA32x32 on CDNA4.
15391539
bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32;
15401540
Type elemType = valType.getElementType();
15411541
if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) &&
15421542
mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() &&
15431543
isMfma32))
15441544
return {};
15451545

1546-
MLIRContext *ctx = mfmaLayout.getContext();
1547-
StringAttr kRegister = S("register");
1548-
StringAttr kLane = S("lane");
1549-
StringAttr kWarp = S("warp");
1550-
StringAttr kBlock = S("block");
1551-
1552-
SmallVector<unsigned> order = getDefaultMmaOrder(mfmaLayout);
1553-
auto standardOutDims = standardOutDimNames(ctx, 2);
1554-
// We make each thread handle 8 consecutive elements to enable 128-bit
1555-
// global stores for [b]f16 types and keep the thread pattern in each lane
1556-
// similar to the canonical mfmaLayout.
1557-
LinearLayout mfma8Layout = LinearLayout::empty();
1558-
mfma8Layout =
1559-
LinearLayout({{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
1560-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}},
1561-
{kWarp, {}},
1562-
{kBlock, {}}},
1563-
{standardOutDims[order[0]], standardOutDims[order[1]]});
1564-
1565-
LinearLayout warpLayout =
1566-
identityStandardND(kWarp, mfmaLayout.getWarpsPerCTA(), order);
1567-
LinearLayout ctaLayout = mfma8Layout.transposeOuts(standardOutDims) *
1568-
warpLayout.transposeOuts(standardOutDims);
1569-
mfma8Layout = combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(),
1570-
valType.getShape());
1571-
return mfma8Layout;
1546+
auto valShape = valType.getShape();
1547+
LinearLayout mfmaLL = mfmaLayout.toLinearLayout(valShape);
1548+
auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames());
1549+
StringAttr dimM = mfmaOutDims[0];
1550+
StringAttr dimN = mfmaOutDims[1];
1551+
1552+
auto swapLL = LinearLayout::empty();
1553+
// The rows are kept as is with an identity linear layout.
1554+
swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM);
1555+
// In transposed mfma32 layout, each thread holds 4 consecutive values along N
1556+
// dim. We want to exchange column 4-7 (owned by thread 32-63) and column 8-11
1557+
// (owned by thread 0-31) every 16 columns to make each thread holds 8
1558+
// elements. This would mean exchange the 2nd and 3rd basis vector from an
1559+
// identity linear layout.
1560+
std::vector<std::vector<int32_t>> dimNBases(mfmaLL.getOutDimSizeLog2(dimN));
1561+
std::generate(dimNBases.begin(), dimNBases.end(),
1562+
[i = 0]() mutable { return std::vector<int32_t>{1 << i++}; });
1563+
std::swap(dimNBases[2], dimNBases[3]);
1564+
swapLL *= LinearLayout({{dimN, dimNBases}}, {dimN});
1565+
1566+
return mfmaLL.compose(swapLL);
15721567
}
15731568

15741569
LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,

test/TritonGPU/amd/amd-optimize-epilogue.mlir

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32}
4343

4444
// -----
4545
// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[32, 0], [64, 0]], block = []}>
46-
// CHECK-LABEL: store_dword
46+
// CHECK-LABEL: store_dword_128x128
4747
// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked>
4848
// CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128x!tt.ptr<f16>, #mma> -> tensor<128x128x!tt.ptr<f16>, #linear>
4949
// CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #linear>
5050
// CHECK: tt.store %[[PTR]], %[[VAL]] : tensor<128x128x!tt.ptr<f16>, #linear>
5151
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}>
5252
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
5353
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
54-
tt.func public @store_dword(%arg0: !tt.ptr<f16>) attributes {noinline = false} {
54+
tt.func public @store_dword_128x128(%arg0: !tt.ptr<f16>) attributes {noinline = false} {
5555
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
5656
%cst_0 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
5757
%cst_1 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
@@ -63,3 +63,26 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32}
6363
tt.return
6464
}
6565
}
66+
67+
// -----
68+
// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 128], [64, 0], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 32], [0, 64], [32, 0]], block = []}>
69+
// CHECK-LABEL: store_dword_256x256
70+
// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked>
71+
// CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<256x256x!tt.ptr<f16>, #mma> -> tensor<256x256x!tt.ptr<f16>, #linear>
72+
// CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<256x256xf16, #mma> -> tensor<256x256xf16, #linear>
73+
// CHECK: tt.store %[[PTR]], %[[VAL]] : tensor<256x256x!tt.ptr<f16>, #linear>
74+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
75+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>
76+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
77+
tt.func public @store_dword_256x256(%arg0: !tt.ptr<f16>) attributes {noinline = false} {
78+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
79+
%cst_0 = arith.constant dense<1.230000e+02> : tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
80+
%cst_1 = arith.constant dense<1.230000e+02> : tensor<256x256xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
81+
%0 = tt.dot %cst_0, %cst_1, %cst : tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<256x256xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
82+
%1 = ttg.convert_layout %0 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked>
83+
%2 = arith.truncf %1 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked>
84+
%3 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x256x!tt.ptr<f16>, #blocked>
85+
tt.store %3, %2 : tensor<256x256x!tt.ptr<f16>, #blocked>
86+
tt.return
87+
}
88+
}

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,21 +189,10 @@ static bool matchMFMAAndLinearLayoutCase(RankedTensorType srcTy,
189189
if (!mfmaLayout || !linearLayout)
190190
return false;
191191

192-
std::optional<LinearLayout> srcLL =
192+
std::optional<LinearLayout> storeLL =
193193
mlir::triton::gpu::chooseMfmaLikeStoreLayout(srcTy);
194-
if (!srcLL)
195-
return false;
196-
197-
MLIRContext *ctx = linearLayout.getContext();
198-
StringAttr kLane = StringAttr::get(ctx, "lane");
199-
StringAttr kRegister = StringAttr::get(ctx, "register");
200-
auto srcBase = srcLL.value().getBases();
201-
auto srcReg = srcBase.lookup(kRegister);
202-
auto srcLane = srcBase.lookup(kLane);
203-
auto dstBases = linearLayout.getLinearLayout().getBases();
204-
auto dstReg = dstBases.lookup(kRegister);
205-
auto dstLane = dstBases.lookup(kLane);
206-
return dstReg == srcReg && dstLane == srcLane;
194+
return linearLayout.getLinearLayout() ==
195+
storeLL.value_or(LinearLayout::empty());
207196
};
208197

209198
struct ConvertLayoutOpMFMAToLinearConversion

third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -59,27 +59,23 @@ bool isOneOperandElementwiseOp(Operation *op) {
5959
return false;
6060
}
6161

62-
static triton::StoreOp convertMfmaLayoutForCDNA4(PatternRewriter &rewriter,
63-
Value ptr, Value val,
64-
Value mask,
65-
triton::StoreOp oldStOp) {
62+
// Tries to optimize oldStoreOp with v_permlane*_swap instruction when possible.
63+
// Returns null store op if not suitable.
64+
static triton::StoreOp
65+
usePermlaneSwapToOptimizeStore(PatternRewriter &rewriter, Value ptr, Value val,
66+
Value mask, triton::StoreOp oldStoreOp) {
6667
auto ptrType = cast<RankedTensorType>(ptr.getType());
6768
auto valType = cast<RankedTensorType>(val.getType());
6869

69-
auto mfmaLayout =
70-
cast<triton::gpu::AMDMfmaEncodingAttr>(valType.getEncoding());
71-
7270
// Create a new layout where each thread holds 8 consecutive elements, in
7371
// order to enable wide 128-bit global stores.
74-
std::optional<triton::LinearLayout> mfma8Layout =
72+
std::optional<triton::LinearLayout> storeLL =
7573
triton::gpu::chooseMfmaLikeStoreLayout(valType);
74+
if (!storeLL)
75+
return nullptr;
7676

77-
if (!mfma8Layout)
78-
return rewriter.create<triton::StoreOp>(oldStOp.getLoc(), ptr, val, mask,
79-
oldStOp.getCache(),
80-
oldStOp.getEvict());
8177
Attribute newEncoding = triton::gpu::LinearEncodingAttr::get(
82-
mfmaLayout.getContext(), mfma8Layout.value());
78+
oldStoreOp.getContext(), storeLL.value());
8379
auto newPtrType = RankedTensorType::get(
8480
ptrType.getShape(), ptrType.getElementType(), newEncoding);
8581
Value newPtr = rewriter.create<triton::gpu::ConvertLayoutOp>(ptr.getLoc(),
@@ -99,9 +95,9 @@ static triton::StoreOp convertMfmaLayoutForCDNA4(PatternRewriter &rewriter,
9995
newMaskType, mask);
10096
}
10197

102-
return rewriter.create<triton::StoreOp>(oldStOp.getLoc(), newPtr, newVal,
103-
newMask, oldStOp.getCache(),
104-
oldStOp.getEvict());
98+
return rewriter.create<triton::StoreOp>(oldStoreOp.getLoc(), newPtr, newVal,
99+
newMask, oldStoreOp.getCache(),
100+
oldStoreOp.getEvict());
105101
}
106102

107103
// convert(val) : xmma -> blocked
@@ -195,12 +191,9 @@ class BypassEpilogueSMEM : public mlir::OpRewritePattern<triton::StoreOp> {
195191
newMask = rewriter.create<triton::gpu::ConvertLayoutOp>(
196192
mask.getLoc(), newMaskType, mask);
197193
}
198-
triton::StoreOp newStoreOp;
199-
if (auto mfmaLayout =
200-
dyn_cast<triton::gpu::AMDMfmaEncodingAttr>(newEncoding)) {
201-
newStoreOp =
202-
convertMfmaLayoutForCDNA4(rewriter, newPtr, newVal, newMask, stOp);
203-
} else {
194+
triton::StoreOp newStoreOp =
195+
usePermlaneSwapToOptimizeStore(rewriter, newPtr, newVal, newMask, stOp);
196+
if (!newStoreOp) {
204197
newStoreOp = rewriter.create<triton::StoreOp>(
205198
stOp.getLoc(), newPtr, newVal, newMask, stOp.getCache(),
206199
stOp.getEvict());

0 commit comments

Comments
 (0)