@@ -1105,20 +1105,39 @@ LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11051105 LinearLayout::identity1D (encoding.getCTASplitN (), kCol , dims[1 ]);
11061106 auto newEncoding = TensorMemoryEncodingAttr::get (
11071107 ctx, encoding.getBlockM (), encoding.getBlockN (),
1108- encoding.getColStride (), encoding.getCTASplitM (), 1 );
1108+ encoding.getColStride (), encoding.getCTASplitM (), 1 ,
1109+ encoding.getTwoCTAs ());
11091110 return tensorMemoryToLinearLayout (
11101111 {shape[0 ], shape[1 ] / encoding.getCTASplitN ()}, newEncoding) *
11111112 split;
11121113 }
11131114 if (encoding.getCTASplitM () > 1 ) {
1114- auto split =
1115- LinearLayout::identity1D (encoding.getCTASplitM (), kCol , dims[0 ]);
1115+ auto splitM = encoding.getCTASplitM ();
1116+ auto blockM = encoding.getBlockM ();
1117+ bool isM64TwoCTA = blockM == 64 && encoding.getTwoCTAs ();
1118+ if (isM64TwoCTA) {
1119+ // blockM == 64 and twoCTAs is laid out as the transpose of 128xblockN
1120+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-b
1121+ blockM *= 2 ;
1122+ splitM /= 2 ;
1123+ }
1124+ auto split = LinearLayout::identity1D (splitM, kCol , dims[0 ]);
11161125 auto newEncoding = TensorMemoryEncodingAttr::get (
1117- ctx, encoding.getBlockM (), encoding.getBlockN (),
1118- encoding.getColStride (), 1 , encoding.getCTASplitN ());
1119- return tensorMemoryToLinearLayout (
1120- {shape[0 ] / encoding.getCTASplitM (), shape[1 ]}, newEncoding) *
1121- split;
1126+ ctx, blockM, encoding.getBlockN (), encoding.getColStride (), 1 ,
1127+ encoding.getCTASplitN (), encoding.getTwoCTAs ());
1128+ auto ret =
1129+ tensorMemoryToLinearLayout ({shape[0 ] / splitM, shape[1 ]}, newEncoding) *
1130+ split;
1131+ // In this case, we swap the basis of the last row and last column as per
1132+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-bny
1133+ if (isM64TwoCTA) {
1134+ auto bases = ret.getBases ();
1135+ auto &rowBases = bases[kRow ];
1136+ auto &colBases = bases[kCol ];
1137+ std::swap (rowBases[rowBases.size () - 1 ], colBases[colBases.size () - 1 ]);
1138+ ret = LinearLayout (bases, ret.getOutDims (), ret.isSurjective ());
1139+ }
1140+ return ret;
11221141 }
11231142 assert (encoding.getCTASplitM () == 1 && encoding.getCTASplitN () == 1 );
11241143
0 commit comments