|
29 | 29 | #include "mlir/Dialect/Rock/utility/builderUtils.h" |
30 | 30 | #include "mlir/Dialect/Rock/utility/loweringUtils.h" |
31 | 31 | #include "mlir/Dialect/Rock/utility/tosaUtils.h" |
| 32 | +#include "mlir/Dialect/Rock/utility/transformMapUtils.h" |
32 | 33 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
33 | 34 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
34 | 35 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
|
52 | 53 | #include "llvm/Support/Debug.h" |
53 | 54 | #include "llvm/Support/LogicalResult.h" |
54 | 55 | #include "llvm/Support/raw_ostream.h" |
| 56 | +#include <tuple> |
55 | 57 | #include <utility> |
56 | 58 |
|
57 | 59 | #define DEBUG_TYPE "convert-tosa-to-rock" |
@@ -790,7 +792,7 @@ static Value insertBroadcast(Value inp, ArrayRef<int64_t> outShape, |
790 | 792 | return rock::TransformOp::create(b, loc, inp, broadcastDims.get()); |
791 | 793 | } |
792 | 794 |
|
793 | | -static FailureOr<Value> mulBroadcast(Value val); |
| 795 | +static FailureOr<Value> mulBroadcast(Value val, bool skipCollapseExpand = true); |
794 | 796 |
|
795 | 797 | static FailureOr<Value> getValueSkipping(Value val, |
796 | 798 | const DenseSet<StringRef> &opsToSkip) { |
@@ -820,9 +822,12 @@ getDefiningOpSkipping(Value val, const DenseSet<StringRef> &opsToSkip) { |
820 | 822 | return result; |
821 | 823 | } |
822 | 824 |
|
823 | | -static FailureOr<Value> mulBroadcast(Value val) { |
| 825 | +static FailureOr<Value> mulBroadcast(Value val, bool skipCollapseExpand) { |
824 | 826 | DenseSet<StringRef> opsToSkip{tensor::CollapseShapeOp::getOperationName(), |
825 | 827 | tensor::ExpandShapeOp::getOperationName()}; |
| 828 | + if (!skipCollapseExpand) |
| 829 | + opsToSkip.clear(); |
| 830 | + |
826 | 831 | auto maybeMul = getDefiningOpSkipping<tosa::MulOp>(val, opsToSkip); |
827 | 832 | if (succeeded(maybeMul)) { |
828 | 833 | auto mul = maybeMul.value(); |
@@ -2322,6 +2327,216 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> { |
2322 | 2327 | return mul.getOutput(); |
2323 | 2328 | } |
2324 | 2329 |
|
| 2330 | + FailureOr<std::pair<int64_t, int64_t>> getNumHeadsGQA(Value value, |
| 2331 | + bool isQ) const { |
| 2332 | + // this size is = batch*numHeads |
| 2333 | + auto collapse = value.getDefiningOp<tensor::CollapseShapeOp>(); |
| 2334 | + if (!collapse) |
| 2335 | + return failure(); |
| 2336 | + |
| 2337 | + auto reassociationIdx = collapse.getReassociationIndices(); |
| 2338 | + |
| 2339 | + // expected to reshape to three dimensions (input to tosa.matmul) |
| 2340 | + if (reassociationIdx.size() != 3) |
| 2341 | + return failure(); |
| 2342 | + size_t expectedGroupSize = isQ ? 2 : 3; |
| 2343 | + if (reassociationIdx[0].size() != expectedGroupSize || |
| 2344 | + reassociationIdx[1].size() != 1 || reassociationIdx[2].size() != 1) |
| 2345 | + return failure(); |
| 2346 | + |
| 2347 | + // group size must match groupSizeQ |
| 2348 | + int64_t count = 0; |
| 2349 | + for (const auto &reassociation : reassociationIdx) { |
| 2350 | + for (auto idx : reassociation) { |
| 2351 | + if (count != idx) |
| 2352 | + return failure(); |
| 2353 | + count++; |
| 2354 | + } |
| 2355 | + } |
| 2356 | + |
| 2357 | + auto reshapeInputShape = |
| 2358 | + cast<ShapedType>(collapse.getSrc().getType()).getShape(); |
| 2359 | + // we expect the input to be batch x num_heads x D x K (or K x D) |
| 2360 | + size_t expectedSize = isQ ? 4 : 5; |
| 2361 | + if (reshapeInputShape.size() != expectedSize) |
| 2362 | + return failure(); |
| 2363 | + |
| 2364 | + int64_t batch = reshapeInputShape[0]; |
| 2365 | + int64_t numHeads = reshapeInputShape[1]; |
| 2366 | + return std::make_pair(batch, numHeads); |
| 2367 | + } |
| 2368 | + |
| 2369 | + LogicalResult checkBroadcastGQA(Value value, int64_t expectedRepeat) const { |
| 2370 | + auto collapse = value.getDefiningOp<tensor::CollapseShapeOp>(); |
| 2371 | + if (!collapse) |
| 2372 | + return failure(); |
| 2373 | + Value collapseVal = collapse.getSrc(); |
| 2374 | + |
| 2375 | + auto maybeNonOne = mulBroadcast(collapseVal, /*skipCollapseExpand=*/false); |
| 2376 | + if (failed(maybeNonOne)) |
| 2377 | + return failure(); |
| 2378 | + |
| 2379 | + // we should be doing batch x num_heads x 1 x D x K -> batch x num_heads x |
| 2380 | + // REPEAT x D x K |
| 2381 | + Value nonOne = maybeNonOne.value(); |
| 2382 | + auto shapeBeforeBroadcast = cast<ShapedType>(nonOne.getType()).getShape(); |
| 2383 | + auto shapeAfterBroadcast = |
| 2384 | + cast<ShapedType>( |
| 2385 | + collapseVal.getDefiningOp<tosa::MulOp>().getOutput().getType()) |
| 2386 | + .getShape(); |
| 2387 | + if (shapeBeforeBroadcast.size() != shapeAfterBroadcast.size()) |
| 2388 | + return failure(); |
| 2389 | + |
| 2390 | + // we expect five dimensions |
| 2391 | + if (shapeBeforeBroadcast.size() != 5) |
| 2392 | + return failure(); |
| 2393 | + |
| 2394 | + // dimension we are broadcasting |
| 2395 | + if (shapeBeforeBroadcast[2] != 1 || |
| 2396 | + shapeAfterBroadcast[2] != expectedRepeat) |
| 2397 | + return failure(); |
| 2398 | + |
| 2399 | + // rest of dimensions must be the same |
| 2400 | + for (size_t idx = 0; idx < shapeBeforeBroadcast.size(); idx++) { |
| 2401 | + if (idx != 2 && shapeBeforeBroadcast[idx] != shapeAfterBroadcast[idx]) |
| 2402 | + return failure(); |
| 2403 | + } |
| 2404 | + |
| 2405 | + return success(); |
| 2406 | + } |
| 2407 | + |
| 2408 | + FailureOr<Value> sliceTensorGQA(PatternRewriter &rewriter, Value value, |
| 2409 | + int64_t batch, int64_t numHeads, |
| 2410 | + int64_t repeat) const { |
| 2411 | + Location loc = value.getLoc(); |
| 2412 | + ArrayRef<int64_t> shape = cast<ShapedType>(value.getType()).getShape(); |
| 2413 | + if (shape.size() != 3) |
| 2414 | + return failure(); |
| 2415 | + |
| 2416 | + if (shape[0] != (batch * numHeads * repeat)) |
| 2417 | + return failure(); |
| 2418 | + |
| 2419 | + // reshape group x D x K -> batch x num_heads x repeat x D x K |
| 2420 | + rock::BottomUpTMBuilder unmergeDims(rewriter, {"group", "dim0", "dim1"}, |
| 2421 | + shape, loc); |
| 2422 | + unmergeDims.unmerge({"batch", "num_heads", "repeat"}, {0, 1, 2}, "group", |
| 2423 | + {batch, numHeads, repeat}); |
| 2424 | + unmergeDims.passThrough({3, 4}, {1, 2}); |
| 2425 | + rock::TransformMapAttr unmergeDimsAttr = unmergeDims.get(); |
| 2426 | + |
| 2427 | + // slice repeat to 1 |
| 2428 | + auto sliceRepeat = |
| 2429 | + rock::BottomUpTMBuilder::above(unmergeDims, unmergeDimsAttr); |
| 2430 | + sliceRepeat.slice({"repeat"}, {"repeat"}, {0}, {1}); |
| 2431 | + sliceRepeat.passThrough({"batch", "num_heads", "dim0", "dim1"}); |
| 2432 | + rock::TransformMapAttr sliceRepeatAttr = sliceRepeat.get(); |
| 2433 | + |
| 2434 | + // reshape back to group/repeat x D x K |
| 2435 | + auto finalMerge = |
| 2436 | + rock::BottomUpTMBuilder::above(sliceRepeat, sliceRepeatAttr); |
| 2437 | + finalMerge.merge("group", 0, {"batch", "num_heads", "repeat"}); |
| 2438 | + finalMerge.passThrough({"dim0", "dim1"}, {1, 2}, {"dim0", "dim1"}); |
| 2439 | + rock::TransformMapAttr finalMergeAttr = finalMerge.get(); |
| 2440 | + |
| 2441 | + ArrayAttr transformsAttr = rewriter.getArrayAttr( |
| 2442 | + {finalMergeAttr, sliceRepeatAttr, unmergeDimsAttr}); |
| 2443 | + return rock::transform(rewriter, value, transformsAttr); |
| 2444 | + } |
| 2445 | + |
| 2446 | + /* |
| 2447 | + This tries to identify if GQA is used, and undoes the broadcast. The expected |
| 2448 | + IR is: |
| 2449 | +
|
| 2450 | + // clang-format off |
| 2451 | + ``` |
| 2452 | + %q = tensor.collapse %q [[0, 1], [2], [3]] : tensor<1x32x1x128xf16> into |
| 2453 | + tensor<32x1x128xf16> |
| 2454 | +
|
| 2455 | + // broadcast from numHeadsK, 1 -> numHeadsK, repeat where |
| 2456 | + numHeadsQ=numHeadsK*repeat %k = tosa.mul %k, constant=1, constant=0 : |
| 2457 | + (tensor<1x8x1x128x64xf16>, tensor<1x8x4x128x64xf16>, tensor<1xi8>) -> |
| 2458 | + tensor<1x8x4x128x64xf16> |
| 2459 | + // collapse batch, numHeadsK and repeat into group dimension, |
| 2460 | + group=batch*numHeadsK*repeat %k = tensor.collapse_shape %k [[0, 1, 2], [3], |
| 2461 | + [4]] : tensor<1x8x4x128x64xf16> into tensor<32x128x64xf16> |
| 2462 | +
|
| 2463 | + %v = same transforms as %k |
| 2464 | + rock.attention(%q, %k, %v) |
| 2465 | + ``` |
| 2466 | + // clang-format on |
| 2467 | +
|
| 2468 | + Note that if we identify the GQA pattern, we slice the K and V tensors |
| 2469 | + and pass numHeadsQ and numHeadsKV to rock.attention. Otherwise, K and V |
| 2470 | + tensors are left untouched and numHeadsQ=1, numHeadsKV=1. |
| 2471 | + */ |
| 2472 | + std::tuple<Value, Value, Value, IntegerAttr, IntegerAttr> |
| 2473 | + getGQAValues(PatternRewriter &rewriter, Value queries, Value keys, |
| 2474 | + Value values) const { |
| 2475 | + // default values in case GQA is not pattern matched |
| 2476 | + IntegerAttr numHeadsQAttr = rewriter.getI32IntegerAttr(1); |
| 2477 | + IntegerAttr numHeadsKVAttr = rewriter.getI32IntegerAttr(1); |
| 2478 | + auto defaultValues = |
| 2479 | + std::make_tuple(queries, keys, values, numHeadsQAttr, numHeadsKVAttr); |
| 2480 | + |
| 2481 | + FailureOr<std::pair<int64_t, int64_t>> reshapeQResults = |
| 2482 | + getNumHeadsGQA(queries, true); |
| 2483 | + if (failed(reshapeQResults)) |
| 2484 | + return defaultValues; |
| 2485 | + int64_t batchQ = reshapeQResults->first; |
| 2486 | + int64_t numHeadsQ = reshapeQResults->second; |
| 2487 | + |
| 2488 | + FailureOr<std::pair<int64_t, int64_t>> reshapeKResults = |
| 2489 | + getNumHeadsGQA(keys, false); |
| 2490 | + if (failed(reshapeKResults)) |
| 2491 | + return defaultValues; |
| 2492 | + int64_t batchK = reshapeKResults->first; |
| 2493 | + int64_t numHeadsK = reshapeKResults->second; |
| 2494 | + |
| 2495 | + FailureOr<std::pair<int64_t, int64_t>> reshapeVResults = |
| 2496 | + getNumHeadsGQA(values, false); |
| 2497 | + if (failed(reshapeVResults)) |
| 2498 | + return defaultValues; |
| 2499 | + int64_t batchV = reshapeVResults->first; |
| 2500 | + int64_t numHeadsV = reshapeVResults->second; |
| 2501 | + |
| 2502 | + // batch must be equal for all tensors |
| 2503 | + if (batchQ != batchK || batchQ != batchV) |
| 2504 | + return defaultValues; |
| 2505 | + |
| 2506 | + // num heads of K and V must be equal |
| 2507 | + if (numHeadsK != numHeadsV) |
| 2508 | + return defaultValues; |
| 2509 | + |
| 2510 | + // numHeadsQ must be divisible by numHeadsKV |
| 2511 | + if (numHeadsQ % numHeadsK != 0) |
| 2512 | + return defaultValues; |
| 2513 | + |
| 2514 | + int64_t expectedRepeat = numHeadsQ / numHeadsK; |
| 2515 | + // check we are doing the expected broadcast for K and V |
| 2516 | + LogicalResult kCorrect = checkBroadcastGQA(keys, expectedRepeat); |
| 2517 | + LogicalResult vCorrect = checkBroadcastGQA(values, expectedRepeat); |
| 2518 | + if (failed(kCorrect) || failed(vCorrect)) |
| 2519 | + return defaultValues; |
| 2520 | + |
| 2521 | + // update keys and values (slicing the repeats) |
| 2522 | + auto maybeKeys = |
| 2523 | + sliceTensorGQA(rewriter, keys, batchK, numHeadsK, expectedRepeat); |
| 2524 | + auto maybeValues = |
| 2525 | + sliceTensorGQA(rewriter, values, batchV, numHeadsV, expectedRepeat); |
| 2526 | + if (failed(maybeKeys) || failed(maybeValues)) |
| 2527 | + return defaultValues; |
| 2528 | + |
| 2529 | + keys = maybeKeys.value(); |
| 2530 | + values = maybeValues.value(); |
| 2531 | + |
| 2532 | + numHeadsQAttr = rewriter.getI32IntegerAttr(numHeadsQ); |
| 2533 | + numHeadsKVAttr = rewriter.getI32IntegerAttr(numHeadsK); |
| 2534 | + LLVM_DEBUG(llvm::dbgs() << "Found GQA pattern, numHeadsQ=" << numHeadsQ |
| 2535 | + << " numHeadsKV=" << numHeadsK << "\n"); |
| 2536 | + return std::make_tuple(queries, keys, values, numHeadsQAttr, |
| 2537 | + numHeadsKVAttr); |
| 2538 | + } |
| 2539 | + |
2325 | 2540 | FailureOr<AttentionMatcherValues> match(tosa::MatMulOp op) const { |
2326 | 2541 | Value softmaxOutput = op.getA(); |
2327 | 2542 | DenseSet<StringRef> expandAndCollapse{ |
@@ -2517,13 +2732,16 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> { |
2517 | 2732 | attentionMatcherValues.preSoftmaxElementwiseFinder; |
2518 | 2733 | int64_t firstGemmBlockIndex = elemwiseRegion.getFirstGemmBlockIndex(); |
2519 | 2734 |
|
2520 | | - // TODO: numHeadsQ and numHeadsKV migraphx integration |
| 2735 | + IntegerAttr numHeadsQ, numHeadsKV; |
| 2736 | + Value queries, keys, values; |
| 2737 | + std::tie(queries, keys, values, numHeadsQ, numHeadsKV) = getGQAValues( |
| 2738 | + rewriter, firstMatMulOp.getA(), firstMatMulOp.getB(), op.getB()); |
| 2739 | + |
2521 | 2740 | rock::AttentionOp attnOp = rock::AttentionOp::create( |
2522 | | - rewriter, loc, outputType, lseType, firstMatMulOp.getA(), |
2523 | | - firstMatMulOp.getB(), op.getB(), elementwiseOtherArgs, currentSeqLen, |
2524 | | - output, lseOut, |
2525 | | - /*numHeadsQ=*/rewriter.getI32IntegerAttr(1), |
2526 | | - /*numHeadsKV=*/rewriter.getI32IntegerAttr(1), |
| 2741 | + rewriter, loc, outputType, lseType, queries, keys, values, |
| 2742 | + elementwiseOtherArgs, currentSeqLen, output, lseOut, |
| 2743 | + /*numHeadsQ=*/numHeadsQ, |
| 2744 | + /*numHeadsKV=*/numHeadsKV, |
2527 | 2745 | /*qTransposed=*/nullptr, |
2528 | 2746 | /*kTransposed=*/nullptr, |
2529 | 2747 | /*vTransposed=*/nullptr, |
|
0 commit comments