Skip to content

Commit 4272074

Browse files
zhanglx13loislo
authored andcommitted
[AMD][FA] Improve warp distribution for decode attention dot (triton-lang#5892)
This PR improves the logic regarding warp distribution for FA kernels 1. Always choose warpsPerCTA=[numWarps, 1] for the 1st dot 2. For the 2nd dot, distribute warps along dim0 first, then dim1 This helps register pressure for FA kernel with a large output head size.
1 parent 4c298b9 commit 4272074

File tree

2 files changed

+209
-50
lines changed

2 files changed

+209
-50
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx940 matrix-instruction-size=16' | FileCheck %s --check-prefixes MFMA16,CHECK
2+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul='arch-generation-name=gfx940 matrix-instruction-size=32' | FileCheck %s --check-prefixes MFMA32,CHECK
3+
4+
// Check the warpsPerCTA parameter of #mma layout of the two dot's.
5+
// The 1st dot always has warpsPerCTA = [4, 1].
6+
// The warpsPerCTA for the 2nd dot depends on mfma instruction size and BLOCK_M size.
7+
8+
9+
// BLOCK_M = 128
10+
// warpsPerCTA = [4, 1] for mfma16 and mfma32
11+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
12+
#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #blocked}>
13+
#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #blocked}>
14+
// MFMA16{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
15+
// MFMA32{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
16+
// CHECK-LABEL: mfma_chain_dot_BM128
17+
// CHECK: tt.dot {{.*}} : {{.*}} -> tensor<128x16xf32, #mma>
18+
// CHECK: tt.dot {{.*}} : {{.*}} -> tensor<128x128xf32, #mma>
19+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
20+
tt.func public @mfma_chain_dot_BM128(
21+
%q: tensor<128x128xf16, #dotOp0>,
22+
%k: tensor<128x16xf16, #dotOp1>,
23+
%v: tensor<16x128xf16, #dotOp1>,
24+
%o_ptr: tensor<128x128x!tt.ptr<f32>, #blocked>) {
25+
%cst = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #blocked>
26+
%cst1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
27+
%qk = tt.dot %q, %k, %cst : tensor<128x128xf16, #dotOp0> * tensor<128x16xf16, #dotOp1> -> tensor<128x16xf32, #blocked>
28+
%qk_f16 = arith.truncf %qk : tensor<128x16xf32, #blocked> to tensor<128x16xf16, #blocked>
29+
%p = ttg.convert_layout %qk_f16 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #dotOp0>
30+
%o = tt.dot %p, %v, %cst1 : tensor<128x16xf16, #dotOp0> * tensor<16x128xf16, #dotOp1> -> tensor<128x128xf32, #blocked>
31+
tt.store %o_ptr, %o : tensor<128x128x!tt.ptr<f32>, #blocked>
32+
tt.return
33+
}
34+
}
35+
36+
37+
// -----
38+
39+
// BLOCK_M = 64
40+
// warpsPerCTA = [4, 1] for mfma16
41+
// warpsPerCTA = [2, 2] for mfma32
42+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
43+
#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #blocked}>
44+
#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #blocked}>
45+
// MFMA16{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
46+
// MFMA32{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
47+
// MFMA32{LITERAL}: #mma1 = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
48+
// CHECK-LABEL: mfma_chain_dot_BM64
49+
// CHECK: tt.dot {{.*}} : {{.*}} -> tensor<64x16xf32, #mma>
50+
// MFMA16: tt.dot {{.*}} : {{.*}} -> tensor<64x128xf32, #mma>
51+
// MFMA32: tt.dot {{.*}} : {{.*}} -> tensor<64x128xf32, #mma1>
52+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
53+
tt.func public @mfma_chain_dot_BM64(
54+
%q: tensor<64x128xf16, #dotOp0>,
55+
%k: tensor<128x16xf16, #dotOp1>,
56+
%v: tensor<16x128xf16, #dotOp1>,
57+
%o_ptr: tensor<64x128x!tt.ptr<f32>, #blocked>) {
58+
%cst = arith.constant dense<0.000000e+00> : tensor<64x16xf32, #blocked>
59+
%cst1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked>
60+
%qk = tt.dot %q, %k, %cst : tensor<64x128xf16, #dotOp0> * tensor<128x16xf16, #dotOp1> -> tensor<64x16xf32, #blocked>
61+
%qk_f16 = arith.truncf %qk : tensor<64x16xf32, #blocked> to tensor<64x16xf16, #blocked>
62+
%p = ttg.convert_layout %qk_f16 : tensor<64x16xf16, #blocked> -> tensor<64x16xf16, #dotOp0>
63+
%o = tt.dot %p, %v, %cst1 : tensor<64x16xf16, #dotOp0> * tensor<16x128xf16, #dotOp1> -> tensor<64x128xf32, #blocked>
64+
tt.store %o_ptr, %o : tensor<64x128x!tt.ptr<f32>, #blocked>
65+
tt.return
66+
}
67+
}
68+
69+
70+
// -----
71+
72+
// BLOCK_M = 32
73+
// warpsPerCTA = [2, 2] for mfma16
74+
// warpsPerCTA = [1, 4] for mfma32
75+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
76+
#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #blocked}>
77+
#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #blocked}>
78+
// MFMA16{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
79+
// MFMA32{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
80+
// MFMA16{LITERAL}: #mma1 = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}>
81+
// MFMA32{LITERAL}: #mma1 = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
82+
// CHECK-LABEL: mfma_chain_dot_BM32
83+
// CHECK: tt.dot {{.*}} : {{.*}} -> tensor<32x16xf32, #mma>
84+
// MFMA16: tt.dot {{.*}} : {{.*}} -> tensor<32x128xf32, #mma1>
85+
// MFMA32: tt.dot {{.*}} : {{.*}} -> tensor<32x128xf32, #mma1>
86+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
87+
tt.func public @mfma_chain_dot_BM32(
88+
%q: tensor<32x128xf16, #dotOp0>,
89+
%k: tensor<128x16xf16, #dotOp1>,
90+
%v: tensor<16x128xf16, #dotOp1>,
91+
%o_ptr: tensor<32x128x!tt.ptr<f32>, #blocked>) {
92+
%cst = arith.constant dense<0.000000e+00> : tensor<32x16xf32, #blocked>
93+
%cst1 = arith.constant dense<0.000000e+00> : tensor<32x128xf32, #blocked>
94+
%qk = tt.dot %q, %k, %cst : tensor<32x128xf16, #dotOp0> * tensor<128x16xf16, #dotOp1> -> tensor<32x16xf32, #blocked>
95+
%qk_f16 = arith.truncf %qk : tensor<32x16xf32, #blocked> to tensor<32x16xf16, #blocked>
96+
%p = ttg.convert_layout %qk_f16 : tensor<32x16xf16, #blocked> -> tensor<32x16xf16, #dotOp0>
97+
%o = tt.dot %p, %v, %cst1 : tensor<32x16xf16, #dotOp0> * tensor<16x128xf16, #dotOp1> -> tensor<32x128xf32, #blocked>
98+
tt.store %o_ptr, %o : tensor<32x128x!tt.ptr<f32>, #blocked>
99+
tt.return
100+
}
101+
}
102+
103+
104+
// -----
105+
106+
// BLOCK_M = 16, only check mfma16 since it's too small for mfma32
107+
// warpsPerCTA = [1, 4] for mfma16
108+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
109+
#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #blocked}>
110+
#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #blocked}>
111+
// MFMA16{LITERAL}: #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}>
112+
// MFMA16{LITERAL}: #mma1 = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 16], isTransposed = true}>
113+
// CHECK-LABEL: mfma_chain_dot_BM16
114+
// CHECK: tt.dot {{.*}} : {{.*}} -> tensor<16x16xf32, #mma>
115+
// MFMA16: tt.dot {{.*}} : {{.*}} -> tensor<16x128xf32, #mma1>
116+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
117+
tt.func public @mfma_chain_dot_BM16(
118+
%q: tensor<16x128xf16, #dotOp0>,
119+
%k: tensor<128x16xf16, #dotOp1>,
120+
%v: tensor<16x128xf16, #dotOp1>,
121+
%o_ptr: tensor<16x128x!tt.ptr<f32>, #blocked>) {
122+
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked>
123+
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #blocked>
124+
%qk = tt.dot %q, %k, %cst : tensor<16x128xf16, #dotOp0> * tensor<128x16xf16, #dotOp1> -> tensor<16x16xf32, #blocked>
125+
%qk_f16 = arith.truncf %qk : tensor<16x16xf32, #blocked> to tensor<16x16xf16, #blocked>
126+
%p = ttg.convert_layout %qk_f16 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #dotOp0>
127+
%o = tt.dot %p, %v, %cst1 : tensor<16x16xf16, #dotOp0> * tensor<16x128xf16, #dotOp1> -> tensor<16x128xf32, #blocked>
128+
tt.store %o_ptr, %o : tensor<16x128x!tt.ptr<f32>, #blocked>
129+
tt.return
130+
}
131+
}

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -43,28 +43,88 @@ int getWmmaVersion(StringRef archGen) {
4343
return 0;
4444
}
4545

