Skip to content

Commit e32c3b1

Browse files
authored
[AMD] Optimize to use 128-bit stores in epilogue for CDNA4 (triton-lang#6688)
Convert the mfmaLayout to a linear layout where each thread holds 8 consecutive elements to enable dwordx4 stores in the epilogue.
1 parent ca1ce1b commit e32c3b1

File tree

4 files changed

+131
-5
lines changed

4 files changed

+131
-5
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,13 @@ LinearLayout chooseScaledMfmaScaleLayout(
282282
MLIRContext *ctx, int dotOperandIdx,
283283
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
284284
ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim);
285-
} // namespace mlir::triton::gpu
286285

286+
// Create a LinearLayout similar to mfmaLayout, but changing each thread to hold
287+
// 8 elements. This layout is useful for emitting the widest 128-bit global
288+
// store instructions. Since it closely resembles mfmaLayout, conversion between
289+
// the two can be done using transferWithinWarp, without involving LDS
290+
LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout,
291+
ArrayRef<int64_t> shape);
292+
293+
} // namespace mlir::triton::gpu
287294
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,6 +1542,39 @@ LinearLayout chooseScaledMfmaScaleLayout(
15421542
return newLL;
15431543
}
15441544

1545+
LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout,
1546+
ArrayRef<int64_t> shape) {
1547+
assert(shape.size() == 2 && mfmaLayout.getMDim() == 32 &&
1548+
mfmaLayout.getNDim() == 32 && mfmaLayout.getIsTransposed());
1549+
1550+
MLIRContext *ctx = mfmaLayout.getContext();
1551+
StringAttr kRegister = S("register");
1552+
StringAttr kLane = S("lane");
1553+
StringAttr kWarp = S("warp");
1554+
StringAttr kBlock = S("block");
1555+
1556+
SmallVector<unsigned> order = getDefaultMmaOrder(mfmaLayout);
1557+
auto standardOutDims = standardOutDimNames(ctx, 2);
1558+
// We make each thread handle 8 consecutive elements to enable 128-bit
1559+
// global stores for [b]f16 types and keep the thread pattern in each lane
1560+
// similar to the canonical mfmaLayout.
1561+
LinearLayout mfma8Layout = LinearLayout::empty();
1562+
mfma8Layout =
1563+
LinearLayout({{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
1564+
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}},
1565+
{kWarp, {}},
1566+
{kBlock, {}}},
1567+
{standardOutDims[order[0]], standardOutDims[order[1]]});
1568+
1569+
LinearLayout warpLayout =
1570+
identityStandardND(kWarp, mfmaLayout.getWarpsPerCTA(), order);
1571+
LinearLayout ctaLayout = mfma8Layout.transposeOuts(standardOutDims) *
1572+
warpLayout.transposeOuts(standardOutDims);
1573+
mfma8Layout =
1574+
combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
1575+
return mfma8Layout;
1576+
}
1577+
15451578
LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,
15461579
int numWarps) {
15471580
assert(numWarps == 4 || numWarps == 8);

test/TritonGPU/amd/amd-optimize-epilogue.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,26 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32}
4040
tt.return
4141
}
4242
}
43+
44+
// -----
45+
// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[32, 0], [64, 0]], block = []}>
46+
// CHECK-LABEL: store_dword
47+
// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked>
48+
// CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128x!tt.ptr<f16>, #mma> -> tensor<128x128x!tt.ptr<f16>, #linear>
49+
// CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #linear>
50+
// CHECK: tt.store %[[PTR]], %[[VAL]] : tensor<128x128x!tt.ptr<f16>, #linear>
51+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}>
52+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
53+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
54+
tt.func public @store_dword(%arg0: !tt.ptr<f16>) attributes {noinline = false} {
55+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
56+
%cst_0 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
57+
%cst_1 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
58+
%0 = tt.dot %cst_0, %cst_1, %cst : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
59+
%1 = ttg.convert_layout %0 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked>
60+
%2 = arith.truncf %1 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked>
61+
%3 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked>
62+
tt.store %3, %2 : tensor<128x128x!tt.ptr<f16>, #blocked>
63+
tt.return
64+
}
65+
}

third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,59 @@ bool isOneOperandElementwiseOp(Operation *op) {
5959
return false;
6060
}
6161

