@@ -204,25 +204,12 @@ SmallVector<unsigned> getUniqueContigPerThread(Attribute layout,
204204 }
205205 return ret;
206206}
207- SmallVector<unsigned > getShapePerCTATile (Attribute layout) {
207+
208+ SmallVector<unsigned > getShapePerCTATile (Attribute layout,
209+ ArrayRef<int64_t > tensorShape) {
208210 if (auto distributedLayout =
209211 mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
210- auto sizePerThread = distributedLayout.getSizePerThread ();
211- auto threadsPerWarp = distributedLayout.getThreadsPerWarp ();
212- // ThreadsPerWarp does not align with this function for slice layout
213- if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
214- threadsPerWarp = getThreadsPerWarp (sliceLayout.getParent ());
215- threadsPerWarp.erase (threadsPerWarp.begin () + sliceLayout.getDim ());
216- }
217- auto warpsPerCTA = distributedLayout.getWarpsPerCTA ();
218- assert (sizePerThread.size () == threadsPerWarp.size () &&
219- sizePerThread.size () == warpsPerCTA.size ());
220- SmallVector<unsigned > shape;
221- for (auto [size, thread, warp] :
222- llvm::zip (sizePerThread, threadsPerWarp, warpsPerCTA)) {
223- shape.push_back (size * thread * warp);
224- }
225- return shape;
212+ return distributedLayout.getShapePerCTATile (tensorShape);
226213 } else {
227214 llvm::report_fatal_error (" getShapePerCTATile not implemented" );
228215 return SmallVector<unsigned >();
@@ -704,6 +691,14 @@ SmallVector<unsigned> BlockedEncodingAttr::getThreadOrder() const {
704691SmallVector<unsigned > BlockedEncodingAttr::getSizePerThread () const {
705692 return SmallVector<unsigned >(getSizePerThread__ ());
706693}
694+ SmallVector<unsigned >
695+ BlockedEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
696+ SmallVector<unsigned > shape;
697+ for (unsigned d = 0 , n = getOrder ().size (); d < n; ++d)
698+ shape.push_back (getSizePerThread ()[d] * getThreadsPerWarp ()[d] *
699+ getWarpsPerCTA ()[d]);
700+ return shape;
701+ }
707702
708703template <class T >
709704SmallVector<T> SliceEncodingAttr::paddedShape (ArrayRef<T> shape) const {
@@ -805,6 +800,12 @@ SmallVector<unsigned> SliceEncodingAttr::getSizePerThread() const {
805800 sizePerThread.erase (sizePerThread.begin () + getDim ());
806801 return sizePerThread;
807802}
803+ SmallVector<unsigned >
804+ SliceEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
805+ SmallVector<unsigned > shape = ::getShapePerCTATile (getParent (), tensorShape);
806+ shape.erase (shape.begin () + getDim ());
807+ return shape;
808+ }
808809
809810//
810811
@@ -998,9 +999,9 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
998999 }
9991000 if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent ())) {
10001001 auto shapePerCTA = getShapePerCTA (*this , shape);
1001- auto shapePerCTATile = getShapePerCTATile (blockedLayout);
1002+ auto shapePerCTATile = :: getShapePerCTATile (blockedLayout);
10021003 auto order = blockedLayout.getOrder ();
1003- auto sizePerThread = blockedLayout. getSizePerThread ();
1004+ auto sizePerThread = :: getSizePerThread (blockedLayout );
10041005
10051006 int K = getOpIdx () == 0 ? shapePerCTA[1 ] : shapePerCTA[0 ];
10061007 int otherDim = getOpIdx () == 1 ? shapePerCTA[1 ] : shapePerCTA[0 ];
@@ -1071,6 +1072,19 @@ SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
10711072 /* kMajor*/ true );
10721073 }
10731074}
1075+ SmallVector<unsigned > DotOperandEncodingAttr::getShapePerCTATile (
1076+ ArrayRef<int64_t > tensorShape) const {
1077+ auto parentLayout = getParent ();
1078+ assert (parentLayout && " DotOperandEncodingAttr must have a parent" );
1079+ if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) {
1080+ return parentMmaLayout.getShapePerCTATileForOperand (
1081+ tensorShape, getKWidth (), getOpIdx ());
1082+ } else {
1083+ llvm::report_fatal_error (
1084+ " DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not "
1085+ " supported yet" );
1086+ }
1087+ }
10741088
10751089LogicalResult DotOperandEncodingAttr::verify (
10761090 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
@@ -1592,6 +1606,16 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
15921606// ===----------------------------------------------------------------------===//
15931607// TODO: there is a lot of common code with MmaEncoding here
15941608
1609+ SmallVector<unsigned >
1610+ AMDMfmaEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
1611+ auto warpsPerCTA = getWarpsPerCTA ();
1612+ auto rank = warpsPerCTA.size ();
1613+ SmallVector<unsigned > shapePerCTATile (warpsPerCTA.begin (), warpsPerCTA.end ());
1614+ shapePerCTATile[rank - 1 ] *= getMDim ();
1615+ shapePerCTATile[rank - 2 ] *= getNDim ();
1616+ return shapePerCTATile;
1617+ }
1618+
15951619SmallVector<unsigned > AMDMfmaEncodingAttr::getCTAsPerCGA () const {
15961620 return SmallVector<unsigned >(getCTALayout ().getCTAsPerCGA ());
15971621}
@@ -1735,10 +1759,43 @@ AMDMfmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
17351759 return sizePerThread;
17361760}
17371761
1762+ SmallVector<unsigned >
1763+ AMDMfmaEncodingAttr::getShapePerCTATileForOperand (ArrayRef<int64_t > shape,
1764+ int kWidth , int opIdx) const {
1765+ assert (getMDim () == 32 || getMDim () == 16 );
1766+ auto parentShapePerCTATile = getShapePerCTATile (shape);
1767+ auto rank = parentShapePerCTATile.size ();
1768+ if (opIdx == 0 ) {
1769+ if (rank == 2 )
1770+ return {parentShapePerCTATile[rank - 2 ], 32 };
1771+ else
1772+ return {parentShapePerCTATile[0 ], parentShapePerCTATile[rank - 2 ], 32 };
1773+ } else if (opIdx == 1 ) {
1774+ if (rank == 2 )
1775+ return {32 , parentShapePerCTATile[rank - 1 ]};
1776+ else
1777+ return {parentShapePerCTATile[0 ], 32 , parentShapePerCTATile[rank - 1 ]};
1778+ } else {
1779+ llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
1780+ }
1781+ llvm_unreachable (" DotOperandEncodingAttr opIdx must be 0 or 1" );
1782+ }
1783+
17381784// ===----------------------------------------------------------------------===//
17391785// Wmma encoding
17401786// ===----------------------------------------------------------------------===//
17411787
1788+ SmallVector<unsigned >
1789+ AMDWmmaEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
1790+ auto warpsPerCTA = getWarpsPerCTA ();
1791+ auto rank = warpsPerCTA.size ();
1792+ SmallVector<unsigned > shapePerCTATile (warpsPerCTA.begin (), warpsPerCTA.end ());
1793+
1794+ auto mnkDim = getMNKDimPerInstr ();
1795+ shapePerCTATile[rank - 2 ] *= mnkDim[0 ];
1796+ shapePerCTATile[rank - 1 ] *= mnkDim[1 ];
1797+ return shapePerCTATile;
1798+ }
17421799SmallVector<unsigned > AMDWmmaEncodingAttr::getRepOrder () const {
17431800 auto rank = getWarpsPerCTA ().size ();
17441801 return getMatrixOrder (rank, /* rowMajor*/ true );
@@ -1803,6 +1860,21 @@ AMDWmmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
18031860 return sizePerThread;
18041861}
18051862
1863+ SmallVector<unsigned >
1864+ AMDWmmaEncodingAttr::getShapePerCTATileForOperand (ArrayRef<int64_t > shape,
1865+ int kWidth , int opIdx) const {
1866+ auto parentShapePerCTA = getShapePerCTATile (shape);
1867+ auto rank = shape.size ();
1868+ assert (rank == 2 );
1869+ if (opIdx == 0 ) {
1870+ return {parentShapePerCTA[0 ], static_cast <unsigned >(shape[1 ])};
1871+ } else if (opIdx == 1 ) {
1872+ return {static_cast <unsigned >(shape[0 ]), parentShapePerCTA[1 ]};
1873+ } else {
1874+ llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
1875+ }
1876+ }
1877+
18061878unsigned AMDWmmaEncodingAttr::getTotalElemsPerThreadForOperand (
18071879 ArrayRef<int64_t > shape, Type eltTy, int kWidth , int opIdx) const {
18081880 auto rep = getRepForOperand (shape, eltTy, kWidth , opIdx);
@@ -1921,6 +1993,24 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getSizePerThread() const {
19211993 llvm_unreachable (" Unexpected mma version" );
19221994}
19231995
1996+ SmallVector<unsigned >
1997+ NvidiaMmaEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
1998+ if (isAmpere ()) {
1999+ auto warpsPerCTA = getWarpsPerCTA ();
2000+ auto rank = warpsPerCTA.size ();
2001+ SmallVector<unsigned > shapePerCTATile (warpsPerCTA.begin (),
2002+ warpsPerCTA.end ());
2003+ shapePerCTATile[rank - 1 ] *= 8 ;
2004+ shapePerCTATile[rank - 2 ] *= 16 ;
2005+ return shapePerCTATile;
2006+ }
2007+ if (isHopper ()) {
2008+ auto instrShape = getInstrShape ();
2009+ return {16 * getWarpsPerCTA ()[0 ], instrShape[1 ] * getWarpsPerCTA ()[1 ]};
2010+ }
2011+ llvm::report_fatal_error (" Unexpected MMA layout version found" );
2012+ }
2013+
19242014SmallVector<unsigned >
19252015NvidiaMmaEncodingAttr::getRepOrderForOperand (int opIdx) const {
19262016 auto rank = getWarpsPerCTA ().size ();
@@ -1961,6 +2051,16 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
19612051 }
19622052}
19632053
2054+ SmallVector<unsigned > NvidiaMmaEncodingAttr::getShapePerCTATileForOperand (
2055+ ArrayRef<int64_t > shape, int kWidth , int opIdx) const {
2056+ assert (isAmpere () && " mmaLayout Hopper is not implemented yet" );
2057+ auto shapePerCTATile = getShapePerCTATile (shape);
2058+ auto rank = shapePerCTATile.size ();
2059+ auto kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
2060+ // 4 threads * 2 subtiles
2061+ shapePerCTATile[kDim ] = kWidth * 2 * 4 ;
2062+ return shapePerCTATile;
2063+ }
19642064SmallVector<unsigned >
19652065NvidiaMmaEncodingAttr::getSizePerThreadForOperand (int kWidth , int opIdx) const {
19662066 auto rank = getWarpsPerCTA ().size ();
0 commit comments