Skip to content

Commit 2903596

Browse files
authored
Group Query Attention (GQA) optimization (#1984)
Add GQA optimization: move num_heads_q to seq_len in some cases to reduce mfma/wmma padding
1 parent e174a10 commit 2903596

33 files changed

+535
-180
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,10 @@ def Rock_AttentionOp
212212
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
213213
Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
214214
TensorOrMemRefOf<[F32, F16, BF16]>:$out,
215-
Optional<TensorOrMemRefOf<[F32, F16, BF16]>>:$lse,
216-
UnitAttr:$qTransposed, UnitAttr:$kTransposed, UnitAttr:$vTransposed,
217-
UnitAttr:$oTransposed, UnitAttr:$causal, I32Attr:$splitKV,
218-
OptionalAttr<Rock_GemmFeaturesAttr>:$features,
215+
Optional<TensorOrMemRefOf<[F32, F16, BF16]>>:$lse, I32Attr:$numHeadsQ,
216+
I32Attr:$numHeadsKV, UnitAttr:$qTransposed, UnitAttr:$kTransposed,
217+
UnitAttr:$vTransposed, UnitAttr:$oTransposed, UnitAttr:$causal,
218+
I32Attr:$splitKV, OptionalAttr<Rock_GemmFeaturesAttr>:$features,
219219
StoreMethodAttr:$storeMethod, OptionalAttr<TypeAttr>:$softmaxType,
220220
OptionalAttr<RockTuningParamAttrInterface>:$params0,
221221
OptionalAttr<RockTuningParamAttrInterface>:$params1,
@@ -534,6 +534,7 @@ def Rock_GridwiseAttentionAccelOp
534534
StoreMethodAttr:$storeMethod, I32Attr:$blockSize, I32Attr:$gridSize,
535535
UnitAttr:$disableQBypassLDS, OptionalAttr<IndexAttr>:$prePadG0M,
536536
OptionalAttr<IndexAttr>:$prePadG0N,
537+
OptionalAttr<IndexAttr>:$numRepeatsGQA,
537538
OptionalAttr<TypeAttr>:$softmaxType,
538539
RockAccelTuningParamAttrInterface:$params0,
539540
RockAccelTuningParamAttrInterface:$params1,

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2136,9 +2136,13 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
21362136
ElementwiseRegionFinder<tosa::MatMulOp> elemwiseRegion =
21372137
attentionMatcherValues.preSoftmaxElementwiseFinder;
21382138
int64_t firstGemmBlockIndex = elemwiseRegion.getFirstGemmBlockIndex();
2139+
2140+
// TODO: numHeadsQ and numHeadsKV migraphx integration
21392141
rock::AttentionOp attnOp = rewriter.create<rock::AttentionOp>(
21402142
loc, outputType, lseType, firstMatMulOp.getA(), firstMatMulOp.getB(),
21412143
op.getB(), elementwiseOtherArgs, currentSeqLen, output, lseOut,
2144+
/*numHeadsQ=*/rewriter.getI32IntegerAttr(1),
2145+
/*numHeadsKV=*/rewriter.getI32IntegerAttr(1),
21422146
/*qTransposed=*/nullptr,
21432147
/*kTransposed=*/nullptr,
21442148
/*vTransposed=*/nullptr,

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2332,7 +2332,21 @@ GemmGemmSize GemmElementwiseGemmOp::getGemmGemmSize() {
23322332
}
23332333

23342334
static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op,
2335-
Value currentSeqLen, Value lse) {
2335+
Value currentSeqLen, Value lse,
2336+
int32_t numHeadsQ,
2337+
int32_t numHeadsKV) {
2338+
// number of heads for Q and K, V
2339+
if (numHeadsQ <= 0) {
2340+
return op.emitError("numHeadsQ must be positive");
2341+
}
2342+
if (numHeadsKV <= 0) {
2343+
return op.emitError("numHeadsKV must be positive");
2344+
}
2345+
if (numHeadsQ % numHeadsKV != 0) {
2346+
return op.emitError("numHeadsQ is not divisible by numHeadsKV");
2347+
}
2348+
int64_t factorGQA = numHeadsQ / numHeadsKV;
2349+
23362350
ShapedType qType = cast<ShapedType>(op.getAType());
23372351
int64_t qBatchDim = qType.getShape().size() == 3 ? qType.getShape()[0] : 1;
23382352
ArrayRef<int64_t> qLastDims = qType.getShape().slice(qType.getRank() - 2);
@@ -2342,13 +2356,15 @@ static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op,
23422356

23432357
ShapedType kType = cast<ShapedType>(op.getBType());
23442358
int64_t kBatchDim = kType.getShape().size() == 3 ? kType.getShape()[0] : 1;
2359+
kBatchDim *= factorGQA;
23452360
ArrayRef<int64_t> kLastDims = kType.getShape().slice(kType.getRank() - 2);
23462361
auto [keyK, keyN] = op.getTransposedB()
23472362
? std::tuple{kLastDims[1], kLastDims[0]}
23482363
: std::tuple{kLastDims[0], kLastDims[1]};
23492364

23502365
ShapedType vType = cast<ShapedType>(op.getCType());
23512366
int64_t vBatchDim = vType.getShape().size() == 3 ? vType.getShape()[0] : 1;
2367+
vBatchDim *= factorGQA;
23522368
ArrayRef<int64_t> vLastDims = vType.getShape().slice(vType.getRank() - 2);
23532369
auto [valueK, valueN] = op.getTransposedC()
23542370
? std::tuple{vLastDims[1], vLastDims[0]}
@@ -2419,12 +2435,14 @@ static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op,
24192435
return op.emitError("SeqLenQ dimensions do not match (LSE and Q)");
24202436
}
24212437
}
2438+
24222439
return success();
24232440
}
24242441

