@@ -201,12 +201,25 @@ SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
201201 }
202202 return ret;
203203}
204-
205- SmallVector<unsigned > getShapePerCTATile (Attribute layout,
206- ArrayRef<int64_t > tensorShape) {
204+ SmallVector<unsigned > getShapePerCTATile (Attribute layout) {
207205 if (auto distributedLayout =
208206 mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
209- return distributedLayout.getShapePerCTATile (tensorShape);
207+ auto sizePerThread = distributedLayout.getSizePerThread ();
208+ auto threadsPerWarp = distributedLayout.getThreadsPerWarp ();
209+ // ThreadsPerWarp does not align with this function for slice layout
210+ if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
211+ threadsPerWarp = getThreadsPerWarp (sliceLayout.getParent ());
212+ threadsPerWarp.erase (threadsPerWarp.begin () + sliceLayout.getDim ());
213+ }
214+ auto warpsPerCTA = distributedLayout.getWarpsPerCTA ();
215+ assert (sizePerThread.size () == threadsPerWarp.size () &&
216+ sizePerThread.size () == warpsPerCTA.size ());
217+ SmallVector<unsigned > shape;
218+ for (auto [size, thread, warp] :
219+ llvm::zip (sizePerThread, threadsPerWarp, warpsPerCTA)) {
220+ shape.push_back (size * thread * warp);
221+ }
222+ return shape;
210223 } else {
211224 llvm::report_fatal_error (" getShapePerCTATile not implemented" );
212225 return SmallVector<unsigned >();
@@ -678,14 +691,6 @@ SmallVector<unsigned> BlockedEncodingAttr::getThreadOrder() const {
678691SmallVector<unsigned > BlockedEncodingAttr::getSizePerThread () const {
679692 return SmallVector<unsigned >(getSizePerThread__ ());
680693}
681- SmallVector<unsigned >
682- BlockedEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
683- SmallVector<unsigned > shape;
684- for (unsigned d = 0 , n = getOrder ().size (); d < n; ++d)
685- shape.push_back (getSizePerThread ()[d] * getThreadsPerWarp ()[d] *
686- getWarpsPerCTA ()[d]);
687- return shape;
688- }
689694
690695template <class T >
691696SmallVector<T> SliceEncodingAttr::paddedShape (ArrayRef<T> shape) const {
@@ -787,12 +792,6 @@ SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
787792 sizePerThread.erase (sizePerThread.begin () + getDim ());
788793 return sizePerThread;
789794}
790- SmallVector<unsigned >
791- SliceEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
792- SmallVector<unsigned > shape = ::getShapePerCTATile (getParent (), tensorShape);
793- shape.erase (shape.begin () + getDim ());
794- return shape;
795- }
796795
797796//
798797
@@ -979,9 +978,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
979978 }
980979 if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent ())) {
981980 auto shapePerCTA = getShapePerCTA (*this , shape);
982- auto shapePerCTATile = :: getShapePerCTATile (blockedLayout);
981+ auto shapePerCTATile = getShapePerCTATile (blockedLayout);
983982 auto order = blockedLayout.getOrder ();
984- auto sizePerThread = :: getSizePerThread (blockedLayout );
983+ auto sizePerThread = blockedLayout. getSizePerThread ();
985984
986985 int K = getOpIdx () == 0 ? shapePerCTA[1 ] : shapePerCTA[0 ];
987986 int otherDim = getOpIdx () == 1 ? shapePerCTA[1 ] : shapePerCTA[0 ];
@@ -1043,19 +1042,6 @@ SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
10431042 return getOrderForDotOperand (getOpIdx (), getWarpsPerCTA ().size (),
10441043 /* kMajor*/ true );
10451044}
1046- SmallVector<unsigned > DotOperandEncodingAttr::getShapePerCTATile (
1047- ArrayRef<int64_t > tensorShape) const {
1048- auto parentLayout = getParent ();
1049- assert (parentLayout && " DotOperandEncodingAttr must have a parent" );
1050- if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) {
1051- return parentMmaLayout.getShapePerCTATileForOperand (
1052- tensorShape, getKWidth (), getOpIdx ());
1053- } else {
1054- llvm::report_fatal_error (
1055- " DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not "
1056- " supported yet" );
1057- }
1058- }
10591045
10601046LogicalResult DotOperandEncodingAttr::verify (
10611047 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
@@ -1562,16 +1548,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
15621548// ===----------------------------------------------------------------------===//
15631549// TODO: there is a lot of common code with MmaEncoding here
15641550
1565- SmallVector<unsigned >
1566- AMDMfmaEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
1567- auto warpsPerCTA = getWarpsPerCTA ();
1568- auto rank = warpsPerCTA.size ();
1569- SmallVector<unsigned > shapePerCTATile (warpsPerCTA.begin (), warpsPerCTA.end ());
1570- shapePerCTATile[rank - 1 ] *= getMDim ();
1571- shapePerCTATile[rank - 2 ] *= getNDim ();
1572- return shapePerCTATile;
1573- }
1574-
15751551SmallVector<unsigned > AMDMfmaEncodingAttr::getCTAsPerCGA () const {
15761552 return SmallVector<unsigned >(getCTALayout ().getCTAsPerCGA ());
15771553}
@@ -1715,43 +1691,10 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
17151691 return sizePerThread;
17161692}
17171693
1718- SmallVector<unsigned >
1719- AMDMfmaEncodingAttr::getShapePerCTATileForOperand (ArrayRef<int64_t > shape,
1720- int kWidth , int opIdx) const {
1721- assert (getMDim () == 32 || getMDim () == 16 );
1722- auto parentShapePerCTATile = getShapePerCTATile (shape);
1723- auto rank = parentShapePerCTATile.size ();
1724- if (opIdx == 0 ) {
1725- if (rank == 2 )
1726- return {parentShapePerCTATile[rank - 2 ], 32 };
1727- else
1728- return {parentShapePerCTATile[0 ], parentShapePerCTATile[rank - 2 ], 32 };
1729- } else if (opIdx == 1 ) {
1730- if (rank == 2 )
1731- return {32 , parentShapePerCTATile[rank - 1 ]};
1732- else
1733- return {parentShapePerCTATile[0 ], 32 , parentShapePerCTATile[rank - 1 ]};
1734- } else {
1735- llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
1736- }
1737- llvm_unreachable (" DotOperandEncodingAttr opIdx must be 0 or 1" );
1738- }
1739-
17401694// ===----------------------------------------------------------------------===//
17411695// Wmma encoding
17421696// ===----------------------------------------------------------------------===//
17431697
1744- SmallVector<unsigned >
1745- AMDWmmaEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
1746- auto warpsPerCTA = getWarpsPerCTA ();
1747- auto rank = warpsPerCTA.size ();
1748- SmallVector<unsigned > shapePerCTATile (warpsPerCTA.begin (), warpsPerCTA.end ());
1749-
1750- auto mnkDim = getMNKDimPerInstr ();
1751- shapePerCTATile[rank - 2 ] *= mnkDim[0 ];
1752- shapePerCTATile[rank - 1 ] *= mnkDim[1 ];
1753- return shapePerCTATile;
1754- }
17551698SmallVector<unsigned > AMDWmmaEncodingAttr::getRepOrder () const {
17561699 auto rank = getWarpsPerCTA ().size ();
17571700 return getMatrixOrder (rank, /* rowMajor*/ true );
@@ -1816,21 +1759,6 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
18161759 return sizePerThread;
18171760}
18181761
1819- SmallVector<unsigned >
1820- AMDWmmaEncodingAttr::getShapePerCTATileForOperand (ArrayRef<int64_t > shape,
1821- int kWidth , int opIdx) const {
1822- auto parentShapePerCTA = getShapePerCTATile (shape);
1823- auto rank = shape.size ();
1824- assert (rank == 2 );
1825- if (opIdx == 0 ) {
1826- return {parentShapePerCTA[0 ], static_cast <unsigned >(shape[1 ])};
1827- } else if (opIdx == 1 ) {
1828- return {static_cast <unsigned >(shape[0 ]), parentShapePerCTA[1 ]};
1829- } else {
1830- llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
1831- }
1832- }
1833-
18341762unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand (
18351763 ArrayRef<int64_t > shape, Type eltTy, int kWidth , int opIdx) const {
18361764 auto rep = getRepForOperand (shape, eltTy, kWidth , opIdx);
@@ -1949,24 +1877,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
19491877 llvm_unreachable (" Unexpected mma version" );
19501878}
19511879
1952- SmallVector<unsigned >
1953- NvidiaMmaEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
1954- if (isAmpere ()) {
1955- auto warpsPerCTA = getWarpsPerCTA ();
1956- auto rank = warpsPerCTA.size ();
1957- SmallVector<unsigned > shapePerCTATile (warpsPerCTA.begin (),
1958- warpsPerCTA.end ());
1959- shapePerCTATile[rank - 1 ] *= 8 ;
1960- shapePerCTATile[rank - 2 ] *= 16 ;
1961- return shapePerCTATile;
1962- }
1963- if (isHopper ()) {
1964- auto instrShape = getInstrShape ();
1965- return {16 * getWarpsPerCTA ()[0 ], instrShape[1 ] * getWarpsPerCTA ()[1 ]};
1966- }
1967- llvm::report_fatal_error (" Unexpected MMA layout version found" );
1968- }
1969-
19701880SmallVector<unsigned >
19711881NvidiaMmaEncodingAttr::getRepOrderForOperand (int opIdx) const {
19721882 auto rank = getWarpsPerCTA ().size ();
@@ -2007,16 +1917,6 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
20071917 }
20081918}
20091919
2010- SmallVector<unsigned > NvidiaMmaEncodingAttr::getShapePerCTATileForOperand (
2011- ArrayRef<int64_t > shape, int kWidth , int opIdx) const {
2012- assert (isAmpere () && " mmaLayout Hopper is not implemented yet" );
2013- auto shapePerCTATile = getShapePerCTATile (shape);
2014- auto rank = shapePerCTATile.size ();
2015- auto kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
2016- // 4 threads * 2 subtiles
2017- shapePerCTATile[kDim ] = kWidth * 2 * 4 ;
2018- return shapePerCTATile;
2019- }
20201920SmallVector<unsigned >
20211921NvidiaMmaEncodingAttr::getSizePerThreadForOperand (int kWidth , int opIdx) const {
20221922 auto rank = getWarpsPerCTA ().size ();
0 commit comments