From 1fc391518ca03c603ff165103b8a6fe74e566506 Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Thu, 21 Nov 2024 13:10:45 +0000 Subject: [PATCH 1/4] Revert "Revert "[LAYOUTS] Unify the implementation of getShapePerCTA (#5183)"" This reverts commit 016ee0055ccfc5a1340134747a2ad973699b9816. --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 4 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 16 -- lib/Analysis/Allocation.cpp | 6 +- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 4 +- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 138 +++--------------- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 3 +- .../ExtractSliceOpToLLVM.cpp | 2 +- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 2 +- .../ConvertLayoutOpToLLVM.cpp | 1 - .../LoadStoreOpToLLVM.cpp | 2 +- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 2 - 12 files changed, 29 insertions(+), 153 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index a9b49448c1..f2715043d7 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -116,9 +116,7 @@ SmallVector getCTAOrder(Attribute layout); * (3) In the implementation of emitIndices, ShapePerCTATile will * be replicated or wrapped to fit ShapePerCTA. */ -SmallVector -getShapePerCTATile(Attribute layout, - ArrayRef tensorShape = ArrayRef()); +SmallVector getShapePerCTATile(Attribute layout); SmallVector getShapePerCTA(ArrayRef CTASplitNum, ArrayRef shape); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 93723e2282..e02f0f7b4f 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -502,11 +502,6 @@ We call each individual tile "rep". "SmallVector", "getCTASplitNum">, - InterfaceMethod<"Gets the shape of the encoding's tile, e.g. sizePerThread * threadsPerWarp * warpsPerCTA", - "SmallVector", - "getShapePerCTATile", - (ins "ArrayRef":$tensorShape)>, - InterfaceMethod<"Gets the number of contiguous elements per thread.", "SmallVector", "getContigPerThread">, @@ -565,7 +560,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, SmallVector getThreadOrder() const; SmallVector getSizePerThread() const; - SmallVector getShapePerCTATile(ArrayRef tensorShape = ArrayRef()) const; std::optional toLinearLayout(ArrayRef shape) const; }]; @@ -765,13 +759,6 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> { "bool", "supportReduction">, - InterfaceMethod<"Return shape per CTA.", - "SmallVector", - "getShapePerCTATileForOperand", - (ins "ArrayRef":$tensorShape, - "int":$kWidth, - "int":$opIdx)>, - InterfaceMethod<"Return size per thread for dot operands.", "SmallVector", "getSizePerThreadForOperand", @@ -905,7 +892,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129, return true; } SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getInstrShapeForOperand(int kWidth, int opIdx) const; SmallVector getRepForOperand(ArrayRef operandShape, int kWidth, int opIdx) const; @@ -1017,7 +1003,6 @@ Row | warp 0 warp 2 return true; } SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; SmallVector getElemsPerInstrForOperands() const; SmallVector getRepForOperand(ArrayRef operandShape, @@ -1153,7 +1138,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: return false; }; SmallVector getSizePerThreadForOperand(int kWidth, int opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; SmallVector getContigPerThread() { assert(isAmpere() || isHopper()); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index a494af168b..3403674e02 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -41,10 +41,8 @@ static SmallVector getRepShapeForCvt(RankedTensorType srcTy, auto srcShapePerCTA = gpu::getShapePerCTA(srcTy); auto dstShapePerCTA = gpu::getShapePerCTA(dstTy); - auto srcShapePerCTATile = - gpu::getShapePerCTATile(srcLayout, srcTy.getShape()); - auto dstShapePerCTATile = - gpu::getShapePerCTATile(dstLayout, dstTy.getShape()); + auto srcShapePerCTATile = gpu::getShapePerCTATile(srcLayout); + auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout); assert(srcTy.getRank() == dstTy.getRank() && "src and dst must have the same rank"); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 2d06980809..0a1b0e30fc 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -174,8 +174,8 @@ struct ConvertLayoutOpConversion SmallVector outNumCTAsEachRep(rank); SmallVector inNumCTAs(rank); SmallVector outNumCTAs(rank); - auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); - auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout); auto shapePerCTA = getShapePerCTA(srcLayout, shape); for (unsigned d = 0; d < rank; ++d) { diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 26dc8a5379..088dbd9976 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -421,7 +421,7 @@ struct ReduceOpConversion auto resultIndices = emitIndices(loc, rewriter, targetInfo, resultLayout, resultTy, true); auto resultShape = resultTy.getShape(); - auto resultCTATile = getShapePerCTATile(resultLayout, resultShape); + auto resultCTATile = getShapePerCTATile(resultLayout); assert(resultIndices.size() == resultElems); SmallVector resultVals(resultElems); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 7cde755873..7cc2525733 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -204,12 +204,25 @@ SmallVector getUniqueContigPerThread(Attribute layout, } return ret; } - -SmallVector getShapePerCTATile(Attribute layout, - ArrayRef tensorShape) { +SmallVector getShapePerCTATile(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { - return distributedLayout.getShapePerCTATile(tensorShape); + auto sizePerThread = distributedLayout.getSizePerThread(); + auto threadsPerWarp = distributedLayout.getThreadsPerWarp(); + // ThreadsPerWarp does not align with this function for slice layout + if (auto sliceLayout = mlir::dyn_cast(layout)) { + threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent()); + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + } + auto warpsPerCTA = distributedLayout.getWarpsPerCTA(); + assert(sizePerThread.size() == threadsPerWarp.size() && + sizePerThread.size() == warpsPerCTA.size()); + SmallVector shape; + for (auto [size, thread, warp] : + llvm::zip(sizePerThread, threadsPerWarp, warpsPerCTA)) { + shape.push_back(size * thread * warp); + } + return shape; } else { llvm::report_fatal_error("getShapePerCTATile not implemented"); return SmallVector(); @@ -691,14 +704,6 @@ SmallVector BlockedEncodingAttr::getThreadOrder() const { SmallVector BlockedEncodingAttr::getSizePerThread() const { return SmallVector(getSizePerThread__()); } -SmallVector -BlockedEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - SmallVector shape; - for (unsigned d = 0, n = getOrder().size(); d < n; ++d) - shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] * - getWarpsPerCTA()[d]); - return shape; -} template SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { @@ -800,12 +805,6 @@ SmallVector SliceEncodingAttr::getSizePerThread() const { sizePerThread.erase(sizePerThread.begin() + getDim()); return sizePerThread; } -SmallVector -SliceEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - SmallVector shape = ::getShapePerCTATile(getParent(), tensorShape); - shape.erase(shape.begin() + getDim()); - return shape; -} // @@ -999,9 +998,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, } if (auto blockedLayout = mlir::dyn_cast(getParent())) { auto shapePerCTA = getShapePerCTA(*this, shape); - auto shapePerCTATile = ::getShapePerCTATile(blockedLayout); + auto shapePerCTATile = getShapePerCTATile(blockedLayout); auto order = blockedLayout.getOrder(); - auto sizePerThread = ::getSizePerThread(blockedLayout); + auto sizePerThread = blockedLayout.getSizePerThread(); int K = getOpIdx() == 0 ? shapePerCTA[1] : shapePerCTA[0]; int otherDim = getOpIdx() == 1 ? shapePerCTA[1] : shapePerCTA[0]; @@ -1072,19 +1071,6 @@ SmallVector DotOperandEncodingAttr::getThreadOrder() const { /*kMajor*/ true); } } -SmallVector DotOperandEncodingAttr::getShapePerCTATile( - ArrayRef tensorShape) const { - auto parentLayout = getParent(); - assert(parentLayout && "DotOperandEncodingAttr must have a parent"); - if (auto parentMmaLayout = mlir::dyn_cast(parentLayout)) { - return parentMmaLayout.getShapePerCTATileForOperand( - tensorShape, getKWidth(), getOpIdx()); - } else { - llvm::report_fatal_error( - "DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not " - "supported yet"); - } -} LogicalResult DotOperandEncodingAttr::verify( ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, @@ -1606,16 +1592,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const { //===----------------------------------------------------------------------===// // TODO: there is a lot of common code with MmaEncoding here -SmallVector -AMDMfmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); - shapePerCTATile[rank - 1] *= getMDim(); - shapePerCTATile[rank - 2] *= getNDim(); - return shapePerCTATile; -} - SmallVector AMDMfmaEncodingAttr::getCTAsPerCGA() const { return SmallVector(getCTALayout().getCTAsPerCGA()); } @@ -1759,43 +1735,10 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { return sizePerThread; } -SmallVector -AMDMfmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, - int kWidth, int opIdx) const { - assert(getMDim() == 32 || getMDim() == 16); - auto parentShapePerCTATile = getShapePerCTATile(shape); - auto rank = parentShapePerCTATile.size(); - if (opIdx == 0) { - if (rank == 2) - return {parentShapePerCTATile[rank - 2], 32}; - else - return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], 32}; - } else if (opIdx == 1) { - if (rank == 2) - return {32, parentShapePerCTATile[rank - 1]}; - else - return {parentShapePerCTATile[0], 32, parentShapePerCTATile[rank - 1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } - llvm_unreachable("DotOperandEncodingAttr opIdx must be 0 or 1"); -} - //===----------------------------------------------------------------------===// // Wmma encoding //===----------------------------------------------------------------------===// -SmallVector -AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), warpsPerCTA.end()); - - auto mnkDim = getMNKDimPerInstr(); - shapePerCTATile[rank - 2] *= mnkDim[0]; - shapePerCTATile[rank - 1] *= mnkDim[1]; - return shapePerCTATile; -} SmallVector AMDWmmaEncodingAttr::getRepOrder() const { auto rank = getWarpsPerCTA().size(); return getMatrixOrder(rank, /*rowMajor*/ true); @@ -1860,21 +1803,6 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { return sizePerThread; } -SmallVector -AMDWmmaEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, - int kWidth, int opIdx) const { - auto parentShapePerCTA = getShapePerCTATile(shape); - auto rank = shape.size(); - assert(rank == 2); - if (opIdx == 0) { - return {parentShapePerCTA[0], static_cast(shape[1])}; - } else if (opIdx == 1) { - return {static_cast(shape[0]), parentShapePerCTA[1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } -} - unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, Type eltTy, int kWidth, int opIdx) const { auto rep = getRepForOperand(shape, eltTy, kWidth, opIdx); @@ -1993,24 +1921,6 @@ SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { llvm_unreachable("Unexpected mma version"); } -SmallVector -NvidiaMmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - if (isAmpere()) { - auto warpsPerCTA = getWarpsPerCTA(); - auto rank = warpsPerCTA.size(); - SmallVector shapePerCTATile(warpsPerCTA.begin(), - warpsPerCTA.end()); - shapePerCTATile[rank - 1] *= 8; - shapePerCTATile[rank - 2] *= 16; - return shapePerCTATile; - } - if (isHopper()) { - auto instrShape = getInstrShape(); - return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]}; - } - llvm::report_fatal_error("Unexpected MMA layout version found"); -} - SmallVector NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { auto rank = getWarpsPerCTA().size(); @@ -2051,16 +1961,6 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, } } -SmallVector NvidiaMmaEncodingAttr::getShapePerCTATileForOperand( - ArrayRef shape, int kWidth, int opIdx) const { - assert(isAmpere() && "mmaLayout Hopper is not implemented yet"); - auto shapePerCTATile = getShapePerCTATile(shape); - auto rank = shapePerCTATile.size(); - auto kDim = opIdx == 0 ? rank - 1 : rank - 2; - // 4 threads * 2 subtiles - shapePerCTATile[kDim] = kWidth * 2 * 4; - return shapePerCTATile; -} SmallVector NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { auto rank = getWarpsPerCTA().size(); diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 7c2473dbe5..0e2a9304eb 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -78,8 +78,7 @@ LogicalResult ExtractSliceOp::verify() { } auto srcShape = srcTy.getShape(); - auto shapePerCTATile = - mlir::triton::gpu::getShapePerCTATile(srcLayout, srcShape); + auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcLayout); shapePerCTATile[0] = std::min(static_cast(srcShape[0]), shapePerCTATile[0]); shapePerCTATile[1] = diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp index c0100812f2..ad56bd2d41 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -70,7 +70,7 @@ struct ExtractSliceOpConversion auto order = triton::gpu::getOrder(srcLayout); // Calculate valid total number of workers in each dimension - auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout, srcShape); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcLayout); shapePerCTATile[0] = std::min(static_cast(srcShape[0]), shapePerCTATile[0]); shapePerCTATile[1] = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 2e28dec802..825697e0e9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -48,7 +48,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, } else { warpOrder = triton::gpu::getWarpOrder(layout); } - auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout); Value warpSize = i32_val(triton::gpu::getWarpSize(layout)); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 185b653f72..1324511aeb 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -17,7 +17,6 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index a439b89270..cc52507121 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -46,7 +46,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, } else { warpOrder = triton::gpu::getWarpOrder(layout); } - auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout); Value warpSize = i32_val(32); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index d1cef15a35..76b5653654 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -12,8 +12,6 @@ using namespace mlir; using mlir::LLVM::getWrappedMultiDimOffset; using ::mlir::LLVM::linearize; -using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; namespace { // declare vprintf(i8*, i8*) as external function LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) { From 6e2da6aaf8cd28ff1be275d57e9c507829bb6986 Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Thu, 21 Nov 2024 14:16:54 +0000 Subject: [PATCH 2/4] Clean the `getShapePerCTATile` and `getShapePerCTATileForOperand` interface in DPAS layout. --- .../IR/TritonIntelGPUAttrDefs.td | 1 - .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 48 ++++--------------- .../ConvertLayoutOpToLLVM.cpp | 4 +- .../LoadStoreOpToLLVM.cpp | 2 +- .../TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp | 2 +- 5 files changed, 12 insertions(+), 45 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index bb456eaf38..6036c54e84 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -84,7 +84,6 @@ along the row (resp. col) dimension. SmallVector getDPASRepetitions(ArrayRef shape, int opIdx) const; SmallVector getSizePerThreadForOperand(int kWidth,unsigned opIdx) const; SmallVector getElemsPerThreadForOperands(ArrayRef shape, Type eltTy, unsigned opIdx) const; - SmallVector getShapePerCTATileForOperand(ArrayRef shape, int kWidth, int opIdx) const; SmallVector getRepOrderForOperand(int opIdx) const; unsigned getTotalElemsPerThreadForOperand(ArrayRef shape, Type eltTy, int kWidth, int opIdx) const; diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index baa0e3e347..0546d0c8de 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -147,18 +147,6 @@ SmallVector DpasEncodingAttr::getSizePerThread() const { return res; } -SmallVector -DpasEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - auto shapeC = getShapeC(); - SmallVector warpsPerCTA = getWarpsPerCTA(); - size_t rank = shapeC.size(); - SmallVector shapePerCTATile(rank); - llvm::transform( - llvm::zip_equal(shapeC, warpsPerCTA), shapePerCTATile.begin(), - [](auto entry) { return std::get<0>(entry) * std::get<1>(entry); }); - return shapePerCTATile; -} - SmallVector DpasEncodingAttr::getRepOrder() const { llvm::report_fatal_error("NYI. DpasEncodingAttr::getRepOrder"); } @@ -174,7 +162,14 @@ DpasEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { assert((rank == 2 || rank == 3) && "Unexpected rank of mma layout"); SmallVector elemsPerThread(rank, 1); - auto shapePerCTATile = getShapePerCTATile(shape); + + auto shapeC = getShapeC(); + SmallVector warpsPerCTA = getWarpsPerCTA(); + SmallVector shapePerCTATile(rank); + llvm::transform( + llvm::zip_equal(shapeC, warpsPerCTA), shapePerCTATile.begin(), + [](auto entry) { return std::get<0>(entry) * std::get<1>(entry); }); + unsigned tilesRow = ceil(shape[rank - 2], shapePerCTATile[rank - 2]); unsigned tilesCol = @@ -306,33 +301,6 @@ SmallVector DpasEncodingAttr::getThreadsPerWarp() const { return res; } -SmallVector -DpasEncodingAttr::getShapePerCTATileForOperand(ArrayRef shape, - int kWidth, int opIdx) const { - auto parentShapePerCTATile = getShapePerCTATile(shape); - size_t rank = parentShapePerCTATile.size(); - assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout"); - if (opIdx == 0) { - auto shapeA = getShapeA(); - return (rank == 2) - ? SmallVector{parentShapePerCTATile[0], shapeA[1]} - : SmallVector{parentShapePerCTATile[0], - parentShapePerCTATile[rank - 2], - shapeA[rank - 1]}; - } - - if (opIdx == 1) { - auto shapeB = getShapeB(); - return (rank == 2) - ? SmallVector{shapeB[0], parentShapePerCTATile[1]} - : SmallVector{parentShapePerCTATile[0], - shapeB[rank - 2], - parentShapePerCTATile[rank - 1]}; - } - - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); -} - SmallVector DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const { ArrayRef repCluster = getRepCluster(); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 72d5f7e291..83f9042912 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -248,8 +248,8 @@ struct ConvertLayoutOpConversion SmallVector outNumCTAsEachRep(rank); SmallVector inNumCTAs(rank); SmallVector outNumCTAs(rank); - auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); - auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape); + auto srcShapePerCTATile = getShapePerCTATile(srcLayout); + auto dstShapePerCTATile = getShapePerCTATile(dstLayout); auto shapePerCTA = getShapePerCTA(srcLayout, shape); for (unsigned d = 0; d < rank; ++d) { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index b7ec2862bd..6604e2e8db 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -36,7 +36,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); auto order = triton::gpu::getOrder(layout); - auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout); Value warpSize = LLVM::intel::getModuleWarpSize(rewriter, loc); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp index 980ac33c3d..2ead653b07 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp @@ -408,7 +408,7 @@ struct ReduceOpConversion auto resultIndices = ::intel::emitIndices(loc, rewriter, targetInfo, resultLayout, resultTy, true); auto resultShape = resultTy.getShape(); - auto resultCTATile = getShapePerCTATile(resultLayout, resultShape); + auto resultCTATile = getShapePerCTATile(resultLayout); assert(resultIndices.size() == resultElems); SmallVector resultVals(resultElems); From d09322024b6f347df01240de17e55f3a1291e402 Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Fri, 22 Nov 2024 08:57:37 +0000 Subject: [PATCH 3/4] Fix bug of `getShapePerCTATile` for slice layout --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 7cc2525733..483e4e3274 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -209,12 +209,15 @@ SmallVector getShapePerCTATile(Attribute layout) { mlir::dyn_cast(layout)) { auto sizePerThread = distributedLayout.getSizePerThread(); auto threadsPerWarp = distributedLayout.getThreadsPerWarp(); - // ThreadsPerWarp does not align with this function for slice layout + auto warpsPerCTA = distributedLayout.getWarpsPerCTA(); + // ThreadsPerWarp and warpsPerCTA does not align with this function for + // slice layout if (auto sliceLayout = mlir::dyn_cast(layout)) { threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent()); threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + warpsPerCTA = getWarpsPerCTA(sliceLayout.getParent()); + warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim()); } - auto warpsPerCTA = distributedLayout.getWarpsPerCTA(); assert(sizePerThread.size() == threadsPerWarp.size() && sizePerThread.size() == warpsPerCTA.size()); SmallVector shape; From 1aebd43cafb87e3c5a737f7e5417fbb3cb7c47ee Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Fri, 22 Nov 2024 11:13:30 +0000 Subject: [PATCH 4/4] Enable LL for `intel::emitOffsetForLayout` --- third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h | 6 ------ 1 file changed, 6 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h index 2160b8f17d..66ab31d336 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h @@ -571,12 +571,6 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, inline SmallVector> emitOffsetForLayout(Attribute layout, RankedTensorType type) { - if (auto dpasLayout = dyn_cast(layout)) - return emitOffsetForDpasLayout(dpasLayout, type); - if (auto dotLayout = dyn_cast(layout)) - return emitOffsetForDotOpLayout(dotLayout, type); - if (auto sliceLayout = dyn_cast(layout)) - return ::intel::emitOffsetForSliceLayout(sliceLayout, type); return mlir::emitOffsetForLayout(layout, type); }