From e1460b1954c6267e4c5c3f318aa288a45c7431e9 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Tue, 22 Oct 2024 10:09:10 +0000 Subject: [PATCH 1/3] Revert "Revert "[BACKEND] Small fixes for dot operand properties (#4895)"" This reverts commit 25a7cbad82beb0d81283263d3d3885174bd290b3. --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 57 ++++--- lib/Dialect/TritonGPU/IR/Dialect.cpp | 152 ++++++++++-------- .../SharedToDotOperandMFMA.cpp | 8 +- .../SharedToDotOperandWMMA.cpp | 4 +- .../TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp | 6 +- .../TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp | 4 +- .../AccelerateAMDMatmul.cpp | 20 ++- .../SharedToDotOperandMMAv2.cpp | 4 +- .../DotOpToLLVM/MMAv2.cpp | 6 +- 9 files changed, 147 insertions(+), 114 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 67da5b30fe..45777d40dd 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -781,22 +781,24 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { InterfaceMethod<"Return shape per CTA.", "SmallVector", - "getShapePerCTATileForDotOperands", + "getShapePerCTATileForOperand", (ins "ArrayRef":$tensorShape, - "unsigned":$opIdx)>, + "int":$kWidth, + "int":$opIdx)>, InterfaceMethod<"Return total element size per thread for dot operands.", "unsigned", - "getTotalElemsPerThreadForOperands", + "getTotalElemsPerThreadForOperand", (ins "ArrayRef":$tensorShape, "Type":$eltTy, - "unsigned":$kWidth, - "unsigned":$opIdx)>, + "int":$kWidth, + "int":$opIdx)>, InterfaceMethod<"Return size per thread for dot operands.", "SmallVector", - "getSizePerThreadForOperands", - (ins "unsigned":$opIdx)>, + "getSizePerThreadForOperand", + (ins "int":$opIdx, + "int":$kWidth)>, InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector", "getElemsPerThreadForOperands", (ins "ArrayRef":$tensorShape, @@ -919,11 +921,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, bool supportReduction() const { return true; } - SmallVector getSizePerThreadForOperands(unsigned opIdx) const; - SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; - unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; - SmallVector getMFMAInstrShapeForOperands(int kWidth, int opIdx) const; - SmallVector getMFMARepForOperands(ArrayRef operandShape, int kWidth, int opIdx) const; + SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; + unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; + SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; SmallVector getContigPerThread() { auto rank = getWarpsPerCTA().size(); @@ -1030,12 +1032,12 @@ Row | warp 0 warp 2 bool supportReduction() const { return true; } - SmallVector getSizePerThreadForOperands(unsigned opIdx) const; - SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; - unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; + unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getElemsPerInstrForOperands() const; - SmallVector getRepForOperands(ArrayRef operandShape, - Type elemType, int kWidth, int opIdx) const; + SmallVector getRepForOperand(ArrayRef operandShape, + Type elemType, int kWidth, int opIdx) const; static SmallVector getMNKDimPerInstr(); SmallVector getContigPerThread() { @@ -1235,8 +1237,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: SmallVector getMMAv1Rep(int opIdx) const; SmallVector getMMAv1ShapePerWarp(int opIdx) const; int getMMAv1Vec(int opIdx) const; - SmallVector getMMAv2Rep(ArrayRef shape, - int bitwidth, int opIdx) const; + SmallVector getMMAv2RepForOperand(ArrayRef shape, + int bitwidth, int kWidth, int opIdx) const; bool supportReduction() const { if (isAmpere() || isHopper()) { @@ -1244,9 +1246,9 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: } return false; }; - SmallVector getSizePerThreadForOperands(unsigned opIdx) const; - SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; - unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; + unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getContigPerThread() { assert(isVolta() || isAmpere() || isHopper()); @@ -1361,7 +1363,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim let genVerifyDecl = 1; let extraClassDeclaration = extraDistributedDeclaration # [{ SmallVector getContigPerThread() { - return getSizePerThread(); + auto rank = getWarpsPerCTA().size(); + assert(rank == 2 || rank == 3); + SmallVector contigPerThread(rank, 1); + auto kWidth = getKWidth(); + assert(kWidth != 0 && "Do not support kWidth=0"); + if (getOpIdx() == 0) + contigPerThread[rank - 1] = kWidth; + else + contigPerThread[rank - 2] = kWidth; + return contigPerThread; }; }]; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 3574329c28..e5feec2b2c 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -252,11 +252,20 @@ SmallVector getWarpOrder(Attribute layout) { order.erase(it); order.insert(order.begin(), 0); } + } else if (auto dotOpLayout = dyn_cast(layout)) { + // opIdx=0: [/*dim0*/batch, /*dim1=*/m, /*dim2=*/k] -> order=[1, 2, 0] + // opIdx=1: [/*dim0*/batch, /*dim1=*/k, /*dim2=*/n] -> order=[2, 1, 0] + std::iota(order.rbegin(), order.rend(), 0); + if (dotOpLayout.getOpIdx() == 0) { + std::swap(order[0], order[1]); + } } return order; } SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank) { + assert((rank == 2 || rank == 3) && + "Invalid rank for dot operand order computation"); SmallVector order(rank); // The 'order' field typically represents a descending sorted array of // dimensions based on contiguity. For instance, in axisInfo utilities that @@ -265,14 +274,16 @@ SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank) { // // The relation between contiguity and order is only relevant if the layout // interfaces with HBM, as is the case when we load tensor from HBM to - // registers in the dot layout to bypass LDS. When bypassing LDS, we make the - // following assumptions about tensor layouts: + // registers in the dot layout to bypass LDS. When bypassing LDS, we make + // the following assumptions about tensor layouts: // - Tensor A (opIdx == 0) is considered to be row-major. // - Tensor B (opIdx == 1) is considered to be column-major. // // Based on these assumptions, we define the following orders: - // - For opIdx == 0, we assume an order of [1, 0]. - // - For opIdx == 1, we assume an order of [0, 1]. + // - For opIdx == 0, batch=dim0, m=dim1, and k=dim2, we assume an order of [2, + // 1, 0] for 3D tensors. + // - For opIdx == 1, batch=dim0, k=dim1, and n=dim2, we assume an order of [1, + // 2, 0] for 3D tensors. std::iota(order.rbegin(), order.rend(), 0); if (opIdx == 1) { std::swap(order[0], order[1]); @@ -293,13 +304,7 @@ SmallVector getOrder(Attribute layout) { } if (auto dotLayout = dyn_cast(layout)) { auto rank = getWarpsPerCTA(dotLayout.getParent()).size(); - SmallVector order(rank); - if (isa(dotLayout.getParent())) { - return getOrderForDotOperand(dotLayout.getOpIdx(), rank); - } else { - std::iota(order.rbegin(), order.rend(), 0); - } - return order; + return getOrderForDotOperand(dotLayout.getOpIdx(), rank); } if (auto sliceLayout = dyn_cast(layout)) { SmallVector parentOrder = getOrder(sliceLayout.getParent()); @@ -344,8 +349,6 @@ SmallVector getCTAsPerCGA(Attribute layout) { ArrayRef ref; if (auto distributedLayout = mlir::dyn_cast(layout)) return distributedLayout.getCTAsPerCGA(); - else if (mlir::isa(layout)) - return {1, 1}; else if (auto sharedLayout = mlir::dyn_cast(layout)) ref = sharedLayout.getCTALayout().getCTAsPerCGA(); else @@ -358,9 +361,6 @@ SmallVector getCTASplitNum(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { return distributedLayout.getCTASplitNum(); - } else if (mlir::isa(layout)) { - res.resize(2); - res[0] = res[1] = 1; } else if (auto sharedLayout = mlir::dyn_cast(layout)) { res.assign(sharedLayout.getCTALayout().getCTASplitNum().begin(), sharedLayout.getCTALayout().getCTASplitNum().end()); @@ -375,8 +375,6 @@ SmallVector getCTAOrder(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { res = distributedLayout.getCTAOrder(); - } else if (mlir::isa(layout)) { - return {0, 1}; } else if (auto sharedLayout = mlir::dyn_cast(layout)) { res = SmallVector(sharedLayout.getCTALayout().getCTAOrder()); } else { @@ -400,9 +398,9 @@ SmallVector getShapePerCTA(ArrayRef CTASplitNum, SmallVector getShapePerCTA(Attribute layout, ArrayRef shape) { if (auto sharedLayout = mlir::dyn_cast(layout)) { // Special logic for pipeline pass, where shape is 3D and CTALayout is 2D. - // The first dim of shape is numStages. This is a work around, otherwise too - // many places would have to be modified in pipeline pass. Maybe we need to - // refactor this logic in the future. + // The first dim of shape is numStages. This is a work around, otherwise + // too many places would have to be modified in pipeline pass. Maybe we + // need to refactor this logic in the future. auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum(); if (shape.size() == CTASplitNum.size() + 1) { auto res = getShapePerCTA(CTASplitNum, shape.drop_front()); @@ -425,7 +423,8 @@ unsigned getNumWarpsPerCTA(Attribute layout) { else if (auto sliceLayout = dyn_cast(layout)) return getNumWarpsPerCTA(sliceLayout.getParent()); else if (auto mmaLayout = dyn_cast(layout)) { - // Use the distributed layout interface to get the number of warps per CTA. + // Use the distributed layout interface to get the number of warps per + // CTA. auto distributedLayout = cast(layout); warpsPerCTA = distributedLayout.getWarpsPerCTA(); } else if (auto mfmaLayout = dyn_cast(layout)) @@ -459,9 +458,9 @@ bool hasDotOperandEncoding(Value value) { } bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { - // If the new elements per thread is less than the old one, we will need to do - // convert encoding that goes through shared memory anyway. So we consider it - // as expensive. + // If the new elements per thread is less than the old one, we will need to + // do convert encoding that goes through shared memory anyway. So we + // consider it as expensive. RankedTensorType tensorTy = cat.getType(); auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy); auto shape = tensorTy.getShape(); @@ -982,7 +981,7 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, SmallVector elemsPerThread(rank); auto kWidth = getKWidth(); - auto rep = parent.getMFMARepForOperands(shape, kWidth, idx); + auto rep = parent.getRepForOperand(shape, kWidth, idx); if (rank == 3) elemsPerThread[0] = rep[0]; @@ -1002,8 +1001,8 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { if (auto mmaParent = mlir::dyn_cast(getParent())) { - return mmaParent.getTotalElemsPerThreadForOperands(shape, eltTy, - getKWidth(), getOpIdx()); + return mmaParent.getTotalElemsPerThreadForOperand(shape, eltTy, getKWidth(), + getOpIdx()); } if (auto blockedLayout = mlir::dyn_cast(getParent())) { auto shapePerCTA = getShapePerCTA(*this, shape); @@ -1067,8 +1066,8 @@ SmallVector DotOperandEncodingAttr::getShapePerCTATile( auto parentLayout = getParent(); assert(parentLayout && "DotOperandEncodingAttr must have a parent"); if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { - return parentMmaLayout.getShapePerCTATileForDotOperands(tensorShape, - getOpIdx()); + return parentMmaLayout.getShapePerCTATileForOperand( + tensorShape, getKWidth(), getOpIdx()); } else { llvm::report_fatal_error( "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " @@ -1668,7 +1667,7 @@ SmallVector AMDMfmaEncodingAttr::getSizePerThread() const { } SmallVector -AMDMfmaEncodingAttr::getMFMAInstrShapeForOperands(int kWidth, int opIdx) const { +AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const { unsigned mDim = getMDim(); unsigned nDim = getNDim(); assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) || @@ -1688,9 +1687,9 @@ AMDMfmaEncodingAttr::getMFMAInstrShapeForOperands(int kWidth, int opIdx) const { } SmallVector -AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, - int kWidth, int opIdx) const { - auto operandTileShape = getMFMAInstrShapeForOperands(kWidth, opIdx); +AMDMfmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + int kWidth, int opIdx) const { + auto operandTileShape = getInstrShapeForOperand(kWidth, opIdx); auto rank = operandShape.size(); auto warpsPerCTA = getWarpsPerCTA(); int numRepBatch = @@ -1711,27 +1710,31 @@ AMDMfmaEncodingAttr::getMFMARepForOperands(ArrayRef operandShape, } } -unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperands( +unsigned AMDMfmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { - auto rep = getMFMARepForOperands(shape, kWidth, opIdx); + auto rep = getRepForOperand(shape, kWidth, opIdx); return product(rep) * kWidth; } SmallVector -AMDMfmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { +AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { + auto rank = getWarpsPerCTA().size(); + auto sizePerThread = SmallVector(rank, 1); if (opIdx == 0) { - return {4, 1}; + sizePerThread[rank - 2] = 1; + sizePerThread[rank - 1] = kWidth; } else if (opIdx == 1) { - return {1, 4}; + sizePerThread[rank - 2] = kWidth; + sizePerThread[rank - 1] = 1; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - return {}; } + return sizePerThread; } SmallVector -AMDMfmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, - int opIdx) const { +AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, + int kWidth, int opIdx) const { assert(getMDim() == 32 || getMDim() == 16); auto parentShapePerCTATile = getShapePerCTATile(shape); auto rank = parentShapePerCTATile.size(); @@ -1801,7 +1804,7 @@ SmallVector AMDWmmaEncodingAttr::getSizePerThread() const { return sizePerThread; } SmallVector -AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { +AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { auto rank = getWarpsPerCTA().size(); SmallVector sizePerThread(rank, 1); auto numReplicated = getVersion() == 1 ? 2 : 1; @@ -1820,8 +1823,8 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { } SmallVector -AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, - int opIdx) const { +AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, + int kWidth, int opIdx) const { auto parentShapePerCTA = getShapePerCTATile(shape); auto rank = shape.size(); assert(rank == 2); @@ -1834,9 +1837,9 @@ AMDWmmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, } } -unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperands( +unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { - auto rep = getRepForOperands(shape, eltTy, kWidth, opIdx); + auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx); return product(rep) * kWidth; } @@ -1845,9 +1848,9 @@ SmallVector AMDWmmaEncodingAttr::getElemsPerInstrForOperands() const { } SmallVector -AMDWmmaEncodingAttr::getRepForOperands(ArrayRef operandShape, - Type elemType, int kWidth, - int opIdx) const { +AMDWmmaEncodingAttr::getRepForOperand(ArrayRef operandShape, + Type elemType, int kWidth, + int opIdx) const { auto operandTileShape = getElemsPerInstrForOperands(); assert(operandTileShape.size() == 2); auto warpsPerCTA = getWarpsPerCTA(); @@ -2038,9 +2041,8 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { return 2 * getMMAv1Rep(opIdx)[opIdx]; } -SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, - int bitwidth, - int opIdx) const { +SmallVector NvidiaMmaEncodingAttr::getMMAv2RepForOperand( + ArrayRef shape, int bitwidth, int kWidth, int opIdx) const { auto rank = shape.size(); auto warpsPerCTA = getWarpsPerCTA(); SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; @@ -2063,7 +2065,7 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef shape, warpsPerCTA[rank - 1]))}; } } -unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( +unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto shapePerCTA = getShapePerCTA(*this, shape); int warpsPerCTAM = getWarpsPerCTA()[0]; @@ -2074,7 +2076,8 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( } // A100 if (isAmpere()) { - auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx); + auto rep = getMMAv2RepForOperand(shapePerCTA, eltTy.getIntOrFloatBitWidth(), + kWidth, opIdx); if (opIdx == 0) return 4 * rep[0] * rep[1] * rep[2]; if (opIdx == 1) @@ -2142,43 +2145,58 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands( } llvm_unreachable("unknown mma layout"); } -SmallVector -NvidiaMmaEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, - int opIdx) const { +SmallVector NvidiaMmaEncodingAttr::getShapePerCTATileForOperand( + ArrayRef shape, int kWidth, int opIdx) const { assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); auto parentShapePerCTATile = getShapePerCTATile(shape); auto rank = parentShapePerCTATile.size(); + // 4 threads * 2 subtiles + unsigned kWidthTile = kWidth * 2 * 4; if (opIdx == 0) { if (rank == 2) - return {parentShapePerCTATile[rank - 2], 16}; + return {parentShapePerCTATile[rank - 2], kWidthTile}; else - return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 16}; + return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], + kWidthTile}; } else if (opIdx == 1) { if (rank == 2) - return {16, parentShapePerCTATile[rank - 1]}; + return {kWidthTile, parentShapePerCTATile[rank - 1]}; else - return {parentShapePerCTATile[0], 16, parentShapePerCTATile[rank - 1]}; + return {parentShapePerCTATile[0], kWidthTile, + parentShapePerCTATile[rank - 1]}; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } } SmallVector -NvidiaMmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { +NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); + auto rank = getWarpsPerCTA().size(); + auto sizePerThread = SmallVector(rank, 1); if (opIdx == 0) { - return {2, 4}; + sizePerThread[rank - 2] = 2; + sizePerThread[rank - 1] = 2 * kWidth; } else if (opIdx == 1) { - return {4, 1}; + sizePerThread[rank - 2] = 2 * kWidth; + sizePerThread[rank - 1] = 1; } else { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - return {}; } + return sizePerThread; } //===----------------------------------------------------------------------===// // DotOperand Encoding //===----------------------------------------------------------------------===// SmallVector DotOperandEncodingAttr::getThreadsPerWarp() const { + auto parent = getParent(); + if (auto mma = mlir::dyn_cast(parent)) { + auto threadsPerWarp = mma.getThreadsPerWarp(); + auto rank = threadsPerWarp.size(); + if (getOpIdx() == 1) + std::swap(threadsPerWarp[rank - 2], threadsPerWarp[rank - 1]); + return threadsPerWarp; + } llvm::report_fatal_error( "getThreadsPerWarp not implemented for DotOperandEncodingAttr"); } @@ -2186,7 +2204,7 @@ SmallVector DotOperandEncodingAttr::getSizePerThread() const { auto parentLayout = getParent(); assert(parentLayout && "DotOperandEncodingAttr must have a parent"); if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { - return parentMmaLayout.getSizePerThreadForOperands(getOpIdx()); + return parentMmaLayout.getSizePerThreadForOperand(getKWidth(), getOpIdx()); } else { llvm::report_fatal_error( "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 7c05ffede2..c8df2ac993 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -217,12 +217,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto elemTy = aTensorTy.getElementType(); auto kWidth = encoding.getKWidth(); - auto elemsPerInstr = mfmaLayout.getMFMAInstrShapeForOperands(kWidth, opIdx); + auto elemsPerInstr = mfmaLayout.getInstrShapeForOperand(kWidth, opIdx); int64_t mfmaInstrNonK; int64_t mfmaInstrK; // TODO(Lixun): make it simpler - // getMFMAInstrShapeForOperands always returns a 2D vector + // getInstrShapeForOperand always returns a 2D vector if (rank == 3) { mfmaInstrNonK = elemsPerInstr[nonKDimIdx - 1]; mfmaInstrK = elemsPerInstr[kDimIdx - 1]; @@ -231,12 +231,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, mfmaInstrK = elemsPerInstr[kDimIdx]; } - auto numReps = mfmaLayout.getMFMARepForOperands(shape, kWidth, opIdx); + auto numReps = mfmaLayout.getRepForOperand(shape, kWidth, opIdx); auto numRepNonK = numReps[nonKDimIdx]; auto numRepK = numReps[kDimIdx]; auto repB = numReps[0]; // TODO(Lixun): make it simpler - // getMFMARepForOperands always returns a 3D vector + // getRepForOperand always returns a 3D vector if (rank == 2) { numRepNonK = numReps[nonKDimIdx + 1]; numRepK = numReps[kDimIdx + 1]; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp index 6cdcddad39..b60c86e1a3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandWMMA.cpp @@ -165,7 +165,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, auto wmmaInstrNonK = elemsPerInstr[opIdx == 0 ? 0 : 1]; assert(wmmaInstrNonK == 16); - auto numReps = wmmaLayout.getRepForOperands(shape, elemTy, kWidth, opIdx); + auto numReps = wmmaLayout.getRepForOperand(shape, elemTy, kWidth, opIdx); auto numRepNonK = numReps[opIdx == 0 ? 1 : 2]; auto numRepK = numReps[opIdx == 0 ? 2 : 1]; auto repB = numReps[0]; @@ -176,7 +176,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Value linearWaveId = udiv(thread, waveSize); unsigned numElemsPerThreadPerRep = - wmmaLayout.getSizePerThreadForOperands(opIdx)[kDimIdx]; + wmmaLayout.getSizePerThreadForOperand(kWidth, opIdx)[kDimIdx]; Value lane = urem(thread, waveSize); unsigned int maxNumWarps = shape[nonKDimIdx] / wmmaInstrNonK; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp index c190711f10..204d54894d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -194,10 +194,8 @@ struct DotOpMFMAConversionHelper { int kWidth = aEncoding.getKWidth(); auto rank = aTensorTy.getShape().size(); - auto repA = - mfmaLayout.getMFMARepForOperands(aTensorTy.getShape(), kWidth, 0); - auto repB = - mfmaLayout.getMFMARepForOperands(bTensorTy.getShape(), kWidth, 1); + auto repA = mfmaLayout.getRepForOperand(aTensorTy.getShape(), kWidth, 0); + auto repB = mfmaLayout.getRepForOperand(bTensorTy.getShape(), kWidth, 1); assert(repA[2] == repB[1]); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 9f575be082..5a003f7688 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -263,9 +263,9 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, int kWidth = aEncoding.getKWidth(); auto repA = - wmmaLayout.getRepForOperands(aTensorTy.getShape(), elemTy, kWidth, 0); + wmmaLayout.getRepForOperand(aTensorTy.getShape(), elemTy, kWidth, 0); auto repB = - wmmaLayout.getRepForOperands(bTensorTy.getShape(), elemTy, kWidth, 1); + wmmaLayout.getRepForOperand(bTensorTy.getShape(), elemTy, kWidth, 1); assert(repA[2] == repB[1]); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 895ddf67e2..6f93bfee99 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -637,14 +637,18 @@ class BlockedToWMMA : public OpRewritePattern { auto newAcc = convertAndCastTensor(rewriter, oldAcc, wmmaEnc, operandTypes[2]); - auto newAType = RankedTensorType::get( - aShape, operandTypes[0], - ttg::DotOperandEncodingAttr::get( - ctx, 0, wmmaEnc, wmmaEnc.getSizePerThreadForOperands(0)[rank - 1])); - auto newBType = RankedTensorType::get( - bShape, operandTypes[1], - ttg::DotOperandEncodingAttr::get( - ctx, 1, wmmaEnc, wmmaEnc.getSizePerThreadForOperands(1)[rank - 2])); + auto newAType = + RankedTensorType::get(aShape, operandTypes[0], + ttg::DotOperandEncodingAttr::get( + ctx, 0, wmmaEnc, + wmmaEnc.getSizePerThreadForOperand( + /*kWidth=*/0, /*opIdx=*/0)[rank - 1])); + auto newBType = + RankedTensorType::get(bShape, operandTypes[1], + ttg::DotOperandEncodingAttr::get( + ctx, 1, wmmaEnc, + wmmaEnc.getSizePerThreadForOperand( + /*kWidth=*/0, /*opIdx=*/1)[rank - 2])); Value castedA = convertAndCastTensor(rewriter, a, newAType.getEncoding(), operandTypes[0]); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index 1abb0c5216..21c2bee584 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -603,9 +603,9 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; - auto numRep = - mmaLayout.getMMAv2Rep(shapePerCTA, bitwidth, encoding.getOpIdx()); int kWidth = encoding.getKWidth(); + auto numRep = mmaLayout.getMMAv2RepForOperand(shapePerCTA, bitwidth, kWidth, + encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); auto order = triton::gpu::getOrder(mmaLayout); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 258b8fc261..05c2e7fd10 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -392,10 +392,12 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); auto dotOpA = cast(aTensorTy.getEncoding()); auto repA = cast(dotOpA.getParent()) - .getMMAv2Rep(aShapePerCTA, bitwidth, dotOpA.getOpIdx()); + .getMMAv2RepForOperand(aShapePerCTA, bitwidth, + dotOpA.getKWidth(), dotOpA.getOpIdx()); auto dotOpB = cast(bTensorTy.getEncoding()); auto repB = cast(dotOpB.getParent()) - .getMMAv2Rep(bShapePerCTA, bitwidth, dotOpB.getOpIdx()); + .getMMAv2RepForOperand(bShapePerCTA, bitwidth, + dotOpB.getKWidth(), dotOpB.getOpIdx()); assert(repA[2] == repB[1]); assert(repA[0] == repB[0]); From 4b7942a96c73b83b04225286a34a6aa95ca62011 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Tue, 22 Oct 2024 10:15:19 +0000 Subject: [PATCH 2/3] [intel] Small fixes for dot operand properties Signed-off-by: Whitney Tsang --- .../TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td | 6 +++--- .../intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index c4eeb496ae..b6d1fdd109 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -82,10 +82,10 @@ along the row (resp. col) dimension. SmallVector getShapeB() const; SmallVector getShapeC() const; SmallVector getDPASRepetitions(ArrayRef shape, int opIdx) const; - SmallVector getSizePerThreadForOperands(unsigned opIdx) const; + SmallVector getSizePerThreadForOperand(int kWidth,unsigned opIdx) const; SmallVector getElemsPerThreadForOperands(ArrayRef shape, Type eltTy, unsigned opIdx) const; - SmallVector getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const; - unsigned getTotalElemsPerThreadForOperands(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; + unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; bool supportReduction() const { return true; diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 28e374932b..af81751dda 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -247,7 +247,7 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef shape, int opIdx) const { warpsPerCTA[rank - 1]))}; } -unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands( +unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, mlir::Type eltTy, int kWidth, int opIdx) const { auto shapePerCTA = getShapePerCTA(*this, shape); auto rep = getDPASRepetitions(shapePerCTA, opIdx); @@ -298,8 +298,8 @@ SmallVector DpasEncodingAttr::getThreadsPerWarp() const { } SmallVector -DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, - int opIdx) const { +DpasEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, + int kWidth, int opIdx) const { auto parentShapePerCTATile = getShapePerCTATile(shape); size_t rank = parentShapePerCTATile.size(); assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout"); @@ -325,7 +325,7 @@ DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, } SmallVector -DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { +DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const { ArrayRef repCluster = getRepCluster(); size_t rank = repCluster.size(); assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout"); @@ -367,7 +367,7 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { SmallVector DpasEncodingAttr::getElemsPerThreadForOperands( ArrayRef shape, Type eltTy, unsigned opIdx) const { - SmallVector sizePerThread = getSizePerThreadForOperands(opIdx); + SmallVector sizePerThread = getSizePerThreadForOperand(0, opIdx); SmallVector repetitions = getDPASRepetitions(shape, opIdx); size_t rank = shape.size(); From c637c07c975d557b2213742ce5d5a7f4c29ea4c5 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 21 Oct 2024 21:45:51 +0000 Subject: [PATCH 3/3] fix order for DPAS Signed-off-by: Anatoly Myachev --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index e5feec2b2c..3a1fd05ab8 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -304,6 +304,16 @@ SmallVector getOrder(Attribute layout) { } if (auto dotLayout = dyn_cast(layout)) { auto rank = getWarpsPerCTA(dotLayout.getParent()).size(); + // FIXME: delete if branch for `DpasEncodingAttr` and provide more + // general solution to make `getOrderForDotOperand` function compatible + // with Intel layouts. + // More details: + // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517 + if (dyn_cast(dotLayout.getParent())) { + SmallVector order(rank); + std::iota(order.rbegin(), order.rend(), 0); + return order; + } return getOrderForDotOperand(dotLayout.getOpIdx(), rank); } if (auto sliceLayout = dyn_cast(layout)) {