@@ -220,7 +220,7 @@ bool isExpensiveView(Type srcType, Type dstType) {
220220 return getTotalElemsPerThread (srcType) != getTotalElemsPerThread (dstType);
221221}
222222
223- /* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
223+ /* Utility function used by get.*Order methods of SliceEncodingAttr.
224224 * Erase dim and decrease all values larger than dim by 1.
225225 * Example: order = [0, 2, 4, 3, 1], dim = 2
226226 * resOrder = [0, 3, 2, 1]
@@ -265,29 +265,11 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
265265}
266266
267267SmallVector<unsigned > getWarpOrder (Attribute layout) {
268- if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
269- if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent ())) {
270- return getWarpOrder (dotLayout.getParent ());
271- }
272- }
273- auto order = getOrder (layout);
274- // FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
275- // M-major This is awkward. Since we can choose any warpOrder in Ampere, we
276- // should probably choose M-major and change `LinearLayoutConversion.cpp` and
277- // `MMAv2.cpp` to match.
278- if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
279- if (mmaLayout.isHopper ()) {
280- // Hopper MMA instructions force warps to be column-major
281- // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
282- return getMatrixOrder (order.size (), /* rowMajor*/ false );
283- }
284- } else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
285- // It's quite weird to talk about warp order when that the warps
286- // are broadcasted along the K dimension
287- llvm::report_fatal_error (
288- " DotOperandEncoding::getWarpOrder not implemented" );
289- }
290- return order;
268+ if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
269+ return distributedLayout.getWarpOrder ();
270+ else
271+ llvm::report_fatal_error (" Unimplemented usage of getThreadOrder" );
272+ return {};
291273}
292274
293275SmallVector<unsigned > getOrder (Attribute layout) {
@@ -296,7 +278,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
296278 }
297279 if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
298280 // Order doesn't really matter. We just have to be consistent when unpacking
299- // the elements in the MMAv2/V3 lowerings. We choose row-major
281+ // the output elements in the LLVM lowerings. We choose row-major
300282 auto distributedLayout = cast<DistributedEncodingTrait>(layout);
301283 auto rank = distributedLayout.getWarpsPerCTA ().size ();
302284 return getMatrixOrder (rank, /* rowMajor*/ true );
@@ -331,15 +313,15 @@ SmallVector<unsigned> getOrder(Attribute layout) {
331313
332314 llvm::report_fatal_error (" Unimplemented usage of getOrder" );
333315 return {};
334- };
316+ }
335317
336318SmallVector<unsigned > getThreadOrder (Attribute layout) {
337319 if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
338320 return distributedLayout.getThreadOrder ();
339321 else
340322 llvm::report_fatal_error (" Unimplemented usage of getThreadOrder" );
341323 return {};
342- };
324+ }
343325
344326CTALayoutAttr getCTALayout (Attribute layout) {
345327 if (auto distributedLayout =
@@ -782,7 +764,8 @@ SmallVector<unsigned> SliceEncodingAttr::getWarpsPerCTA() const {
782764 return warpsPerCTA;
783765}
784766SmallVector<unsigned > SliceEncodingAttr::getWarpOrder () const {
785- return ::getWarpOrder (*this );
767+ auto parentWarpOrder = ::getWarpOrder (getParent ());
768+ return eraseOrder (parentWarpOrder, getDim ());
786769}
787770SmallVector<unsigned > SliceEncodingAttr::getThreadsPerWarp () const {
788771 auto parent = getParent ();
@@ -794,7 +777,8 @@ SmallVector<unsigned> SliceEncodingAttr::getThreadsPerWarp() const {
794777 return threadsPerWarp;
795778}
796779SmallVector<unsigned > SliceEncodingAttr::getThreadOrder () const {
797- return ::getOrder (*this );
780+ auto parentThreadOrder = ::getThreadOrder (getParent ());
781+ return eraseOrder (parentThreadOrder, getDim ());
798782}
799783SmallVector<unsigned > SliceEncodingAttr::getSizePerThread () const {
800784 auto sizePerThread = ::getSizePerThread (getParent ());
@@ -1065,7 +1049,14 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
10651049 return warps;
10661050}
10671051SmallVector<unsigned > DotOperandEncodingAttr::getWarpOrder () const {
1068- return ::getWarpOrder (*this );
1052+ // FIXME(Lezcano): Preexisting. Do we want to have this path at all?
1053+ if (mlir::isa<AMDMfmaEncodingAttr>(getParent ())) {
1054+ return ::getWarpOrder (getParent ());
1055+ }
1056+ // It's quite weird to talk about warp order when that the warps
1057+ // are broadcasted along the K dimension
1058+ llvm::report_fatal_error (" DotOperandEncoding::getWarpOrder not implemented" );
1059+ return {};
10691060}
10701061SmallVector<unsigned > DotOperandEncodingAttr::getThreadOrder () const {
10711062 // FIXME: delete if branch for `DpasEncodingAttr` and provide more
@@ -1637,7 +1628,7 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpsPerCTA() const {
16371628 return SmallVector<unsigned >(getWarpsPerCTA__ ());
16381629}
16391630SmallVector<unsigned > AMDMfmaEncodingAttr::getWarpOrder () const {
1640- return ::getWarpOrder (*this );
1631+ return ::getOrder (*this );
16411632}
16421633SmallVector<unsigned > AMDMfmaEncodingAttr::getThreadOrder () const {
16431634 auto order = ::getOrder (*this );
@@ -1806,7 +1797,7 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getWarpsPerCTA() const {
18061797 return SmallVector<unsigned >(getWarpsPerCTA__ ());
18071798}
18081799SmallVector<unsigned > AMDWmmaEncodingAttr::getWarpOrder () const {
1809- return ::getWarpOrder (*this );
1800+ return ::getOrder (*this );
18101801}
18111802SmallVector<unsigned > AMDWmmaEncodingAttr::getThreadOrder () const {
18121803 return ::getOrder (*this );
@@ -1930,7 +1921,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getWarpsPerCTA() const {
19301921 return SmallVector<unsigned >(getWarpsPerCTA__ ());
19311922}
19321923SmallVector<unsigned > NvidiaMmaEncodingAttr::getWarpOrder () const {
1933- return ::getWarpOrder (*this );
1924+ auto rank = getWarpsPerCTA ().size ();
1925+ // Hopper (wgmma) uses column-major as this is embeded in the instruction
1926+ // For Ampere we can choose either row-major or column-major.
1927+ // We choose row-major as the legacy path did so
1928+ return getMatrixOrder (rank, /* rowMajor*/ !isHopper ());
19341929}
19351930SmallVector<unsigned > NvidiaMmaEncodingAttr::getThreadsPerWarp () const {
19361931 auto rank = getWarpsPerCTA ().size ();
@@ -1954,10 +1949,11 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getThreadsPerWarp() const {
19541949 " getThreadsPerWarp not implemented for unknown Mma version " );
19551950}
19561951SmallVector<unsigned > NvidiaMmaEncodingAttr::getThreadOrder () const {
1957- return ::getOrder (*this );
1952+ auto rank = getWarpsPerCTA ().size ();
1953+ return getMatrixOrder (rank, /* rowMajor*/ true );
19581954}
19591955SmallVector<unsigned > NvidiaMmaEncodingAttr::getSizePerThread () const {
1960- auto rank = :: getOrder (* this ).size ();
1956+ auto rank = getWarpsPerCTA ( ).size ();
19611957 SmallVector<unsigned > res (rank, 1 );
19621958 if (isAmpere ()) {
19631959 res[rank - 2 ] = 2 ;
@@ -2198,11 +2194,10 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
21982194 if (opIdx == 0 ) {
21992195 sizePerThread[rank - 2 ] = 2 ;
22002196 sizePerThread[rank - 1 ] = 2 * kWidth ;
2201- } else if (opIdx == 1 ) {
2197+ } else {
2198+ assert (opIdx == 1 );
22022199 sizePerThread[rank - 2 ] = 2 * kWidth ;
22032200 sizePerThread[rank - 1 ] = 1 ;
2204- } else {
2205- llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
22062201 }
22072202 return sizePerThread;
22082203}
0 commit comments