62+
static triton::StoreOp convertMfmaLayoutForCDNA4(PatternRewriter &rewriter,
63+
Value ptr, Value val,
64+
Value mask,
65+
triton::StoreOp oldStOp) {
66+
auto ptrType = cast<RankedTensorType>(ptr.getType());
67+
auto valType = cast<RankedTensorType>(val.getType());
68+
69+
auto mfmaLayout =
70+
cast<triton::gpu::AMDMfmaEncodingAttr>(valType.getEncoding());
71+
72+
bool mfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32;
73+
74+
if (valType.getRank() != 2 ||
75+
(!valType.getElementType().isF16() &&
76+
!valType.getElementType().isBF16()) ||
77+
mfmaLayout.getVersionMajor() != 4 || !mfmaLayout.getIsTransposed() ||
78+
!mfma32) {
79+
return rewriter.create<triton::StoreOp>(oldStOp.getLoc(), ptr, val, mask,
80+
oldStOp.getCache(),
81+
oldStOp.getEvict());
82+
}
83+
84+
// Create a new layout where each thread holds 8 consecutive elements, in
85+
// order to enable wide 128-bit global stores.
86+
triton::LinearLayout mfma8Layout =
87+
chooseMfmaLikeStoreLayout(mfmaLayout, valType.getShape());
88+
89+
Attribute newEncoding = triton::gpu::LinearEncodingAttr::get(
90+
mfmaLayout.getContext(), mfma8Layout);
91+
auto newPtrType = RankedTensorType::get(
92+
ptrType.getShape(), ptrType.getElementType(), newEncoding);
93+
Value newPtr = rewriter.create<triton::gpu::ConvertLayoutOp>(ptr.getLoc(),
94+
newPtrType, ptr);
95+
96+
auto newValType = RankedTensorType::get(
97+
valType.getShape(), valType.getElementType(), newEncoding);
98+
Value newVal = rewriter.create<triton::gpu::ConvertLayoutOp>(val.getLoc(),
99+
newValType, val);
100+
101+
Value newMask = mask;
102+
if (mask) {
103+
auto maskType = dyn_cast<RankedTensorType>(mask.getType());
104+
auto newMaskType = RankedTensorType::get(
105+
maskType.getShape(), maskType.getElementType(), newEncoding);
106+
newMask = rewriter.create<triton::gpu::ConvertLayoutOp>(mask.getLoc(),
107+
newMaskType, mask);
108+
}
109+
110+
return rewriter.create<triton::StoreOp>(oldStOp.getLoc(), newPtr, newVal,
111+
newMask, oldStOp.getCache(),
112+
oldStOp.getEvict());
113+
}
114+
62115
// convert(val) : xmma -> blocked
63116
// elementWiseOp(val) : blocked
64117
// ...
@@ -126,19 +179,20 @@ class BypassEpilogueSMEM : public mlir::RewritePattern {
126179
auto newEncoding =
127180
cast<RankedTensorType>(cvtOp.getSrc().getType()).getEncoding();
128181

129-
auto newVal = cvtOp.getSrc();
130-
131182
auto newPtrType = RankedTensorType::get(
132183
ptrType.getShape(), ptrType.getElementType(), newEncoding);
133184
Value newPtr = rewriter.create<triton::gpu::ConvertLayoutOp>(
134185
ptr.getLoc(), newPtrType, ptr);
135186

187+
auto newVal = cvtOp.getSrc();
188+
136189
for (auto chainedOp : llvm::reverse(chainedOps)) {
137190
auto oldType =
138191
cast<mlir::RankedTensorType>(chainedOp->getResult(0).getType());
139192
chainedOp->setOperand(0, newVal);
140193
newVal = llvm::cast<mlir::TypedValue<RankedTensorType>>(
141194
chainedOp->getResult(0));
195+
142196
auto newType = mlir::RankedTensorType::get(
143197
oldType.getShape(), oldType.getElementType(), newEncoding);
144198
newVal.setType(newType);
@@ -152,9 +206,18 @@ class BypassEpilogueSMEM : public mlir::RewritePattern {
152206
newMask = rewriter.create<triton::gpu::ConvertLayoutOp>(
153207
mask.getLoc(), newMaskType, mask);
154208
}
209+
triton::StoreOp newStoreOp;
210+
if (auto mfmaLayout =
211+
dyn_cast<triton::gpu::AMDMfmaEncodingAttr>(newEncoding)) {
212+
newStoreOp =
213+
convertMfmaLayoutForCDNA4(rewriter, newPtr, newVal, newMask, stOp);
214+
} else {
215+
newStoreOp = rewriter.create<triton::StoreOp>(
216+
stOp.getLoc(), newPtr, newVal, newMask, stOp.getCache(),
217+
stOp.getEvict());
218+
}
155219

156-
rewriter.replaceOpWithNewOp<triton::StoreOp>(
157-
stOp, newPtr, newVal, newMask, stOp.getCache(), stOp.getEvict());
220+
rewriter.replaceOp(stOp, newStoreOp);
158221
return mlir::success();
159222
}
160223
};

0 commit comments

Comments
 (0)