|
10 | 10 | #include "triton/Dialect/Triton/IR/Dialect.h" |
11 | 11 | #include "triton/Dialect/TritonGPU/IR/Dialect.h" |
12 | 12 | #include "triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h" |
| 13 | +#include "triton/Dialect/TritonGPU/Transforms/LayoutPropagationUtility.h" |
13 | 14 | #include "triton/Tools/LayoutUtils.h" |
| 15 | +#include "triton/Tools/LinearLayout.h" |
14 | 16 | #include "llvm/ADT/TypeSwitch.h" |
15 | 17 |
|
16 | 18 | namespace tt = mlir::triton; |
17 | 19 | namespace ttg = mlir::triton::gpu; |
18 | 20 | using ::mlir::LLVM::AMD::isChainDotHead; |
19 | 21 | using ::mlir::LLVM::AMD::isChainDotTail; |
20 | | -using ::mlir::LLVM::AMD::scaleDotElemTypeToMLIRType; |
21 | | -using mlir::triton::gpu::chooseScaledMfmaScaleLayout; |
| 22 | + |
| 23 | +#undef DEBUG_TYPE |
| 24 | +#define DEBUG_TYPE "tritonamd-accelerate-matmul" |
22 | 25 |
|
23 | 26 | namespace mlir { |
24 | 27 |
|
@@ -217,6 +220,8 @@ FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion, |
217 | 220 |
|
218 | 221 | FailureOr<MfmaIntrinsic> chooseMfmaInstruction(tt::DotScaledOp dot, |
219 | 222 | int mfmaVersion, int nonKDim) { |
| 223 | + using ::mlir::LLVM::AMD::scaleDotElemTypeToMLIRType; |
| 224 | + |
220 | 225 | auto ctx = dot.getContext(); |
221 | 226 | int64_t inputKDim = dot.getA().getType().getShape().back(); |
222 | 227 | if (dot.getAElemType() == ScaleDotElemType::E2M1 && dot.getLhsKPack()) { |
@@ -779,55 +784,72 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> { |
779 | 784 | } |
780 | 785 | }; |
781 | 786 |
|
782 | | -template <typename Op> Op getDefOpBeforeConvertLayout(Value op) { |
783 | | - while (auto cvtOp = op.getDefiningOp<ttg::ConvertLayoutOp>()) { |
784 | | - op = cvtOp.getSrc(); |
785 | | - } |
786 | | - return op.getDefiningOp<Op>(); |
787 | | -} |
788 | | - |
789 | | -bool isScaleShuffled(Value scale) { |
| 787 | +// Figure out a best tilesPerWarp parameter that gives largest vector size for |
| 788 | +// global load for the given |scale| tensor feeding into dot_scaled op. Returns |
| 789 | +// the largest vector size and writes the choice to |result|. |
| 790 | +int deduceTilesPerWarp(TypedValue<RankedTensorType> scale, unsigned opIdx, |
| 791 | + unsigned nonKDim, ArrayRef<unsigned> warpsPerCTA, |
| 792 | + SmallVectorImpl<unsigned> *result) { |
| 793 | + std::array<unsigned, 2> chosen{1, 1}; |
| 794 | + int vecSize = 1; |
790 | 795 | if (!scale) { |
791 | | - return false; |
| 796 | + result->assign(chosen.begin(), chosen.end()); |
| 797 | + return vecSize; |
792 | 798 | } |
793 | 799 |
|
794 | | - auto shape = cast<RankedTensorType>(scale.getType()).getShape(); |
795 | | - |
796 | | - int rank = shape.size(); |
797 | | - int blockNonK = shape[rank - 2]; |
798 | | - // 1 scale always scales 32 elements along K dim |
799 | | - int blockK = shape[rank - 1] * 32; |
800 | | - |
801 | | - auto reshapeOp2D = getDefOpBeforeConvertLayout<triton::ReshapeOp>(scale); |
802 | | - if (!reshapeOp2D || reshapeOp2D.getType().getShape() != shape) { |
803 | | - return false; |
804 | | - } |
805 | | - |
806 | | - const std::array<int, 7> transposeOrder{0, 5, 3, 1, 4, 2, 6}; |
807 | | - auto transOp = |
808 | | - getDefOpBeforeConvertLayout<triton::TransOp>(reshapeOp2D.getSrc()); |
809 | | - if (!transOp || transOp.getOrder() != ArrayRef<int>(transposeOrder)) { |
810 | | - return false; |
811 | | - } |
812 | | - |
813 | | - const std::array<int64_t, 7> reshape7DShape{ |
814 | | - blockNonK / 32, blockK / 32 / 8, 4, 16, 2, 2, 1}; |
815 | | - auto reshapeOp7D = |
816 | | - getDefOpBeforeConvertLayout<triton::ReshapeOp>(transOp.getSrc()); |
817 | | - |
818 | | - if (!reshapeOp7D || |
819 | | - reshapeOp7D.getType().getShape() != ArrayRef<int64_t>(reshape7DShape)) { |
820 | | - return false; |
821 | | - } |
822 | | - |
823 | | - return true; |
824 | | -} |
825 | | - |
826 | | -SmallVector<unsigned, 2> getTilesPerWarp(Value aScale, Value bScale) { |
827 | | - if (isScaleShuffled(aScale) || isScaleShuffled(bScale)) { |
828 | | - return {2, 2}; |
| 800 | + // Source code have flexibility to preshuffle scale tensor to achieve better |
| 801 | + // global load vectorization. That preshuffle scheme is conveyed via some |
| 802 | + // tl.reshape and tl.trans op combinations. Instead of hardcoding one case or |
| 803 | + // pattern match the op chain here, we try certain scale tensor layouts and |
| 804 | + // see which one gives us better vectorization when pushed upwards to the |
| 805 | + // global load. |
| 806 | + // |
| 807 | + // For 16x16x128 scaled MFMA intrinsic, each thread only reads one i8 value. |
| 808 | + // For better vectorization, we prefer to stick 2x2 such intrinsic together so |
| 809 | + // each thread can read 4xi8 values. |
| 810 | + SmallVector<std::array<unsigned, 2>, 2> choices{{2, 2}, {1, 1}}; |
| 811 | + for (const auto &choice : choices) { |
| 812 | + LLVM_DEBUG(llvm::dbgs() |
| 813 | + << "choice: [" << choice[0] << ", " << choice[1] << "]\n"); |
| 814 | + LinearLayout layout = ttg::chooseScaledMfmaScaleLayout( |
| 815 | + scale.getContext(), opIdx, scale.getType().getShape(), nonKDim, choice, |
| 816 | + warpsPerCTA); |
| 817 | + LLVM_DEBUG(llvm::dbgs() << "trying scale layout: " << layout << "\n"); |
| 818 | + |
| 819 | + // Infer source layout used for global load using the current scale layout. |
| 820 | + auto loadLayoutPair = |
| 821 | + ttg::inferSourceLoadLayout(layout, scale.getDefiningOp()); |
| 822 | + if (!loadLayoutPair) |
| 823 | + continue; |
| 824 | + tt::LoadOp loadOp = loadLayoutPair->first; |
| 825 | + const LinearLayout &inferredLayout = loadLayoutPair->second; |
| 826 | + LLVM_DEBUG(llvm::dbgs() |
| 827 | + << "inferred load layout: " << inferredLayout << "\n"); |
| 828 | + |
| 829 | + auto loadType = cast<RankedTensorType>(loadOp.getType()); |
| 830 | + auto loadOrder = ttg::getOrder(loadType); |
| 831 | + auto loadCTALayout = ttg::getCTALayout(loadType.getEncoding()); |
| 832 | + |
| 833 | + // Reuse existing shared memory vectorization utilities by constructing a |
| 834 | + // pass through layout that does linear element mapping. |
| 835 | + MLIRContext *context = scale.getContext(); |
| 836 | + auto passThruShared = ttg::SwizzledSharedEncodingAttr::get( |
| 837 | + context, 1, 1, 1, loadOrder, loadCTALayout); |
| 838 | + auto sharedLL = |
| 839 | + triton::gpu::toLinearLayout(loadType.getShape(), passThruShared); |
| 840 | + auto composedLL = inferredLayout.invertAndCompose(sharedLL).flattenOuts(); |
| 841 | + auto [v, _] = |
| 842 | + largestVectorisation(context, composedLL, /*bitwidth=*/8, std::nullopt); |
| 843 | + |
| 844 | + if (v > vecSize) { |
| 845 | + LLVM_DEBUG(llvm::dbgs() << "found vector size: " << v << "\n"); |
| 846 | + chosen = choice; |
| 847 | + vecSize = v; |
| 848 | + break; |
| 849 | + } |
829 | 850 | } |
830 | | - return {1, 1}; |
| 851 | + result->assign(chosen.begin(), chosen.end()); |
| 852 | + return vecSize; |
831 | 853 | } |
832 | 854 |
|
833 | 855 | class DecomposeAMDScaledBlocked final : public ttg::DecomposeScaledBlocked { |
@@ -968,34 +990,18 @@ class ScaledBlockedToScaledMFMAF8F6F4 final |
968 | 990 | auto warpsPerTile = |
969 | 991 | warpsPerTileMFMA(dotOp, oldShape, numWarps, {mDim, nDim}); |
970 | 992 |
|
971 | | - // For scale tensor preshuffling, the minimum block size is 32x32x256. |
972 | | - // When using MFMA16 instructions, each warp should compute two MFMA ops |
973 | | - // along the non-K dimension. To support this, we must set tilesPerWarp to |
974 | | - // {2, 2}. Failing to do so won't break correctness, but it will prevent |
975 | | - // vectorized local_loads, as the data each thread needs won't be contiguous |
976 | | - // due to the shuffle pattern. This requirement doesn’t apply to MFMA32 |
977 | | - // instructions, since only one MFMA op spans the non-K dimension at the |
978 | | - // minimal shuffling size. |
979 | | - SmallVector<unsigned> tilesPerWarp = getTilesPerWarp(aScale, bScale); |
980 | | - |
981 | | - if (rank == 3) { |
982 | | - tilesPerWarp.insert(tilesPerWarp.begin(), 1); |
983 | | - } |
| 993 | + SmallVector<unsigned, 2> tilesA{1, 1}, tilesB{1, 1}, tilesPerWarp; |
| 994 | + int vecA = deduceTilesPerWarp(aScale, 0, mDim, warpsPerTile, &tilesA); |
| 995 | + int vecB = deduceTilesPerWarp(bScale, 1, mDim, warpsPerTile, &tilesB); |
| 996 | + tilesPerWarp = vecA > vecB ? tilesA : tilesB; |
| 997 | + LLVM_DEBUG(llvm::dbgs() << "chosen tilesPerWarp: [" << tilesPerWarp[0] |
| 998 | + << ", " << tilesPerWarp[1] << "]\n"); |
984 | 999 |
|
985 | 1000 | // Always use transposed mfma layout. This enables larger vectorization |
986 | 1001 | // for global store instructions. |
987 | | - mlir::Attribute mfmaEnc; |
988 | | - if (llvm::any_of(tilesPerWarp, [](int x) { return x != 1; })) { |
989 | | - mfmaEnc = ttg::AMDMfmaEncodingAttr::get( |
990 | | - ctx, /*verison=*/mfmaVersion, warpsPerTile, tilesPerWarp, |
991 | | - /*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout, |
992 | | - oldRetType.getElementType()); |
993 | | - } else { |
994 | | - mfmaEnc = ttg::AMDMfmaEncodingAttr::get( |
995 | | - ctx, /*verison=*/mfmaVersion, warpsPerTile, |
996 | | - /*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout, |
997 | | - oldRetType.getElementType()); |
998 | | - } |
| 1002 | + mlir::Attribute mfmaEnc = ttg::AMDMfmaEncodingAttr::get( |
| 1003 | + ctx, /*verison=*/mfmaVersion, warpsPerTile, tilesPerWarp, mDim, nDim, |
| 1004 | + /*isTransposed=*/true, ctaLayout, oldRetType.getElementType()); |
999 | 1005 |
|
1000 | 1006 | auto newRetType = |
1001 | 1007 | RankedTensorType::get(oldShape, oldRetType.getElementType(), mfmaEnc); |
@@ -1097,7 +1103,7 @@ class ScaledBlockedToScaledMFMAF8F6F4 final |
1097 | 1103 | shape = llvm::to_vector(scale.getType().getShape()); |
1098 | 1104 | } |
1099 | 1105 |
|
1100 | | - LinearLayout newLL = chooseScaledMfmaScaleLayout( |
| 1106 | + LinearLayout newLL = ttg::chooseScaledMfmaScaleLayout( |
1101 | 1107 | ctx, idx, shape, mDim, tilesPerWarp, warpsPerTile); |
1102 | 1108 |
|
1103 | 1109 | Attribute newScaleEncoding = ttg::LinearEncodingAttr::get(ctx, newLL); |
@@ -1515,7 +1521,6 @@ struct TritonAMDGPUAccelerateMatmulPass |
1515 | 1521 | using Base::Base; |
1516 | 1522 |
|
1517 | 1523 | void runOnOperation() override { |
1518 | | - |
1519 | 1524 | MLIRContext *context = &getContext(); |
1520 | 1525 | ModuleOp m = getOperation(); |
1521 | 1526 |
|
|
0 commit comments