@@ -204,12 +204,25 @@ SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
204204 }
205205 return ret;
206206}
207-
208- SmallVector<unsigned > getShapePerCTATile (Attribute layout,
209- ArrayRef<int64_t > tensorShape) {
207+ SmallVector<unsigned > getShapePerCTATile (Attribute layout) {
210208 if (auto distributedLayout =
211209 mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
212- return distributedLayout.getShapePerCTATile (tensorShape);
210+ auto sizePerThread = distributedLayout.getSizePerThread ();
211+ auto threadsPerWarp = distributedLayout.getThreadsPerWarp ();
212+ // ThreadsPerWarp does not align with this function for slice layout
213+ if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
214+ threadsPerWarp = getThreadsPerWarp (sliceLayout.getParent ());
215+ threadsPerWarp.erase (threadsPerWarp.begin () + sliceLayout.getDim ());
216+ }
217+ auto warpsPerCTA = distributedLayout.getWarpsPerCTA ();
218+ assert (sizePerThread.size () == threadsPerWarp.size () &&
219+ sizePerThread.size () == warpsPerCTA.size ());
220+ SmallVector<unsigned > shape;
221+ for (auto [size, thread, warp] :
222+ llvm::zip (sizePerThread, threadsPerWarp, warpsPerCTA)) {
223+ shape.push_back (size * thread * warp);
224+ }
225+ return shape;
213226 } else {
214227 llvm::report_fatal_error (" getShapePerCTATile not implemented" );
215228 return SmallVector<unsigned >();
@@ -691,14 +704,6 @@ SmallVector<unsigned> BlockedEncodingAttr::getThreadOrder() const {
691704SmallVector<unsigned > BlockedEncodingAttr::getSizePerThread () const {
692705 return SmallVector<unsigned >(getSizePerThread__ ());
693706}
694- SmallVector<unsigned >
695- BlockedEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
696- SmallVector<unsigned > shape;
697- for (unsigned d = 0 , n = getOrder ().size (); d < n; ++d)
698- shape.push_back (getSizePerThread ()[d] * getThreadsPerWarp ()[d] *
699- getWarpsPerCTA ()[d]);
700- return shape;
701- }
702707
703708template <class T >
704709SmallVector<T> SliceEncodingAttr::paddedShape (ArrayRef<T> shape) const {
@@ -800,12 +805,6 @@ SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
800805 sizePerThread.erase (sizePerThread.begin () + getDim ());
801806 return sizePerThread;
802807}
803- SmallVector<unsigned >
804- SliceEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
805- SmallVector<unsigned > shape = ::getShapePerCTATile (getParent (), tensorShape);
806- shape.erase (shape.begin () + getDim ());
807- return shape;
808- }
809808
810809//
811810
@@ -999,9 +998,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
999998 }
1000999 if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent ())) {
10011000 auto shapePerCTA = getShapePerCTA (*this , shape);
1002- auto shapePerCTATile = :: getShapePerCTATile (blockedLayout);
1001+ auto shapePerCTATile = getShapePerCTATile (blockedLayout);
10031002 auto order = blockedLayout.getOrder ();
1004- auto sizePerThread = :: getSizePerThread (blockedLayout );
1003+ auto sizePerThread = blockedLayout. getSizePerThread ();
10051004
10061005 int K = getOpIdx () == 0 ? shapePerCTA[1 ] : shapePerCTA[0 ];
10071006 int otherDim = getOpIdx () == 1 ? shapePerCTA[1 ] : shapePerCTA[0 ];
@@ -1072,19 +1071,6 @@ SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
10721071 /* kMajor*/ true );
10731072 }
10741073}
1075- SmallVector<unsigned > DotOperandEncodingAttr::getShapePerCTATile (
1076- ArrayRef<int64_t > tensorShape) const {
1077- auto parentLayout = getParent ();
1078- assert (parentLayout && " DotOperandEncodingAttr must have a parent" );
1079- if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) {
1080- return parentMmaLayout.getShapePerCTATileForOperand (
1081- tensorShape, getKWidth (), getOpIdx ());
1082- } else {
1083- llvm::report_fatal_error (
1084- " DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not "
1085- " supported yet" );
1086- }
1087- }
10881074
10891075LogicalResult DotOperandEncodingAttr::verify (
10901076 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
@@ -1606,16 +1592,6 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
16061592// ===----------------------------------------------------------------------===//
16071593// TODO: there is a lot of common code with MmaEncoding here
16081594
1609- SmallVector<unsigned >
1610- AMDMfmaEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
1611- auto warpsPerCTA = getWarpsPerCTA ();
1612- auto rank = warpsPerCTA.size ();
1613- SmallVector<unsigned > shapePerCTATile (warpsPerCTA.begin (), warpsPerCTA.end ());
1614- shapePerCTATile[rank - 1 ] *= getMDim ();
1615- shapePerCTATile[rank - 2 ] *= getNDim ();
1616- return shapePerCTATile;
1617- }
1618-
16191595SmallVector<unsigned > AMDMfmaEncodingAttr::getCTAsPerCGA () const {
16201596 return SmallVector<unsigned >(getCTALayout ().getCTAsPerCGA ());
16211597}
@@ -1759,43 +1735,10 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
17591735 return sizePerThread;
17601736}
17611737
1762- SmallVector<unsigned >
1763- AMDMfmaEncodingAttr::getShapePerCTATileForOperand (ArrayRef<int64_t > shape,
1764- int kWidth , int opIdx) const {
1765- assert (getMDim () == 32 || getMDim () == 16 );
1766- auto parentShapePerCTATile = getShapePerCTATile (shape);
1767- auto rank = parentShapePerCTATile.size ();
1768- if (opIdx == 0 ) {
1769- if (rank == 2 )
1770- return {parentShapePerCTATile[rank - 2 ], 32 };
1771- else
1772- return {parentShapePerCTATile[0 ], parentShapePerCTATile[rank - 2 ], 32 };
1773- } else if (opIdx == 1 ) {
1774- if (rank == 2 )
1775- return {32 , parentShapePerCTATile[rank - 1 ]};
1776- else
1777- return {parentShapePerCTATile[0 ], 32 , parentShapePerCTATile[rank - 1 ]};
1778- } else {
1779- llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
1780- }
1781- llvm_unreachable (" DotOperandEncodingAttr opIdx must be 0 or 1" );
1782- }
1783-
17841738// ===----------------------------------------------------------------------===//
17851739// Wmma encoding
17861740// ===----------------------------------------------------------------------===//
17871741
1788- SmallVector<unsigned >
1789- AMDWmmaEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
1790- auto warpsPerCTA = getWarpsPerCTA ();
1791- auto rank = warpsPerCTA.size ();
1792- SmallVector<unsigned > shapePerCTATile (warpsPerCTA.begin (), warpsPerCTA.end ());
1793-
1794- auto mnkDim = getMNKDimPerInstr ();
1795- shapePerCTATile[rank - 2 ] *= mnkDim[0 ];
1796- shapePerCTATile[rank - 1 ] *= mnkDim[1 ];
1797- return shapePerCTATile;
1798- }
17991742SmallVector<unsigned > AMDWmmaEncodingAttr::getRepOrder () const {
18001743 auto rank = getWarpsPerCTA ().size ();
18011744 return getMatrixOrder (rank, /* rowMajor*/ true );
@@ -1860,21 +1803,6 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
18601803 return sizePerThread;
18611804}
18621805
1863- SmallVector<unsigned >
1864- AMDWmmaEncodingAttr::getShapePerCTATileForOperand (ArrayRef<int64_t > shape,
1865- int kWidth , int opIdx) const {
1866- auto parentShapePerCTA = getShapePerCTATile (shape);
1867- auto rank = shape.size ();
1868- assert (rank == 2 );
1869- if (opIdx == 0 ) {
1870- return {parentShapePerCTA[0 ], static_cast <unsigned >(shape[1 ])};
1871- } else if (opIdx == 1 ) {
1872- return {static_cast <unsigned >(shape[0 ]), parentShapePerCTA[1 ]};
1873- } else {
1874- llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
1875- }
1876- }
1877-
18781806unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand (
18791807 ArrayRef<int64_t > shape, Type eltTy, int kWidth , int opIdx) const {
18801808 auto rep = getRepForOperand (shape, eltTy, kWidth , opIdx);
@@ -1993,24 +1921,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
19931921 llvm_unreachable (" Unexpected mma version" );
19941922}
19951923
1996- SmallVector<unsigned >
1997- NvidiaMmaEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
1998- if (isAmpere ()) {
1999- auto warpsPerCTA = getWarpsPerCTA ();
2000- auto rank = warpsPerCTA.size ();
2001- SmallVector<unsigned > shapePerCTATile (warpsPerCTA.begin (),
2002- warpsPerCTA.end ());
2003- shapePerCTATile[rank - 1 ] *= 8 ;
2004- shapePerCTATile[rank - 2 ] *= 16 ;
2005- return shapePerCTATile;
2006- }
2007- if (isHopper ()) {
2008- auto instrShape = getInstrShape ();
2009- return {16 * getWarpsPerCTA ()[0 ], instrShape[1 ] * getWarpsPerCTA ()[1 ]};
2010- }
2011- llvm::report_fatal_error (" Unexpected MMA layout version found" );
2012- }
2013-
20141924SmallVector<unsigned >
20151925NvidiaMmaEncodingAttr::getRepOrderForOperand (int opIdx) const {
20161926 auto rank = getWarpsPerCTA ().size ();
@@ -2051,16 +1961,6 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
20511961 }
20521962}
20531963
2054- SmallVector<unsigned > NvidiaMmaEncodingAttr::getShapePerCTATileForOperand (
2055- ArrayRef<int64_t > shape, int kWidth , int opIdx) const {
2056- assert (isAmpere () && " mmaLayout Hopper is not implemented yet" );
2057- auto shapePerCTATile = getShapePerCTATile (shape);
2058- auto rank = shapePerCTATile.size ();
2059- auto kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
2060- // 4 threads * 2 subtiles
2061- shapePerCTATile[kDim ] = kWidth * 2 * 4 ;
2062- return shapePerCTATile;
2063- }
20641964SmallVector<unsigned >
20651965NvidiaMmaEncodingAttr::getSizePerThreadForOperand (int kWidth , int opIdx) const {
20661966 auto rank = getWarpsPerCTA ().size ();
0 commit comments