@@ -217,7 +217,7 @@ bool isExpensiveView(Type srcType, Type dstType) {
217217 return getTotalElemsPerThread (srcType) != getTotalElemsPerThread (dstType);
218218}
219219
220- /* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
220+ /* Utility function used by get.*Order methods of SliceEncodingAttr.
221221 * Erase dim and decrease all values larger than dim by 1.
222222 * Example: order = [0, 2, 4, 3, 1], dim = 2
223223 * resOrder = [0, 3, 2, 1]
@@ -262,29 +262,11 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
262262}
263263
264264SmallVector<unsigned > getWarpOrder (Attribute layout) {
265- if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
266- if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent ())) {
267- return getWarpOrder (dotLayout.getParent ());
268- }
269- }
270- auto order = getOrder (layout);
271- // FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
272- // M-major This is awkward. Since we can choose any warpOrder in Ampere, we
273- // should probably choose M-major and change `LinearLayoutConversion.cpp` and
274- // `MMAv2.cpp` to match.
275- if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
276- if (mmaLayout.isHopper ()) {
277- // Hopper MMA instructions force warps to be column-major
278- // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
279- return getMatrixOrder (order.size (), /* rowMajor*/ false );
280- }
281- } else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
282- // It's quite weird to talk about warp order when that the warps
283- // are broadcasted along the K dimension
284- llvm::report_fatal_error (
285- " DotOperandEncoding::getWarpOrder not implemented" );
286- }
287- return order;
265+ if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
266+ return distributedLayout.getWarpOrder ();
267+ else
268+ llvm::report_fatal_error (" Unimplemented usage of getThreadOrder" );
269+ return {};
288270}
289271
290272SmallVector<unsigned > getOrder (Attribute layout) {
@@ -293,7 +275,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
293275 }
294276 if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
295277 // Order doesn't really matter. We just have to be consistent when unpacking
296- // the elements in the MMAv2/V3 lowerings. We choose row-major
278+ // the output elements in the LLVM lowerings. We choose row-major
297279 auto distributedLayout = cast<DistributedEncodingTrait>(layout);
298280 auto rank = distributedLayout.getWarpsPerCTA ().size ();
299281 return getMatrixOrder (rank, /* rowMajor*/ true );
@@ -318,15 +300,15 @@ SmallVector<unsigned> getOrder(Attribute layout) {
318300
319301 llvm::report_fatal_error (" Unimplemented usage of getOrder" );
320302 return {};
321- };
303+ }
322304
323305SmallVector<unsigned > getThreadOrder (Attribute layout) {
324306 if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
325307 return distributedLayout.getThreadOrder ();
326308 else
327309 llvm::report_fatal_error (" Unimplemented usage of getThreadOrder" );
328310 return {};
329- };
311+ }
330312
331313CTALayoutAttr getCTALayout (Attribute layout) {
332314 if (auto distributedLayout =
@@ -769,7 +751,8 @@ SmallVector<unsigned> SliceEncodingAttr::getWarpsPerCTA() const {
769751 return warpsPerCTA;
770752}
771753SmallVector<unsigned > SliceEncodingAttr::getWarpOrder () const {
772- return ::getWarpOrder (*this );
754+ auto parentWarpOrder = ::getWarpOrder (getParent ());
755+ return eraseOrder (parentWarpOrder, getDim ());
773756}
774757SmallVector<unsigned > SliceEncodingAttr::getThreadsPerWarp () const {
775758 auto parent = getParent ();
@@ -781,7 +764,8 @@ SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
781764 return threadsPerWarp;
782765}
783766SmallVector<unsigned > SliceEncodingAttr::getThreadOrder () const {
784- return ::getOrder (*this );
767+ auto parentThreadOrder = ::getThreadOrder (getParent ());
768+ return eraseOrder (parentThreadOrder, getDim ());
785769}
786770SmallVector<unsigned > SliceEncodingAttr::getSizePerThread () const {
787771 auto sizePerThread = ::getSizePerThread (getParent ());
@@ -1049,7 +1033,14 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
10491033 return warps;
10501034}
10511035SmallVector<unsigned > DotOperandEncodingAttr::getWarpOrder () const {
1052- return ::getWarpOrder (*this );
1036+ // FIXME(Lezcano): Preexisting. Do we want to have this path at all?
1037+ if (mlir::isa<AMDMfmaEncodingAttr>(getParent ())) {
1038+ return ::getWarpOrder (getParent ());
1039+ }
1040+ // It's quite weird to talk about warp order when that the warps
1041+ // are broadcasted along the K dimension
1042+ llvm::report_fatal_error (" DotOperandEncoding::getWarpOrder not implemented" );
1043+ return {};
10531044}
10541045SmallVector<unsigned > DotOperandEncodingAttr::getThreadOrder () const {
10551046 return getOrderForDotOperand (getOpIdx (), getWarpsPerCTA ().size (),
@@ -1597,7 +1588,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpsPerCTA() const {
15971588 return SmallVector<unsigned >(getWarpsPerCTA__ ());
15981589}
15991590SmallVector<unsigned > AMDMfmaEncodingAttr::getWarpOrder () const {
1600- return ::getWarpOrder (*this );
1591+ return ::getOrder (*this );
16011592}
16021593SmallVector<unsigned > AMDMfmaEncodingAttr::getThreadOrder () const {
16031594 auto order = ::getOrder (*this );
@@ -1766,7 +1757,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpsPerCTA() const {
17661757 return SmallVector<unsigned >(getWarpsPerCTA__ ());
17671758}
17681759SmallVector<unsigned > AMDWmmaEncodingAttr::getWarpOrder () const {
1769- return ::getWarpOrder (*this );
1760+ return ::getOrder (*this );
17701761}
17711762SmallVector<unsigned > AMDWmmaEncodingAttr::getThreadOrder () const {
17721763 return ::getOrder (*this );
@@ -1890,7 +1881,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpsPerCTA() const {
18901881 return SmallVector<unsigned >(getWarpsPerCTA__ ());
18911882}
18921883SmallVector<unsigned > NvidiaMmaEncodingAttr::getWarpOrder () const {
1893- return ::getWarpOrder (*this );
1884+ auto rank = getWarpsPerCTA ().size ();
1885+ // Hopper (wgmma) uses column-major as this is embeded in the instruction
1886+ // For Ampere we can choose either row-major or column-major.
1887+ // We choose row-major as the legacy path did so
1888+ return getMatrixOrder (rank, /* rowMajor*/ !isHopper ());
18941889}
18951890SmallVector<unsigned > NvidiaMmaEncodingAttr::getThreadsPerWarp () const {
18961891 auto rank = getWarpsPerCTA ().size ();
@@ -1914,10 +1909,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
19141909 " getThreadsPerWarp not implemented for unknown Mma version " );
19151910}
19161911SmallVector<unsigned > NvidiaMmaEncodingAttr::getThreadOrder () const {
1917- return ::getOrder (*this );
1912+ auto rank = getWarpsPerCTA ().size ();
1913+ return getMatrixOrder (rank, /* rowMajor*/ true );
19181914}
19191915SmallVector<unsigned > NvidiaMmaEncodingAttr::getSizePerThread () const {
1920- auto rank = :: getOrder (* this ).size ();
1916+ auto rank = getWarpsPerCTA ( ).size ();
19211917 SmallVector<unsigned > res (rank, 1 );
19221918 if (isAmpere ()) {
19231919 res[rank - 2 ] = 2 ;
@@ -2158,11 +2154,10 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
21582154 if (opIdx == 0 ) {
21592155 sizePerThread[rank - 2 ] = 2 ;
21602156 sizePerThread[rank - 1 ] = 2 * kWidth ;
2161- } else if (opIdx == 1 ) {
2157+ } else {
2158+ assert (opIdx == 1 );
21622159 sizePerThread[rank - 2 ] = 2 * kWidth ;
21632160 sizePerThread[rank - 1 ] = 1 ;
2164- } else {
2165- llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
21662161 }
21672162 return sizePerThread;
21682163}
0 commit comments