Skip to content

Commit 2e0be81

Browse files
authored
GQA optimization migraphx integration (#2093)
Pattern match GQA from migraphx IR
1 parent e116a04 commit 2e0be81

File tree

8 files changed

+826
-176
lines changed

8 files changed

+826
-176
lines changed

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 226 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/Dialect/Rock/utility/builderUtils.h"
3030
#include "mlir/Dialect/Rock/utility/loweringUtils.h"
3131
#include "mlir/Dialect/Rock/utility/tosaUtils.h"
32+
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"
3233
#include "mlir/Dialect/Tensor/IR/Tensor.h"
3334
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
3435
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
@@ -52,6 +53,7 @@
5253
#include "llvm/Support/Debug.h"
5354
#include "llvm/Support/LogicalResult.h"
5455
#include "llvm/Support/raw_ostream.h"
56+
#include <tuple>
5557
#include <utility>
5658

5759
#define DEBUG_TYPE "convert-tosa-to-rock"
@@ -790,7 +792,7 @@ static Value insertBroadcast(Value inp, ArrayRef<int64_t> outShape,
790792
return rock::TransformOp::create(b, loc, inp, broadcastDims.get());
791793
}
792794

793-
static FailureOr<Value> mulBroadcast(Value val);
795+
static FailureOr<Value> mulBroadcast(Value val, bool skipCollapseExpand = true);
794796

795797
static FailureOr<Value> getValueSkipping(Value val,
796798
const DenseSet<StringRef> &opsToSkip) {
@@ -820,9 +822,12 @@ getDefiningOpSkipping(Value val, const DenseSet<StringRef> &opsToSkip) {
820822
return result;
821823
}
822824

823-
static FailureOr<Value> mulBroadcast(Value val) {
825+
static FailureOr<Value> mulBroadcast(Value val, bool skipCollapseExpand) {
824826
DenseSet<StringRef> opsToSkip{tensor::CollapseShapeOp::getOperationName(),
825827
tensor::ExpandShapeOp::getOperationName()};
828+
if (!skipCollapseExpand)
829+
opsToSkip.clear();
830+
826831
auto maybeMul = getDefiningOpSkipping<tosa::MulOp>(val, opsToSkip);
827832
if (succeeded(maybeMul)) {
828833
auto mul = maybeMul.value();
@@ -2322,6 +2327,216 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
23222327
return mul.getOutput();
23232328
}
23242329

2330+
FailureOr<std::pair<int64_t, int64_t>> getNumHeadsGQA(Value value,
2331+
bool isQ) const {
2332+
// this size is = batch*numHeads
2333+
auto collapse = value.getDefiningOp<tensor::CollapseShapeOp>();
2334+
if (!collapse)
2335+
return failure();
2336+
2337+
auto reassociationIdx = collapse.getReassociationIndices();
2338+
2339+
// expected to reshape to three dimensions (input to tosa.matmul)
2340+
if (reassociationIdx.size() != 3)
2341+
return failure();
2342+
size_t expectedGroupSize = isQ ? 2 : 3;
2343+
if (reassociationIdx[0].size() != expectedGroupSize ||
2344+
reassociationIdx[1].size() != 1 || reassociationIdx[2].size() != 1)
2345+
return failure();
2346+
2347+
// group size must match groupSizeQ
2348+
int64_t count = 0;
2349+
for (const auto &reassociation : reassociationIdx) {
2350+
for (auto idx : reassociation) {
2351+
if (count != idx)
2352+
return failure();
2353+
count++;
2354+
}
2355+
}
2356+
2357+
auto reshapeInputShape =
2358+
cast<ShapedType>(collapse.getSrc().getType()).getShape();
2359+
// we expect the input to be batch x num_heads x D x K (or K x D)
2360+
size_t expectedSize = isQ ? 4 : 5;
2361+
if (reshapeInputShape.size() != expectedSize)
2362+
return failure();
2363+
2364+
int64_t batch = reshapeInputShape[0];
2365+
int64_t numHeads = reshapeInputShape[1];
2366+
return std::make_pair(batch, numHeads);
2367+
}
2368+
2369+
LogicalResult checkBroadcastGQA(Value value, int64_t expectedRepeat) const {
2370+
auto collapse = value.getDefiningOp<tensor::CollapseShapeOp>();
2371+
if (!collapse)
2372+
return failure();
2373+
Value collapseVal = collapse.getSrc();
2374+
2375+
auto maybeNonOne = mulBroadcast(collapseVal, /*skipCollapseExpand=*/false);
2376+
if (failed(maybeNonOne))
2377+
return failure();
2378+
2379+
// we should be doing batch x num_heads x 1 x D x K -> batch x num_heads x
2380+
// REPEAT x D x K
2381+
Value nonOne = maybeNonOne.value();
2382+
auto shapeBeforeBroadcast = cast<ShapedType>(nonOne.getType()).getShape();
2383+
auto shapeAfterBroadcast =
2384+
cast<ShapedType>(
2385+
collapseVal.getDefiningOp<tosa::MulOp>().getOutput().getType())
2386+
.getShape();
2387+
if (shapeBeforeBroadcast.size() != shapeAfterBroadcast.size())
2388+
return failure();
2389+
2390+
// we expect five dimensions
2391+
if (shapeBeforeBroadcast.size() != 5)
2392+
return failure();
2393+
2394+
// dimension we are broadcasting
2395+
if (shapeBeforeBroadcast[2] != 1 ||
2396+
shapeAfterBroadcast[2] != expectedRepeat)
2397+
return failure();
2398+
2399+
// rest of dimensions must be the same
2400+
for (size_t idx = 0; idx < shapeBeforeBroadcast.size(); idx++) {
2401+
if (idx != 2 && shapeBeforeBroadcast[idx] != shapeAfterBroadcast[idx])
2402+
return failure();
2403+
}
2404+
2405+
return success();
2406+
}
2407+
2408+
FailureOr<Value> sliceTensorGQA(PatternRewriter &rewriter, Value value,
2409+
int64_t batch, int64_t numHeads,
2410+
int64_t repeat) const {
2411+
Location loc = value.getLoc();
2412+
ArrayRef<int64_t> shape = cast<ShapedType>(value.getType()).getShape();
2413+
if (shape.size() != 3)
2414+
return failure();
2415+
2416+
if (shape[0] != (batch * numHeads * repeat))
2417+
return failure();
2418+
2419+
// reshape group x D x K -> batch x num_heads x repeat x D x K
2420+
rock::BottomUpTMBuilder unmergeDims(rewriter, {"group", "dim0", "dim1"},
2421+
shape, loc);
2422+
unmergeDims.unmerge({"batch", "num_heads", "repeat"}, {0, 1, 2}, "group",
2423+
{batch, numHeads, repeat});
2424+
unmergeDims.passThrough({3, 4}, {1, 2});
2425+
rock::TransformMapAttr unmergeDimsAttr = unmergeDims.get();
2426+
2427+
// slice repeat to 1
2428+
auto sliceRepeat =
2429+
rock::BottomUpTMBuilder::above(unmergeDims, unmergeDimsAttr);
2430+
sliceRepeat.slice({"repeat"}, {"repeat"}, {0}, {1});
2431+
sliceRepeat.passThrough({"batch", "num_heads", "dim0", "dim1"});
2432+
rock::TransformMapAttr sliceRepeatAttr = sliceRepeat.get();
2433+
2434+
// reshape back to group/repeat x D x K
2435+
auto finalMerge =
2436+
rock::BottomUpTMBuilder::above(sliceRepeat, sliceRepeatAttr);
2437+
finalMerge.merge("group", 0, {"batch", "num_heads", "repeat"});
2438+
finalMerge.passThrough({"dim0", "dim1"}, {1, 2}, {"dim0", "dim1"});
2439+
rock::TransformMapAttr finalMergeAttr = finalMerge.get();
2440+
2441+
ArrayAttr transformsAttr = rewriter.getArrayAttr(
2442+
{finalMergeAttr, sliceRepeatAttr, unmergeDimsAttr});
2443+
return rock::transform(rewriter, value, transformsAttr);
2444+
}
2445+
2446+
/*
2447+
This tries to identify if GQA is used, and undoes the broadcast. The expected
2448+
IR is:
2449+
2450+
// clang-format off
2451+
```
2452+
%q = tensor.collapse %q [[0, 1], [2], [3]] : tensor<1x32x1x128xf16> into
2453+
tensor<32x1x128xf16>
2454+
2455+
// broadcast from numHeadsK, 1 -> numHeadsK, repeat where
2456+
numHeadsQ=numHeadsK*repeat %k = tosa.mul %k, constant=1, constant=0 :
2457+
(tensor<1x8x1x128x64xf16>, tensor<1x8x4x128x64xf16>, tensor<1xi8>) ->
2458+
tensor<1x8x4x128x64xf16>
2459+
// collapse batch, numHeadsK and repeat into group dimension,
2460+
group=batch*numHeadsK*repeat %k = tensor.collapse_shape %k [[0, 1, 2], [3],
2461+
[4]] : tensor<1x8x4x128x64xf16> into tensor<32x128x64xf16>
2462+
2463+
%v = same transforms as %k
2464+
rock.attention(%q, %k, %v)
2465+
```
2466+
// clang-format on
2467+
2468+
Note that if we identify the GQA pattern, we slice the K and V tensors
2469+
and pass numHeadsQ and numHeadsKV to rock.attention. Otherwise, K and V
2470+
tensors are left untouched and numHeadsQ=1, numHeadsKV=1.
2471+
*/
2472+
std::tuple<Value, Value, Value, IntegerAttr, IntegerAttr>
2473+
getGQAValues(PatternRewriter &rewriter, Value queries, Value keys,
2474+
Value values) const {
2475+
// default values in case GQA is not pattern matched
2476+
IntegerAttr numHeadsQAttr = rewriter.getI32IntegerAttr(1);
2477+
IntegerAttr numHeadsKVAttr = rewriter.getI32IntegerAttr(1);
2478+
auto defaultValues =
2479+
std::make_tuple(queries, keys, values, numHeadsQAttr, numHeadsKVAttr);
2480+
2481+
FailureOr<std::pair<int64_t, int64_t>> reshapeQResults =
2482+
getNumHeadsGQA(queries, true);
2483+
if (failed(reshapeQResults))
2484+
return defaultValues;
2485+
int64_t batchQ = reshapeQResults->first;
2486+
int64_t numHeadsQ = reshapeQResults->second;
2487+
2488+
FailureOr<std::pair<int64_t, int64_t>> reshapeKResults =
2489+
getNumHeadsGQA(keys, false);
2490+
if (failed(reshapeKResults))
2491+
return defaultValues;
2492+
int64_t batchK = reshapeKResults->first;
2493+
int64_t numHeadsK = reshapeKResults->second;
2494+
2495+
FailureOr<std::pair<int64_t, int64_t>> reshapeVResults =
2496+
getNumHeadsGQA(values, false);
2497+
if (failed(reshapeVResults))
2498+
return defaultValues;
2499+
int64_t batchV = reshapeVResults->first;
2500+
int64_t numHeadsV = reshapeVResults->second;
2501+
2502+
// batch must be equal for all tensors
2503+
if (batchQ != batchK || batchQ != batchV)
2504+
return defaultValues;
2505+
2506+
// num heads of K and V must be equal
2507+
if (numHeadsK != numHeadsV)
2508+
return defaultValues;
2509+
2510+
// numHeadsQ must be divisible by numHeadsKV
2511+
if (numHeadsQ % numHeadsK != 0)
2512+
return defaultValues;
2513+
2514+
int64_t expectedRepeat = numHeadsQ / numHeadsK;
2515+
// check we are doing the expected broadcast for K and V
2516+
LogicalResult kCorrect = checkBroadcastGQA(keys, expectedRepeat);
2517+
LogicalResult vCorrect = checkBroadcastGQA(values, expectedRepeat);
2518+
if (failed(kCorrect) || failed(vCorrect))
2519+
return defaultValues;
2520+
2521+
// update keys and values (slicing the repeats)
2522+
auto maybeKeys =
2523+
sliceTensorGQA(rewriter, keys, batchK, numHeadsK, expectedRepeat);
2524+
auto maybeValues =
2525+
sliceTensorGQA(rewriter, values, batchV, numHeadsV, expectedRepeat);
2526+
if (failed(maybeKeys) || failed(maybeValues))
2527+
return defaultValues;
2528+
2529+
keys = maybeKeys.value();
2530+
values = maybeValues.value();
2531+
2532+
numHeadsQAttr = rewriter.getI32IntegerAttr(numHeadsQ);
2533+
numHeadsKVAttr = rewriter.getI32IntegerAttr(numHeadsK);
2534+
LLVM_DEBUG(llvm::dbgs() << "Found GQA pattern, numHeadsQ=" << numHeadsQ
2535+
<< " numHeadsKV=" << numHeadsK << "\n");
2536+
return std::make_tuple(queries, keys, values, numHeadsQAttr,
2537+
numHeadsKVAttr);
2538+
}
2539+
23252540
FailureOr<AttentionMatcherValues> match(tosa::MatMulOp op) const {
23262541
Value softmaxOutput = op.getA();
23272542
DenseSet<StringRef> expandAndCollapse{
@@ -2517,13 +2732,16 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
25172732
attentionMatcherValues.preSoftmaxElementwiseFinder;
25182733
int64_t firstGemmBlockIndex = elemwiseRegion.getFirstGemmBlockIndex();
25192734

2520-
// TODO: numHeadsQ and numHeadsKV migraphx integration
2735+
IntegerAttr numHeadsQ, numHeadsKV;
2736+
Value queries, keys, values;
2737+
std::tie(queries, keys, values, numHeadsQ, numHeadsKV) = getGQAValues(
2738+
rewriter, firstMatMulOp.getA(), firstMatMulOp.getB(), op.getB());
2739+
25212740
rock::AttentionOp attnOp = rock::AttentionOp::create(
2522-
rewriter, loc, outputType, lseType, firstMatMulOp.getA(),
2523-
firstMatMulOp.getB(), op.getB(), elementwiseOtherArgs, currentSeqLen,
2524-
output, lseOut,
2525-
/*numHeadsQ=*/rewriter.getI32IntegerAttr(1),
2526-
/*numHeadsKV=*/rewriter.getI32IntegerAttr(1),
2741+
rewriter, loc, outputType, lseType, queries, keys, values,
2742+
elementwiseOtherArgs, currentSeqLen, output, lseOut,
2743+
/*numHeadsQ=*/numHeadsQ,
2744+
/*numHeadsKV=*/numHeadsKV,
25272745
/*qTransposed=*/nullptr,
25282746
/*kTransposed=*/nullptr,
25292747
/*vTransposed=*/nullptr,

0 commit comments

Comments
 (0)