24252442
LogicalResult GemmElementwiseGemmOp::verify() {
24262443
return verifyGemmPlusGemmLikeOp(*this, /*currentSeqLen=*/nullptr,
2427-
/*lse=*/nullptr);
2444+
/*lse=*/nullptr, /*numHeadsQ=*/1,
2445+
/*numHeadsKV=*/1);
24282446
}
24292447

24302448
void GemmElementwiseGemmOp::getEffects(
@@ -2520,7 +2538,8 @@ GemmGemmSize ConvElementwiseGemmOp::getGemmGemmSize() {
25202538

25212539
LogicalResult ConvElementwiseGemmOp::verify() {
25222540
return verifyGemmPlusGemmLikeOp(*this, /*currentSeqLen=*/nullptr,
2523-
/*lse=*/nullptr);
2541+
/*lse=*/nullptr, /*numHeadsQ=*/1,
2542+
/*numHeadsKV=*/1);
25242543
}
25252544

25262545
void ConvElementwiseGemmOp::getEffects(
@@ -2598,7 +2617,8 @@ LogicalResult AttentionOp::verify() {
25982617
if (getStoreMethod() != StoreMethod::Set)
25992618
return emitError("Only set store method is supported for attention.");
26002619

2601-
return verifyGemmPlusGemmLikeOp(*this, getCurrentSeqLen(), getLse());
2620+
return verifyGemmPlusGemmLikeOp(*this, getCurrentSeqLen(), getLse(),
2621+
getNumHeadsQ(), getNumHeadsKV());
26022622
}
26032623

26042624
void AttentionOp::getEffects(

mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp

Lines changed: 167 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@
4444
#include "mlir/Support/LogicalResult.h"
4545
#include "mlir/Transforms/DialectConversion.h"
4646

47+
#include "llvm/ADT/STLExtras.h"
4748
#include "llvm/Support/Debug.h"
49+
#include "llvm/Support/Errc.h"
4850
#include "llvm/Support/LogicalResult.h"
4951
#include <algorithm>
5052
#include <memory>
@@ -108,6 +110,154 @@ struct AttentionRewritePattern : public OpConversionPattern<AttentionOp> {
108110
ConversionPatternRewriter &rw) const override;
109111
};
110112

113+
// Move num_heads dimension to sequence length dimension. This is useful for the
114+
// decoding phase, when batch=1, seq_len_q = 1 and GQA (example: num_heads_q=64,
115+
// num_heads_kv=8), we can move numRepeat=num_heads_q/num_heads_kv = 8, to the
116+
// seq_len_q dimension and use the tile size better (otherwise seq_len_q=1 and
117+
// it will get padded to 32). This reduces the amount of workgroups by
118+
// numRepeat. However, typically decoding phase will use split_kv anyway to
119+
// increase the number of workgroups.
120+
static Value moveNumHeadsToSeqLenQ(OpBuilder builder, Location loc,
121+
Value inputTensor, int64_t numRepeats) {
122+
ArrayRef<int64_t> inpShape =
123+
cast<ShapedType>(inputTensor.getType()).getShape();
124+
125+
assert(inpShape.size() == 3 && "input must be 3D");
126+
assert(inpShape[0] % numRepeats == 0 &&
127+
"gemmG must be divisible by numRepeats");
128+
129+
int64_t newGemmG = inpShape[0] / numRepeats;
130+
SmallVector<StringRef> startNames = {"gemmG", "headDim", "seqLen"};
131+
132+
// (gemmG, headDim, seqLen) -> (gemmG / numRepeats, headDim, seqLen,
133+
// numRepeats)
134+
rock::BottomUpTMBuilder unmerge(builder, startNames, inpShape);
135+
unmerge.unmerge({"gemmG", "numRepeats"}, {0, 3}, "gemmG",
136+
{newGemmG, numRepeats});
137+
unmerge.passThrough({"seqLen", "headDim"}, {2, 1}, {"seqLen", "headDim"});
138+
auto unmergeAttr = unmerge.get();
139+
Value matrixUnmerge =
140+
builder.create<rock::TransformOp>(loc, inputTensor, unmergeAttr);
141+
142+
// (gemmG / numRepeats, headDim, seqLen, numRepeats) -> (gemmG / numRepeats,
143+
// headDim, seqLen * numRepeats)
144+
auto merger = rock::BottomUpTMBuilder::above(unmerge, unmergeAttr);
145+
merger.merge("seqLen", 2, {"seqLen", "numRepeats"});
146+
merger.passThrough(ArrayRef<uint32_t>{0, 1}, ArrayRef<uint32_t>{0, 1});
147+
auto mergerAttr = merger.get();
148+
return builder.create<rock::TransformOp>(loc, matrixUnmerge, mergerAttr);
149+
}
150+
151+
// Same as moveNumHeadsToSeqLenQ() but for currSeqLen tensor (KV-Cache)
152+
static Value moveNumHeadsToSeqLenCurrSeqLen(OpBuilder builder, Location loc,
153+
Value inputTensor,
154+
int64_t numRepeats) {
155+
ArrayRef<int64_t> inpShape =
156+
cast<ShapedType>(inputTensor.getType()).getShape();
157+
158+
assert(inpShape.size() == 1 && "input must be 1D");
159+
assert(inpShape[0] % numRepeats == 0 &&
160+
"gemmG must be divisible by numRepeats");
161+
162+
int64_t newGemmG = inpShape[0] / numRepeats;
163+
SmallVector<StringRef> startNames = {"gemmG"};
164+
165+
// (gemmG) -> (gemmG / numRepeats, numRepeats)
166+
rock::BottomUpTMBuilder unmerge(builder, startNames, inpShape);
167+
unmerge.unmerge({"gemmG", "numRepeats"}, {0, 1}, "gemmG",
168+
{newGemmG, numRepeats});
169+
auto unmergeAttr = unmerge.get();
170+
Value matrixUnmerge =
171+
builder.create<rock::TransformOp>(loc, inputTensor, unmergeAttr);
172+
173+
// slice numRepeats to 1
174+
auto slicer = rock::BottomUpTMBuilder::above(unmerge, unmergeAttr);
175+
slicer.slice({"numRepeats"}, {"numRepeats"}, {0}, {1});
176+
slicer.passThrough(ArrayRef<uint32_t>{0}, ArrayRef<uint32_t>{0});
177+
auto slicerAttr = slicer.get();
178+
Value matrixSliced =
179+
builder.create<rock::TransformOp>(loc, matrixUnmerge, slicerAttr);
180+
181+
// (gemmG / numRepeats, headDim, seqLen, numRepeats) -> (gemmG / numRepeats,
182+
// headDim, seqLen * numRepeats)
183+
auto merger = rock::BottomUpTMBuilder::above(slicer, slicerAttr);
184+
merger.merge("seqLen", 0, {"gemmG", "numRepeats"});
185+
auto mergerAttr = merger.get();
186+
return builder.create<rock::TransformOp>(loc, matrixSliced, mergerAttr);
187+
}
188+
189+
// Same as moveNumHeadsToSeqLenQ() but for the output tensor
190+
static Value moveNumHeadsToSeqLenOut(OpBuilder builder, Location loc,
191+
Value inputTensor, int64_t numRepeats,
192+
int64_t splitKV) {
193+
ArrayRef<int64_t> inpShape =
194+
cast<ShapedType>(inputTensor.getType()).getShape();
195+
196+
assert((inpShape.size() == 2 || inpShape.size() == 3) &&
197+
"input must be 2D or 3D");
198+
assert(inpShape[0] % numRepeats == 0 &&
199+
"gemmG must be divisible by numRepeats");
200+
assert(inpShape[0] % splitKV == 0 && "gemmG must be divisible by numRepeats");
201+
202+
int64_t newGemmG = inpShape[0] / (numRepeats * splitKV);
203+
bool isLSE = inpShape.size() == 2;
204+
205+
SmallVector<StringRef> startNamesAll = {"gemmG", "seqLen", "headDim"};
206+
ArrayRef<StringRef> startNames =
207+
ArrayRef<StringRef>(startNamesAll).take_front(inpShape.size());
208+
209+
// Note that for LSE, there are only two dimensions (gemmG, seqLen)
210+
// (gemmG, seqLen, headDim) -> (gemmG / (splitKV*numRepeats), splitKV, seqLen,
211+
// numRepeats, headDim)
212+
rock::BottomUpTMBuilder unmerge(builder, startNames, inpShape);
213+
unmerge.unmerge({"gemmG", "numRepeats", "splitKV"}, {0, 3, 1}, "gemmG",
214+
{newGemmG, numRepeats, splitKV});
215+
if (isLSE)
216+
unmerge.passThrough({"seqLen"}, {2}, {"seqLen"});
217+
else
218+
unmerge.passThrough({"seqLen", "headDim"}, {2, 4}, {"seqLen", "headDim"});
219+
auto unmergeAttr = unmerge.get();
220+
Value matrixUnmerge =
221+
builder.create<rock::TransformOp>(loc, inputTensor, unmergeAttr);
222+
223+
// (gemmG / (splitKV*numRepeats), splitKV, seqLen, numRepeats, headDim) ->
224+
// (gemmG / numRepeats, seqLen * numRepeats, headDim)
225+
auto merger = rock::BottomUpTMBuilder::above(unmerge, unmergeAttr);
226+
merger.merge("seqLen", 1, {"seqLen", "numRepeats"});
227+
merger.merge("gemmG", 0, {"gemmG", "splitKV"});
228+
if (!isLSE)
229+
merger.passThrough({"headDim"}, {2}, {"headDim"});
230+
auto mergerAttr = merger.get();
231+
return builder.create<rock::TransformOp>(loc, matrixUnmerge, mergerAttr);
232+
}
233+
234+
// This function will implement GQA, moving numRepeat=num_heads_q/num_heads_kv
235+
// to the seq_len_q dimension. See moveNumHeadsToSeqLenQ() comment for more
236+
// details.
237+
static std::tuple<IntegerAttr, Value, Value, Value, Value, Value, Value>
238+
processGQA(ConversionPatternRewriter &rw, Location loc, Value queries,
239+
Value keys, Value values, Value out, Value lse, Value currentSeqLen,
240+
int64_t numHeadsQ, int64_t numHeadsKV, int64_t splitKV) {
241+
assert(numHeadsQ % numHeadsKV == 0);
242+
IntegerAttr numRepeatsAttr = nullptr;
243+
244+
if (numHeadsQ != numHeadsKV) {
245+
int64_t numRepeats = numHeadsQ / numHeadsKV;
246+
247+
numRepeatsAttr = rw.getIndexAttr(numRepeats);
248+
queries = moveNumHeadsToSeqLenQ(rw, loc, queries, numRepeats);
249+
if (currentSeqLen)
250+
currentSeqLen =
251+
moveNumHeadsToSeqLenCurrSeqLen(rw, loc, currentSeqLen, numRepeats);
252+
out = moveNumHeadsToSeqLenOut(rw, loc, out, numRepeats, splitKV);
253+
if (lse)
254+
lse = moveNumHeadsToSeqLenOut(rw, loc, lse, numRepeats, splitKV);
255+
}
256+
257+
return std::make_tuple(numRepeatsAttr, queries, keys, values, out, lse,
258+
currentSeqLen);
259+
}
260+
111261
template <typename Op>
112262
static LogicalResult
113263
computeGridSizeAttentionGemmElmtGemm(ConversionPatternRewriter &rw, Op op,
@@ -314,6 +464,7 @@ static LogicalResult commonAttentionGemmElmtGemm(
314464
Value b, Value c, Value out, Value lse, Value currentSeqLen,
315465
UnitAttr causal, IntegerAttr splitKV, ValueRange elementwiseInputs,
316466
Region &preSecondOpRegion, bool enableSoftmax, TypeAttr softmaxType,
467+
int64_t numHeadsQ, int64_t numHeadsKV,
317468
std::optional<std::reference_wrapper<const BufferDependencyAnalysis>>
318469
bufferDeps) {
319470
Location loc = op->getLoc();
@@ -363,6 +514,15 @@ static LogicalResult commonAttentionGemmElmtGemm(
363514
std::tie(a, b, c, out) = maybeSplitk.value();
364515
}
365516

517+
int64_t splitKVNum = splitKV.getInt();
518+
519+
// Grouped-Query Attention (GQA)
520+
IntegerAttr numRepeatsGQA = nullptr;
521+
if (enableSoftmax)
522+
std::tie(numRepeatsGQA, a, b, c, out, lse, currentSeqLen) =
523+
processGQA(rw, op.getLoc(), a, b, c, out, lse, currentSeqLen, numHeadsQ,
524+
numHeadsKV, splitKVNum);
525+
366526
// Note, matrix dimension correctness is handled in the verifier
367527
ArrayRef<int64_t> aShape = cast<MemRefType>(a.getType()).getShape();
368528
ArrayRef<int64_t> bShape = cast<MemRefType>(b.getType()).getShape();
@@ -374,7 +534,6 @@ static LogicalResult commonAttentionGemmElmtGemm(
374534
GemmSize gemm1Size(/*g=*/aShape[0], /*m=*/cShape[2],
375535
/*k=*/cShape[1],
376536
/*n=*/aShape[2]);
377-
int64_t splitKVNum = splitKV.getInt();
378537
GemmSize gemm0ExtraPad = requiredPadding(params0, gemm0Size, 1, splitKVNum)
379538
.value_or(GemmSize{0, 0, 0, 0});
380539
GemmSize gemm1ExtraPad = requiredPadding(params1, gemm1Size, splitKVNum)
@@ -417,8 +576,9 @@ static LogicalResult commonAttentionGemmElmtGemm(
417576
loc, a, b, c, elementwiseInputs, currentSeqLen, out, lse, causal, splitKV,
418577
op.getGemmFeaturesAttr(), op.getStoreMethodAttr(), blockSizeAttr,
419578
gridSizeAttr,
420-
/*disableQBypassLDS=*/nullptr, prePadG0MAttr, prePadG0NAttr, softmaxType,
421-
params0, params1, rw.getDenseI64ArrayAttr(op.getFirstGemmIndices()),
579+
/*disableQBypassLDS=*/nullptr, prePadG0MAttr, prePadG0NAttr,
580+
numRepeatsGQA, softmaxType, params0, params1,
581+
rw.getDenseI64ArrayAttr(op.getFirstGemmIndices()),
422582
rw.getBoolAttr(enableSoftmax));
423583
bool linalgOpFound = false;
424584
preSecondOpRegion.walk(
@@ -777,7 +937,8 @@ AttentionRewritePattern::matchAndRewrite(AttentionOp op,
777937
adaptor.getOut(), adaptor.getLse(), adaptor.getCurrentSeqLen(),
778938
adaptor.getCausalAttr(), adaptor.getSplitKVAttr(),
779939
adaptor.getPreSoftmaxElemWiseInputs(), op.getPreSoftmaxBody(),
780-
/*enableSoftmax=*/true, op.getSoftmaxTypeAttr(),
940+
/*enableSoftmax=*/true, op.getSoftmaxTypeAttr(), adaptor.getNumHeadsQ(),
941+
adaptor.getNumHeadsKV(),
781942
/*bufferDeps=*/std::nullopt);
782943
}
783944

@@ -790,7 +951,8 @@ LogicalResult GemmElementwiseGemmRewritePattern::matchAndRewrite(
790951
/*lse=*/nullptr,
791952
/*currentSeqLen=*/nullptr, /*causal=*/nullptr, splitKV,
792953
adaptor.getElemwiseInputs(), op.getPreSecondGemmBody(),
793-
/*enableSoftmax=*/false, /*softmaxType=*/nullptr, std::cref(bufferDeps));
954+
/*enableSoftmax=*/false, /*softmaxType=*/nullptr, /*numHeadsQ=*/1,
955+
/*numHeadsKV=*/1, std::cref(bufferDeps));
794956
}
795957

796958
void RockGemmToGridwisePass::runOnOperation() {

0 commit comments

Comments
 (0)