@@ -41,6 +41,17 @@ SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
4141 return ret;
4242}
4343
44+ // TODO Have order be a mandatory argument of standardOutDimNames.
45+ SmallVector<StringAttr> permuteDimNames (const SmallVector<StringAttr> &names,
46+ const SmallVector<unsigned > &order) {
47+ assert (names.size () == order.size ());
48+ SmallVector<StringAttr> ret;
49+ for (unsigned i : order) {
50+ ret.push_back (names[i]);
51+ }
52+ return ret;
53+ }
54+
4455void assertIsRegisterLayout (const LinearLayout &layout) {
4556 assert (layout.getNumInDims () > 0 );
4657 MLIRContext *ctx = layout.getInDimNames ().begin ()->getContext ();
@@ -281,15 +292,19 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,
281292
282293 MLIRContext *ctx = mma.getContext ();
283294 SmallVector<StringAttr> dimNames = standardOutDimNames (ctx, rank);
295+ auto orderedDimNames = permuteDimNames (dimNames, getOrder (mma));
296+ // By using `reverse(dimNames)` below, we set the order to be row-major
297+ assert (getOrder (mma) == getMatrixOrder (rank, /* rowMajor=*/ true ));
284298
285299 LinearLayout ctaLayout (
286300 {{S (" register" ), {{1 , 0 }, {0 , 8 }}},
287301 {S (" lane" ), {{2 , 0 }, {4 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }}}},
288- llvm::to_vector (llvm::reverse (ArrayRef (dimNames).take_back (2 ))));
289-
290- ctaLayout *= identityND (
291- S (" warp" ), mma.getWarpsPerCTA (),
292- llvm::to_vector (llvm::reverse (llvm::seq<unsigned >(rank))), dimNames);
302+ ArrayRef (orderedDimNames).take_front (2 ));
303+ assert (getWarpOrder (mma) == getMatrixOrder (rank, /* rowMajor=*/ true ));
304+ // FIXME(Lezcano). identityND should not have an `order` param as it's
305+ // redundant with the order of the out dims.
306+ ctaLayout *=
307+ identityND (S (" warp" ), mma.getWarpsPerCTA (), mma.getWarpOrder (), dimNames);
293308
294309 return combineCtaCgaWithShape (ctaLayout, mma.getCTALayout (), shape);
295310}
@@ -322,10 +337,14 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
322337 ctaLayout *= LinearLayout::identity1D (n / ctaLayout.getOutDimSize (S (" dim1" )),
323338 S (" register" ), S (" dim1" ));
324339
325- // Expand the `warp` dimension according to warpsPerCTA.
326- //
327- // It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but
328- // this really does seem to be correct.
340+ // The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major.
341+ // Since the warpOrder needs to be M-major, we need to transpose the out
342+ // dimensions AND transpose the order
343+ // FIXME(Lezcano). identityND should not have an `order` param as it's
344+ // redundant. The order is already given by the order of the
345+ // out dims, and if it has an order, it shouldn't change the
346+ // order of the out dims.
347+ assert (getWarpOrder (mma) == SmallVector<unsigned >({0 , 1 }));
329348 ctaLayout *= identityND (S (" warp" ), mma.getWarpsPerCTA (), /* order=*/ {0 , 1 },
330349 {S (" dim0" ), S (" dim1" )})
331350 .transposeOuts (llvm::to_vector (ctaLayout.getOutDimNames ()));
@@ -843,18 +862,24 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
843862
844863LinearLayout ampereDotToLinearLayout (ArrayRef<int64_t > shape,
845864 DotOperandEncodingAttr dot) {
846- // TODO,BE. Implement ampereMMA in terms of this one
865+ // Note that, even though MMAv2 looks similar to this layout, they are just
866+ // the same at a register and lane level. The warps treatment is different!
847867 int rank = shape.size ();
848868 auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent ());
849869 int kWidth = dot.getKWidth ();
850870 bool isA = dot.getOpIdx () == 0 ;
851871
852- assert (mma.isAmpere ());
853872 assert ((rank == 2 && mma.getInstrShape () == ArrayRef<unsigned >({16 , 8 })) ||
854873 (rank == 3 && mma.getInstrShape () == ArrayRef<unsigned >({1 , 16 , 8 })));
874+ assert (mma.isAmpere ());
855875
856876 MLIRContext *ctx = mma.getContext ();
857- SmallVector<StringAttr> dimNames = standardOutDimNames (ctx, rank);
877+ // A and B have kMajor order
878+ assert (getOrder (dot) ==
879+ getOrderForDotOperand (dot.getOpIdx (), rank, /* kMajor=*/ true ));
880+
881+ auto kMajorDims =
882+ permuteDimNames (standardOutDimNames (ctx, rank), getOrder (dot));
858883
859884 // Implement A. For B transpose in the end
860885 std::vector<std::vector<int32_t >> registers;
@@ -881,24 +906,51 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
881906 }
882907 registers.push_back ({i, 0 });
883908
884- if (!isA) {
885- for (auto &r : registers) {
886- std::swap (r[0 ], r[1 ]);
909+ LinearLayout ctaLayout ({{S (" register" ), registers}, {S (" lane" ), lanes}},
910+ ArrayRef (kMajorDims ).take_front (2 ));
911+
912+ // Let warpsPerCTAMma = {2, 2}, then
913+ // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
914+ // assume warpOrder = {0, 1}
915+ // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
916+ // the C is owned as per the following layout:
917+ // C: 0 | 1
918+ // - | -
919+ // 2 | 3
920+ // In order to be able to compute C, we need the following warp tiling of
921+ // A and B:
922+ // A: 0 1 | 0 1 B: 0 2 | 1 3
923+ // - - | - - - - | - -
924+ // 2 3 | 2 3 0 2 | 1 3
925+ // In particular, for A and B we need to broadcast along K
926+
927+ assert (mma.getWarpOrder () == getMatrixOrder (rank, /* rowMajor=*/ true ));
928+ auto warpsPerCTAMma = mma.getWarpsPerCTA ();
929+ std::vector<std::vector<int32_t >> warps;
930+ if (isA) {
931+ for (int i = 1 ; i < warpsPerCTAMma[1 ]; i *= 2 ) {
932+ warps.push_back ({0 , 0 });
933+ }
934+ for (int i = 1 ; i < warpsPerCTAMma[0 ]; i *= 2 ) {
935+ warps.push_back ({0 , i});
936+ }
937+ } else {
938+ for (int i = 1 ; i < warpsPerCTAMma[1 ]; i *= 2 ) {
939+ warps.push_back ({0 , i});
887940 }
888- for (auto &l : lanes) {
889- std::swap (l[0 ], l[1 ]);
941+ for (int i = 1 ; i < warpsPerCTAMma[0 ]; i *= 2 ) {
942+ warps.push_back ({0 , 0 });
943+ }
944+ }
945+ if (rank == 3 ) {
946+ for (auto &w : warps) {
947+ w.push_back (0 );
890948 }
891949 }
892950
893- LinearLayout ctaLayout (
894- {{S (" register" ), registers}, {S (" lane" ), lanes}},
895- llvm::to_vector (llvm::reverse (ArrayRef (dimNames).take_back (2 ))));
896-
897- auto order = dot.getCTAOrder ();
898- assert (order[0 ] == rank - 1 && order[1 ] == rank - 2 );
899- ctaLayout *= identityND (S (" warp" ), dot.getWarpsPerCTA (), order, dimNames);
951+ ctaLayout *= LinearLayout ({{S (" warp" ), warps}}, kMajorDims );
900952
901- return combineCtaCgaWithShape (ctaLayout, mma. getCTALayout (), shape);
953+ return combineCtaCgaWithShape (ctaLayout, getCTALayout (dot ), shape);
902954}
903955
904956std::optional<LinearLayout>
@@ -907,7 +959,7 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
907959 if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
908960 return mfmaDotToLinearLayout (*this , shape);
909961 } else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
910- if (mma.getVersionMajor () == 2 && mma. getVersionMinor () == 0 ) {
962+ if (mma.isAmpere () ) {
911963 return ampereDotToLinearLayout (shape, *this );
912964 }
913965 }
0 commit comments