46-
SmallVector<unsigned, 3>
47-
warpsPerTile(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps,
48-
std::pair<int64_t, int64_t> shapePerWarp) {
49-
auto rank = shape.size();
50-
// Early exit for batched matmul
51-
if (rank == 3)
52-
return {(unsigned)numWarps, 1, 1};
53-
54-
auto filter = [dotOp](Operation *op) {
46+
// Check if the result of this tl.dot is used as opA of another tl.dot
47+
// in the same region
48+
bool isChainDotHead(tt::DotOpInterface dotOp) {
49+
auto isInSameRegion = [&dotOp](Operation *op) {
5550
return op->getParentRegion() == dotOp->getParentRegion();
5651
};
5752
ForwardSliceOptions fwdOpt;
58-
fwdOpt.filter = filter;
53+
fwdOpt.filter = isInSameRegion;
54+
SetVector<mlir::Operation *> fwdSlices;
55+
getForwardSlice(dotOp, &fwdSlices, fwdOpt);
56+
for (Operation *op : fwdSlices) {
57+
if (auto dOp = dyn_cast<tt::DotOpInterface>(op)) {
58+
assert(dOp != dotOp);
59+
auto opA = dOp.getA().getDefiningOp();
60+
if (opA && fwdSlices.contains(opA)) {
61+
return true;
62+
}
63+
}
64+
}
65+
return false;
66+
}
67+
68+
// Check if the opA of this tl.dot is the result of another tl.dot
69+
// in the same region
70+
bool isChainDotTail(tt::DotOpInterface dotOp) {
71+
auto isInSameRegion = [&dotOp](Operation *op) {
72+
return op->getParentRegion() == dotOp->getParentRegion();
73+
};
5974
BackwardSliceOptions bwdOpt;
6075
bwdOpt.omitBlockArguments = true;
61-
bwdOpt.filter = filter;
62-
auto slices = getSlice(dotOp, bwdOpt, fwdOpt);
63-
for (Operation *op : slices) {
64-
if (isa<mlir::triton::DotOpInterface>(op) && (op != dotOp))
65-
return {(unsigned)numWarps, 1};
76+
bwdOpt.filter = isInSameRegion;
77+
SetVector<Operation *> bwdSlices;
78+
Operation *opA = dotOp.getA().getDefiningOp();
79+
if (!opA)
80+
return false;
81+
getBackwardSlice(opA, &bwdSlices, bwdOpt);
82+
if (llvm::find_if(bwdSlices, [](Operation *op) {
83+
return isa<tt::DotOpInterface>(op);
84+
}) != bwdSlices.end())
85+
return true;
86+
return false;
87+
}
88+
89+
SmallVector<unsigned, 3>
90+
warpsPerTile(Operation *dotOp, ArrayRef<int64_t> shape, int numWarps,
91+
std::pair<int64_t, int64_t> shapePerWarp) {
92+
auto rank = shape.size();
93+
// Case 1: Early exit for batched matmul
94+
if (rank == 3)
95+
return {static_cast<unsigned>(numWarps), 1, 1};
96+
97+
// Case 2: For FA-like pattern, i.e. result of 1st tl.dot is used as the opA
98+
// of the 2nd dot, we will set warpsPerCTA differently for 1st and 2nd dot
99+
auto ttDotOp = cast<tt::DotOpInterface>(dotOp);
100+
bool isHeadDot = isChainDotHead(ttDotOp);
101+
bool isTailDot = isChainDotTail(ttDotOp);
102+
// For the 1st dot in chain-dot, we always set warpsPerCTA={numWarps, 1}
103+
// because this eliminates
104+
// 1) inter-warp reduction in the softmax step.
105+
// 2) layout conversion from #mma to #dot_op of the second dot.
106+
if (isHeadDot)
107+
return {static_cast<unsigned>(numWarps), 1};
108+
// For the 2nd dot in chain-dot, we always distribute warp along dim0 first,
109+
// then dim1. Because
110+
// 1) This is how we distribute the warps for the 1st dot. Now the
111+
// warpsPerCTA for the 1st dot become the warp layout of the dotOperand
112+
// layout of the 2nd dot, which must match the warpsPerCTA of the 2nd dot.
113+
// 2) When shape[0] is small, as in decode kernels, we don't want to
114+
// distribute more warps than shape[0] // mDim. If we do so, each warp
115+
// needs to hold more elements in the final output, which increases
116+
// register pressure, especially for large head dim (e.g. 512) attention
117+
// kernels.
118+
if (isTailDot) {
119+
SmallVector<unsigned, 3> ret = {1, 1};
120+
ret[0] = static_cast<unsigned>(std::min(
121+
static_cast<int64_t>(numWarps),
122+
static_cast<int64_t>(llvm::divideCeil(shape[0], shapePerWarp.first))));
123+
ret[1] = numWarps / ret[0];
124+
return ret;
66125
}
67126

127+
// Case 3: Regular cases
68128
SmallVector<int64_t, 2> tensorShape = {shape[0], shape[1]};
69129
SmallVector<unsigned, 3> ret = {1, 1};
70130
do {
@@ -365,39 +425,6 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
365425
: OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion),
366426
nonKDim(nonKDim), kPack(kPack) {}
367427

368-
bool isChainDot(tt::DotOp &dotOp) const {
369-
auto filter = [&dotOp](Operation *op) {
370-
return op->getParentRegion() == dotOp->getParentRegion();
371-
};
372-
ForwardSliceOptions fwdOpt;
373-
fwdOpt.filter = filter;
374-
BackwardSliceOptions bwdOpt;
375-
bwdOpt.omitBlockArguments = true;
376-
bwdOpt.filter = filter;
377-
auto slices = getSlice(dotOp, bwdOpt, fwdOpt);
378-
for (Operation *op : slices) {
379-
if (isa<tt::DotOp>(op) && (op != dotOp))
380-
return true;
381-
}
382-
return false;
383-
}
384-
385-
bool isSecondDot(tt::DotOp &dotOp) const {
386-
auto filter = [&dotOp](Operation *op) {
387-
return op->getParentRegion() == dotOp->getParentRegion();
388-
};
389-
BackwardSliceOptions bwdOpt;
390-
bwdOpt.omitBlockArguments = true;
391-
bwdOpt.filter = filter;
392-
SetVector<Operation *> slices;
393-
getBackwardSlice(dotOp.getResult(), &slices, bwdOpt);
394-
if (llvm::find_if(slices, [](Operation *op) {
395-
return isa<tt::DotOp>(op);
396-
}) != slices.end())
397-
return true;
398-
return false;
399-
}
400-
401428
LogicalResult matchAndRewrite(tt::DotOp dotOp,
402429
PatternRewriter &rewriter) const override {
403430
RankedTensorType oldRetType = dotOp.getType();
@@ -439,7 +466,8 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
439466
// TODO (lixun): investigate the regression and enable this feature again
440467
auto aElemTy = mfmaInstr->aElementType;
441468
bool isFP8 = llvm::isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(aElemTy);
442-
bool isTransposed = isChainDot(dotOp) || !isFP8;
469+
bool isTransposed =
470+
isChainDotHead(dotOp) || isChainDotTail(dotOp) || !isFP8;
443471
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
444472
oldRetType.getContext(),
445473
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
@@ -492,7 +520,7 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
492520
// to increase ds_read vector size
493521
// However, in FA, the second dot can only use kWidth = kBase since it's
494522
// limited by the result of the first dot, which is of mfmaLayout.
495-
if (!isSecondDot(dotOp))
523+
if (!isChainDotTail(dotOp))
496524
kWidth *= kPack;
497525

498526
auto newAEncoding =

0 commit comments

Comments
 (0)