From d36e57998e7aaf7495c5d916b51948b2a42ae5f3 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Wed, 20 Nov 2024 22:20:20 +0000 Subject: [PATCH 1/3] [LAYOUTS] Unify the implementation of getShapePerCTA (#5183) We unify it and simplify its API (it was taking an unused `shape` parameter). While doing this, we found that the previous implementation was incorrect at least for `AMDWmmaEncodingAttr`, as this layout was using the shape parameter. Interestingly enough the doc in the header file for this function noted that the function is indeed independent of the tensor shape, even though the function does take a shape as an input! https://github.com/triton-lang/triton/blob/0bd30a2f3192204c5a50d5ffde27ad8493f6c026/include/triton/Dialect/TritonGPU/IR/Dialect.h#L113-L114 --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 4 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 16 -- lib/Analysis/Allocation.cpp | 6 +- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 4 +- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 138 +++--------------- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 3 +- .../ExtractSliceOpToLLVM.cpp | 2 +- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 2 +- .../ConvertLayoutOpToLLVM.cpp | 1 - .../LoadStoreOpToLLVM.cpp | 2 +- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 2 - 12 files changed, 29 insertions(+), 153 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index a9b49448c1..f2715043d7 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -116,9 +116,7 @@ SmallVector getCTAOrder(Attribute layout); * (3) In the implementation of emitIndices, ShapePerCTATile will * be replicated or wrapped to fit ShapePerCTA. */ -SmallVector -getShapePerCTATile(Attribute layout, - ArrayRef tensorShape = ArrayRef()); +SmallVector getShapePerCTATile(Attribute layout); SmallVector getShapePerCTA(ArrayRef CTASplitNum, ArrayRef shape); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 07514126d2..e6be2f8332 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -502,11 +502,6 @@ We call each individual tile "rep". "SmallVector", "getCTASplitNum">, - InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA", - "SmallVector", - "getShapePerCTATile", - (ins "ArrayRef":$tensorShape)>, - InterfaceMethod<"Gets the number of contiguous elements per thread.", "SmallVector", "getContigPerThread">, @@ -565,7 +560,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, SmallVector getThreadOrder() const; SmallVector getSizePerThread() const; - SmallVector getShapePerCTATile(ArrayRef tensorShape = ArrayRef()) const; std::optional toLinearLayout(ArrayRef shape) const; }]; @@ -765,13 +759,6 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { "bool", "supportReduction">, - InterfaceMethod<"Return shape per CTA.", - "SmallVector", - "getShapePerCTATileForOperand", - (ins "ArrayRef":$tensorShape, - "int":$kWidth, - "int":$opIdx)>, - InterfaceMethod<"Return size per thread for dot operands.", "SmallVector", "getSizePerThreadForOperand", @@ -900,7 +887,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, return true; } SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; @@ -1008,7 +994,6 @@ Row | warp 0 warp 2 return true; } SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getElemsPerInstrForOperands() const; SmallVector getRepForOperand(ArrayRef operandShape, @@ -1140,7 +1125,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: return false; }; SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; SmallVector getContigPerThread() { assert(isAmpere() || isHopper()); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 02269c9aac..53897578aa 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -41,10 +41,8 @@ static SmallVector getRepShapeForCvt(RankedTensorType srcTy, auto srcShapePerCTA = gpu::getShapePerCTA(srcTy); auto dstShapePerCTA = gpu::getShapePerCTA(dstTy); - auto srcShapePerCTATile = - gpu::getShapePerCTATile(srcLayout, srcTy.getShape()); - auto dstShapePerCTATile = - gpu::getShapePerCTATile(dstLayout, dstTy.getShape()); + auto srcShapePerCTATile = gpu::getShapePerCTATile(srcLayout); + auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout); assert(srcTy.getRank() == dstTy.getRank() && "src and dst must have the same rank"); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index a802d62ace..e48cfca441 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -174,8 +174,8 @@ struct ConvertLayoutOpConversion SmallVector outNumCTAsEachRep(rank); SmallVector inNumCTAs(rank); SmallVector outNumCTAs(rank); - auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); - auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout); auto shapePerCTA = getShapePerCTA(srcLayout, shape); for (unsigned d = 0; d < rank; ++d) { diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 26dc8a5379..088dbd9976 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -421,7 +421,7 @@ struct ReduceOpConversion auto resultIndices = emitIndices(loc, rewriter, targetInfo, resultLayout, resultTy, true); auto resultShape = resultTy.getShape(); - auto resultCTATile = getShapePerCTATile(resultLayout, resultShape); + auto resultCTATile = getShapePerCTATile(resultLayout); assert(resultIndices.size() == resultElems); SmallVector resultVals(resultElems); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index dce2a6034f..721c8dd10c 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -201,12 +201,25 @@ SmallVector getUniqueContigPerThread(Attribute layout, } return ret; } - -SmallVector getShapePerCTATile(Attribute layout, - ArrayRef tensorShape) { +SmallVector getShapePerCTATile(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { - return distributedLayout.getShapePerCTATile(tensorShape); + auto sizePerThread = distributedLayout.getSizePerThread(); + auto threadsPerWarp = distributedLayout.getThreadsPerWarp(); + // ThreadsPerWarp does not align with this function for slice layout + if (auto sliceLayout = mlir::dyn_cast(layout)) { + threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent()); + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + } + auto warpsPerCTA = distributedLayout.getWarpsPerCTA(); + assert(sizePerThread.size() == threadsPerWarp.size() && + sizePerThread.size() == warpsPerCTA.size()); + SmallVector shape; + for (auto [size, thread, warp] : + llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) { + shape.push_back(size * thread * warp); + } + return shape; } else { llvm::report_fatal_error("getShapePerCTATile not implemented"); return SmallVector(); @@ -678,14 +691,6 @@ SmallVector BlockedEncodingAttr::getThreadOrder() const { SmallVector BlockedEncodingAttr::getSizePerThread() const { return SmallVector(getSizePerThread__()); } -SmallVector -BlockedEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - SmallVector shape; - for (unsigned d = 0, n = getOrder().size(); d < n; ++d) - shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] * - getWarpsPerCTA()[d]); - return shape; -} template SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { @@ -787,12 +792,6 @@ SmallVector SliceEncodingAttr::getSizePerThread() const { sizePerThread.erase(sizePerThread.begin() + getDim()); return sizePerThread; } -SmallVector -SliceEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - SmallVector shape = ::getShapePerCTATile(getParent(), tensorShape); - shape.erase(shape.begin() + getDim()); - return shape; -} // @@ -979,9 +978,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, } if (auto blockedLayout = mlir::dyn_cast(getParent())) { auto shapePerCTA = getShapePerCTA(*this, shape); - auto shapePerCTATile = ::getShapePerCTATile(blockedLayout); + auto shapePerCTATile = getShapePerCTATile(blockedLayout); auto order = blockedLayout.getOrder(); - auto sizePerThread = ::getSizePerThread(blockedLayout); + auto sizePerThread = blockedLayout.getSizePerThread(); int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; @@ -1043,19 +1042,6 @@ SmallVector DotOperandEncodingAttr::getThreadOrder() const { return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), /*kMajor*/ true); } -SmallVector DotOperandEncodingAttr::getShapePerCTATile( - ArrayRef tensorShape) const { - auto parentLayout = getParent(); - assert(parentLayout && "DotOperandEncodingAttr must have a parent"); - if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { - return parentMmaLayout.getShapePerCTATileForOperand( - tensorShape, getKWidth(), getOpIdx()); - } else { - llvm::report_fatal_error( - "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " - "supported yet"); - } -} LogicalResult DotOperandEncodingAttr::verify( ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, @@ -1562,16 +1548,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { //===----------------------------------------------------------------------===// // TODO: there is a lot of common code with MmaEncoding here -SmallVector -AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); - shapePerCTATile[rank - 1] *= getMDim(); - shapePerCTATile[rank - 2] *= getNDim(); - return shapePerCTATile; -} - SmallVector AMDMfmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -1715,43 +1691,10 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { return sizePerThread; } -SmallVector -AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, - int kWidth, int opIdx) const { - assert(getMDim() == 32 || getMDim() == 16); - auto parentShapePerCTATile = getShapePerCTATile(shape); - auto rank = parentShapePerCTATile.size(); - if (opIdx == 0) { - if (rank == 2) - return {parentShapePerCTATile[rank - 2], 32}; - else - return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32}; - } else if (opIdx == 1) { - if (rank == 2) - return {32, parentShapePerCTATile[rank - 1]}; - else - return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } - llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1"); -} - //===----------------------------------------------------------------------===// // Wmma encoding //===----------------------------------------------------------------------===// -SmallVector -AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); - - auto mnkDim = getMNKDimPerInstr(); - shapePerCTATile[rank - 2] *= mnkDim[0]; - shapePerCTATile[rank - 1] *= mnkDim[1]; - return shapePerCTATile; -} SmallVector AMDWmmaEncodingAttr::getRepOrder() const { auto rank = getWarpsPerCTA().size(); return getMatrixOrder(rank, /*rowMajor*/ true); @@ -1816,21 +1759,6 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { return sizePerThread; } -SmallVector -AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, - int kWidth, int opIdx) const { - auto parentShapePerCTA = getShapePerCTATile(shape); - auto rank = shape.size(); - assert(rank == 2); - if (opIdx == 0) { - return {parentShapePerCTA[0], static_cast(shape[1])}; - } else if (opIdx == 1) { - return {static_cast(shape[0]), parentShapePerCTA[1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } -} - unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx); @@ -1949,24 +1877,6 @@ SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { llvm_unreachable("Unexpected mma version"); } -SmallVector -NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - if (isAmpere()) { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), - warpsPerCTA.end()); - shapePerCTATile[rank - 1] *= 8; - shapePerCTATile[rank - 2] *= 16; - return shapePerCTATile; - } - if (isHopper()) { - auto instrShape = getInstrShape(); - return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]}; - } - llvm::report_fatal_error("Unexpected MMA layout version found"); -} - SmallVector NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { auto rank = getWarpsPerCTA().size(); @@ -2007,16 +1917,6 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, } } -SmallVector NvidiaMmaEncodingAttr::getShapePerCTATileForOperand( - ArrayRef shape, int kWidth, int opIdx) const { - assert(isAmpere() && "mmaLayout Hopper is not implemented yet"); - auto shapePerCTATile = getShapePerCTATile(shape); - auto rank = shapePerCTATile.size(); - auto kDim = opIdx == 0 ? rank - 1 : rank - 2; - // 4 threads * 2 subtiles - shapePerCTATile[kDim] = kWidth * 2 * 4; - return shapePerCTATile; -} SmallVector NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { auto rank = getWarpsPerCTA().size(); diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 7c2473dbe5..0e2a9304eb 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -78,8 +78,7 @@ LogicalResult ExtractSliceOp::verify() { } auto srcShape = srcTy.getShape(); - auto shapePerCTATile = - mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape); + auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcLayout); shapePerCTATile[0] = std::min(static_cast(srcShape[0]), shapePerCTATile[0]); shapePerCTATile[1] = diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp index c0100812f2..ad56bd2d41 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -70,7 +70,7 @@ struct ExtractSliceOpConversion auto order = triton::gpu::getOrder(srcLayout); // Calculate valid total number of workers in each dimension - auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout); shapePerCTATile[0] = std::min(static_cast(srcShape[0]), shapePerCTATile[0]); shapePerCTATile[1] = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 2e28dec802..825697e0e9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -48,7 +48,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, } else { warpOrder = triton::gpu::getWarpOrder(layout); } - auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout); Value warpSize = i32_val(triton::gpu::getWarpSize(layout)); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 185b653f72..1324511aeb 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -17,7 +17,6 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index a439b89270..cc52507121 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -46,7 +46,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, } else { warpOrder = triton::gpu::getWarpOrder(layout); } - auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout); Value warpSize = i32_val(32); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index d1cef15a35..76b5653654 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -12,8 +12,6 @@ using namespace mlir; using mlir::LLVM::getWrappedMultiDimOffset; using ::mlir::LLVM::linearize; -using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; namespace { // declare vprintf(i8*, i8*) as external function LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) { From 251ec88dea0f0c7ab45592c6287fda4b8585440f Mon Sep 17 00:00:00 2001 From: chengjunlu Date: Thu, 21 Nov 2024 06:53:58 +0800 Subject: [PATCH 2/3] [BACKEND] Use the LL API to replace the using of legacy layout attribute API. (#5196) The util function `getDistributedLayoutStr` uses the `DistributedLayout` attribute interface, which is not flexible for third-party extensions. Use the `getInDimSize` of the `LinearLayout`, which is better since the legacy layout has been converted to the `LinearLayout`. There is no new test case since it is only a change in API usage. --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 721c8dd10c..c785271808 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -3058,10 +3058,6 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType, if (!layout) return ""; - unsigned threadsPerWarp = getWarpSize(layout); - unsigned numWarpsPerCTA = getNumWarpsPerCTA(layout); - unsigned numBlocks = getNumCTAs(layout); - int numElementsPerThreads = getTotalElemsPerThread(tensorType); StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register"); StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane"); StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp"); @@ -3074,6 +3070,10 @@ std::string getDistributedLayoutStr(RankedTensorType tensorType, int64_t tensorSize = product(tensorType.getShape()); std::vector elementMapping(tensorSize); std::vector threadMapping; + unsigned threadsPerWarp = ll->getInDimSize(kLane); + unsigned numWarpsPerCTA = ll->getInDimSize(kWarp); + unsigned numBlocks = ll->getInDimSize(kBlock); + int numElementsPerThreads = ll->getInDimSize(kRegister); for (int blockId = 0; blockId < numBlocks; ++blockId) { for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) { for (int tid = 0; tid < threadsPerWarp; ++tid) { From 016ee0055ccfc5a1340134747a2ad973699b9816 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 21 Nov 2024 03:22:42 +0000 Subject: [PATCH 3/3] Revert "[LAYOUTS] Unify the implementation of getShapePerCTA (#5183)" This reverts commit d36e57998e7aaf7495c5d916b51948b2a42ae5f3. --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 4 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 16 ++ lib/Analysis/Allocation.cpp | 6 +- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 4 +- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 138 +++++++++++++++--- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 3 +- .../ExtractSliceOpToLLVM.cpp | 2 +- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 2 +- .../ConvertLayoutOpToLLVM.cpp | 1 + .../LoadStoreOpToLLVM.cpp | 2 +- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 2 + 12 files changed, 153 insertions(+), 29 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index f2715043d7..a9b49448c1 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -116,7 +116,9 @@ SmallVector getCTAOrder(Attribute layout); * (3) In the implementation of emitIndices, ShapePerCTATile will * be replicated or wrapped to fit ShapePerCTA. */ -SmallVector getShapePerCTATile(Attribute layout); +SmallVector +getShapePerCTATile(Attribute layout, + ArrayRef tensorShape = ArrayRef()); SmallVector getShapePerCTA(ArrayRef CTASplitNum, ArrayRef shape); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index e02f0f7b4f..93723e2282 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -502,6 +502,11 @@ We call each individual tile "rep". "SmallVector", "getCTASplitNum">, + InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA", + "SmallVector", + "getShapePerCTATile", + (ins "ArrayRef":$tensorShape)>, + InterfaceMethod<"Gets the number of contiguous elements per thread.", "SmallVector", "getContigPerThread">, @@ -560,6 +565,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, SmallVector getThreadOrder() const; SmallVector getSizePerThread() const; + SmallVector getShapePerCTATile(ArrayRef tensorShape = ArrayRef()) const; std::optional toLinearLayout(ArrayRef shape) const; }]; @@ -759,6 +765,13 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { "bool", "supportReduction">, + InterfaceMethod<"Return shape per CTA.", + "SmallVector", + "getShapePerCTATileForOperand", + (ins "ArrayRef":$tensorShape, + "int":$kWidth, + "int":$opIdx)>, + InterfaceMethod<"Return size per thread for dot operands.", "SmallVector", "getSizePerThreadForOperand", @@ -892,6 +905,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, return true; } SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; @@ -1003,6 +1017,7 @@ Row | warp 0 warp 2 return true; } SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getElemsPerInstrForOperands() const; SmallVector getRepForOperand(ArrayRef operandShape, @@ -1138,6 +1153,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: return false; }; SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; + SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; SmallVector getContigPerThread() { assert(isAmpere() || isHopper()); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 3403674e02..a494af168b 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -41,8 +41,10 @@ static SmallVector getRepShapeForCvt(RankedTensorType srcTy, auto srcShapePerCTA = gpu::getShapePerCTA(srcTy); auto dstShapePerCTA = gpu::getShapePerCTA(dstTy); - auto srcShapePerCTATile = gpu::getShapePerCTATile(srcLayout); - auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout); + auto srcShapePerCTATile = + gpu::getShapePerCTATile(srcLayout, srcTy.getShape()); + auto dstShapePerCTATile = + gpu::getShapePerCTATile(dstLayout, dstTy.getShape()); assert(srcTy.getRank() == dstTy.getRank() && "src and dst must have the same rank"); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index a8572243dc..8e8b089549 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -174,8 +174,8 @@ struct ConvertLayoutOpConversion SmallVector outNumCTAsEachRep(rank); SmallVector inNumCTAs(rank); SmallVector outNumCTAs(rank); - auto srcShapePerCTATile = getShapePerCTATile(srcLayout); - auto dstShapePerCTATile = getShapePerCTATile(dstLayout); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape); auto shapePerCTA = getShapePerCTA(srcLayout, shape); for (unsigned d = 0; d < rank; ++d) { diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 088dbd9976..26dc8a5379 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -421,7 +421,7 @@ struct ReduceOpConversion auto resultIndices = emitIndices(loc, rewriter, targetInfo, resultLayout, resultTy, true); auto resultShape = resultTy.getShape(); - auto resultCTATile = getShapePerCTATile(resultLayout); + auto resultCTATile = getShapePerCTATile(resultLayout, resultShape); assert(resultIndices.size() == resultElems); SmallVector resultVals(resultElems); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 5e0ee92549..8dc7829c38 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -204,25 +204,12 @@ SmallVector getUniqueContigPerThread(Attribute layout, } return ret; } -SmallVector getShapePerCTATile(Attribute layout) { + +SmallVector getShapePerCTATile(Attribute layout, + ArrayRef tensorShape) { if (auto distributedLayout = mlir::dyn_cast(layout)) { - auto sizePerThread = distributedLayout.getSizePerThread(); - auto threadsPerWarp = distributedLayout.getThreadsPerWarp(); - // ThreadsPerWarp does not align with this function for slice layout - if (auto sliceLayout = mlir::dyn_cast(layout)) { - threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent()); - threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); - } - auto warpsPerCTA = distributedLayout.getWarpsPerCTA(); - assert(sizePerThread.size() == threadsPerWarp.size() && - sizePerThread.size() == warpsPerCTA.size()); - SmallVector shape; - for (auto [size, thread, warp] : - llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) { - shape.push_back(size * thread * warp); - } - return shape; + return distributedLayout.getShapePerCTATile(tensorShape); } else { llvm::report_fatal_error("getShapePerCTATile not implemented"); return SmallVector(); @@ -704,6 +691,14 @@ SmallVector BlockedEncodingAttr::getThreadOrder() const { SmallVector BlockedEncodingAttr::getSizePerThread() const { return SmallVector(getSizePerThread__()); } +SmallVector +BlockedEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + SmallVector shape; + for (unsigned d = 0, n = getOrder().size(); d < n; ++d) + shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] * + getWarpsPerCTA()[d]); + return shape; +} template SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { @@ -805,6 +800,12 @@ SmallVector SliceEncodingAttr::getSizePerThread() const { sizePerThread.erase(sizePerThread.begin() + getDim()); return sizePerThread; } +SmallVector +SliceEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + SmallVector shape = ::getShapePerCTATile(getParent(), tensorShape); + shape.erase(shape.begin() + getDim()); + return shape; +} // @@ -998,9 +999,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, } if (auto blockedLayout = mlir::dyn_cast(getParent())) { auto shapePerCTA = getShapePerCTA(*this, shape); - auto shapePerCTATile = getShapePerCTATile(blockedLayout); + auto shapePerCTATile = ::getShapePerCTATile(blockedLayout); auto order = blockedLayout.getOrder(); - auto sizePerThread = blockedLayout.getSizePerThread(); + auto sizePerThread = ::getSizePerThread(blockedLayout); int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; @@ -1071,6 +1072,19 @@ SmallVector DotOperandEncodingAttr::getThreadOrder() const { /*kMajor*/ true); } } +SmallVector DotOperandEncodingAttr::getShapePerCTATile( + ArrayRef tensorShape) const { + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { + return parentMmaLayout.getShapePerCTATileForOperand( + tensorShape, getKWidth(), getOpIdx()); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " + "supported yet"); + } +} LogicalResult DotOperandEncodingAttr::verify( ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, @@ -1592,6 +1606,16 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { //===----------------------------------------------------------------------===// // TODO: there is a lot of common code with MmaEncoding here +SmallVector +AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); + shapePerCTATile[rank - 1] *= getMDim(); + shapePerCTATile[rank - 2] *= getNDim(); + return shapePerCTATile; +} + SmallVector AMDMfmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -1735,10 +1759,43 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { return sizePerThread; } +SmallVector +AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, + int kWidth, int opIdx) const { + assert(getMDim() == 32 || getMDim() == 16); + auto parentShapePerCTATile = getShapePerCTATile(shape); + auto rank = parentShapePerCTATile.size(); + if (opIdx == 0) { + if (rank == 2) + return {parentShapePerCTATile[rank - 2], 32}; + else + return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32}; + } else if (opIdx == 1) { + if (rank == 2) + return {32, parentShapePerCTATile[rank - 1]}; + else + return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } + llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1"); +} + //===----------------------------------------------------------------------===// // Wmma encoding //===----------------------------------------------------------------------===// +SmallVector +AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); + + auto mnkDim = getMNKDimPerInstr(); + shapePerCTATile[rank - 2] *= mnkDim[0]; + shapePerCTATile[rank - 1] *= mnkDim[1]; + return shapePerCTATile; +} SmallVector AMDWmmaEncodingAttr::getRepOrder() const { auto rank = getWarpsPerCTA().size(); return getMatrixOrder(rank, /*rowMajor*/ true); @@ -1803,6 +1860,21 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { return sizePerThread; } +SmallVector +AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, + int kWidth, int opIdx) const { + auto parentShapePerCTA = getShapePerCTATile(shape); + auto rank = shape.size(); + assert(rank == 2); + if (opIdx == 0) { + return {parentShapePerCTA[0], static_cast(shape[1])}; + } else if (opIdx == 1) { + return {static_cast(shape[0]), parentShapePerCTA[1]}; + } else { + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + } +} + unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx); @@ -1921,6 +1993,24 @@ SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { llvm_unreachable("Unexpected mma version"); } +SmallVector +NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { + if (isAmpere()) { + auto warpsPerCTA = getWarpsPerCTA(); + auto rank = warpsPerCTA.size(); + SmallVector shapePerCTATile(warpsPerCTA.begin(), + warpsPerCTA.end()); + shapePerCTATile[rank - 1] *= 8; + shapePerCTATile[rank - 2] *= 16; + return shapePerCTATile; + } + if (isHopper()) { + auto instrShape = getInstrShape(); + return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]}; + } + llvm::report_fatal_error("Unexpected MMA layout version found"); +} + SmallVector NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { auto rank = getWarpsPerCTA().size(); @@ -1961,6 +2051,16 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, } } +SmallVector NvidiaMmaEncodingAttr::getShapePerCTATileForOperand( + ArrayRef shape, int kWidth, int opIdx) const { + assert(isAmpere() && "mmaLayout Hopper is not implemented yet"); + auto shapePerCTATile = getShapePerCTATile(shape); + auto rank = shapePerCTATile.size(); + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; + // 4 threads * 2 subtiles + shapePerCTATile[kDim] = kWidth * 2 * 4; + return shapePerCTATile; +} SmallVector NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { auto rank = getWarpsPerCTA().size(); diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 0e2a9304eb..7c2473dbe5 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -78,7 +78,8 @@ LogicalResult ExtractSliceOp::verify() { } auto srcShape = srcTy.getShape(); - auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcLayout); + auto shapePerCTATile = + mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape); shapePerCTATile[0] = std::min(static_cast(srcShape[0]), shapePerCTATile[0]); shapePerCTATile[1] = diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp index ad56bd2d41..c0100812f2 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -70,7 +70,7 @@ struct ExtractSliceOpConversion auto order = triton::gpu::getOrder(srcLayout); // Calculate valid total number of workers in each dimension - auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape); shapePerCTATile[0] = std::min(static_cast(srcShape[0]), shapePerCTATile[0]); shapePerCTATile[1] = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 825697e0e9..2e28dec802 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -48,7 +48,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, } else { warpOrder = triton::gpu::getWarpOrder(layout); } - auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); Value warpSize = i32_val(triton::gpu::getWarpSize(layout)); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 1324511aeb..185b653f72 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -17,6 +17,7 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index cc52507121..a439b89270 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -46,7 +46,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, } else { warpOrder = triton::gpu::getWarpOrder(layout); } - auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); Value warpSize = i32_val(32); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 76b5653654..d1cef15a35 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -12,6 +12,8 @@ using namespace mlir; using mlir::LLVM::getWrappedMultiDimOffset; using ::mlir::LLVM::linearize; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; namespace { // declare vprintf(i8*, i8*) as external function LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) {