|
17 | 17 | #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
18 | 18 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
19 | 19 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 20 | +#include "mlir/Dialect/Vector/IR/VectorOps.h" |
20 | 21 | #include "mlir/IR/Builders.h" |
21 | 22 | #include "mlir/IR/BuiltinTypes.h" |
22 | 23 | #include "mlir/IR/Diagnostics.h" |
|
28 | 29 | #include "llvm/ADT/DenseMap.h" |
29 | 30 | #include "llvm/ADT/TypeSwitch.h" |
30 | 31 |
|
| 32 | +#include <cstdint> |
31 | 33 | #include <limits> |
32 | 34 | #include <optional> |
33 | 35 |
|
@@ -631,6 +633,146 @@ LogicalResult TransposeLoadOp::verify() { |
631 | 633 | return success(); |
632 | 634 | } |
633 | 635 |
|
| 636 | +//===----------------------------------------------------------------------===// |
| 637 | +// ScaledMFMAOp |
| 638 | +//===----------------------------------------------------------------------===// |
| 639 | + |
| 640 | +namespace { |
| 641 | +/// Check if the scales input is used in other scaled mfma's while they exist. |
| 642 | +/// If theyre unused then pack the scales. |
| 643 | +struct PackScales final : OpRewritePattern<ScaledMFMAOp> { |
| 644 | + using OpRewritePattern::OpRewritePattern; |
| 645 | + |
| 646 | + LogicalResult matchAndRewrite(ScaledMFMAOp op, |
| 647 | + PatternRewriter &rewriter) const override { |
| 648 | + Location loc = op.getLoc(); |
| 649 | + // If this use of a scale has a non zero opsel, packing has already been |
| 650 | + // done. |
| 651 | + auto checkIfUnpackable = [&](OpOperand &op) { |
| 652 | + if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) { |
| 653 | + switch (op.getOperandNumber()) { |
| 654 | + case 3: |
| 655 | + return smfma.getScalesIdxA() != 0; |
| 656 | + break; |
| 657 | + case 4: |
| 658 | + return smfma.getScalesIdxB() != 0; |
| 659 | + break; |
| 660 | + default: |
| 661 | + return true; |
| 662 | + break; |
| 663 | + } |
| 664 | + } |
| 665 | + }; |
| 666 | + |
| 667 | + auto setOpsel = [&](unsigned idx, int64_t val) { |
| 668 | + switch (idx) { |
| 669 | + case 3: |
| 670 | + return op.setScalesIdxA(val); |
| 671 | + break; |
| 672 | + case 4: |
| 673 | + return op.setScalesIdxB(val); |
| 674 | + break; |
| 675 | + default: |
| 676 | + break; |
| 677 | + } |
| 678 | + }; |
| 679 | + |
| 680 | + // Obtain flat index from offsets and shape. |
| 681 | + auto getIdxFromExtract = [](vector::ExtractOp op) { |
| 682 | + ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType()); |
| 683 | + int cumul = 1; |
| 684 | + int idx = 0; |
| 685 | + for (auto [offset, size] : |
| 686 | + reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) { |
| 687 | + idx += offset * cumul; |
| 688 | + cumul *= size; |
| 689 | + } |
| 690 | + return idx; |
| 691 | + }; |
| 692 | + |
| 693 | + // Obtain offsets for new shape from flat index. |
| 694 | + auto getOffsetsFromIdx = [](int64_t idx, Type ty) { |
| 695 | + SmallVector<int64_t> res; |
| 696 | + ShapedType shapedty = static_cast<ShapedType>(ty); |
| 697 | + int64_t numElements = shapedty.getNumElements(); |
| 698 | + for (auto size : shapedty.getShape()) { |
| 699 | + numElements /= size; |
| 700 | + res.push_back(idx / numElements); |
| 701 | + idx -= (idx / numElements) * size; |
| 702 | + } |
| 703 | + return res; |
| 704 | + }; |
| 705 | + |
| 706 | + // For every scale operand of this ScaledMFMAOp, if the scale follows the |
| 707 | + // following pattern: |
| 708 | + // |
| 709 | + // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU> |
| 710 | + // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU> |
| 711 | + // amdgpu.scaled_mfma(%scale[0] * ... |
| 712 | + // |
| 713 | + // rewrite to: |
| 714 | + // |
| 715 | + // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU> |
| 716 | + // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU> |
| 717 | + // amdgpu.scaled_mfma(%scale[0-3] * ... |
| 718 | + // |
| 719 | + // This creates duplicate shape_casts for every use but these will be removed in CSE. |
| 720 | + for (auto opIdx : SmallVector<int64_t>({3, 4})) { |
| 721 | + auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>(); |
| 722 | + if (!insertOp) { |
| 723 | + return failure(); |
| 724 | + } |
| 725 | + if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) { |
| 726 | + return failure(); |
| 727 | + } |
| 728 | + |
| 729 | + auto extractOp = |
| 730 | + insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>(); |
| 731 | + if (!extractOp) { |
| 732 | + return failure(); |
| 733 | + } |
| 734 | + |
| 735 | + Value scaleSrc = extractOp.getOperand(0); |
| 736 | + auto stype = dyn_cast<ShapedType>(scaleSrc.getType()); |
| 737 | + if (!stype) { |
| 738 | + return failure(); |
| 739 | + } |
| 740 | + // We do not handle dynamic dims yet, assume that the input is padded to |
| 741 | + // a static shape now. |
| 742 | + if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()), |
| 743 | + [&](int64_t i) { return stype.isDynamicDim(i); })) { |
| 744 | + return failure(); |
| 745 | + } |
| 746 | + |
| 747 | + int64_t numElements = stype.getNumElements(); |
| 748 | + if (numElements <= 4) { |
| 749 | + return failure(); |
| 750 | + } |
| 751 | + |
| 752 | + Type newSrcType = VectorType::get( |
| 753 | + SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType()); |
| 754 | + Value newScaleSrc = |
| 755 | + rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc); |
| 756 | + int64_t idx = getIdxFromExtract(extractOp); |
| 757 | + SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType)); |
| 758 | + auto scaleTy = VectorType::get({4}, stype.getElementType()); |
| 759 | + Value extract = rewriter.create<vector::ExtractStridedSliceOp>( |
| 760 | + loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0}, |
| 761 | + SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1}); |
| 762 | + Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract); |
| 763 | + op.setOperand(opIdx, scale); |
| 764 | + setOpsel(opIdx, offsets[1]); |
| 765 | + } |
| 766 | + return success(); |
| 767 | + } |
| 768 | +}; |
| 769 | +} // namespace |
| 770 | + |
| 771 | +void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| 772 | + MLIRContext *context) { |
| 773 | + results.add<PackScales>(context); |
| 774 | +} |
| 775 | + |
634 | 776 | #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc" |
635 | 777 |
|
636 | 778 | #define GET_ATTRDEF_CLASSES |
|
0 commit comments