@@ -280,78 +280,6 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
280280 return ret;
281281}
282282
283- LinearLayout ampereMmaToLinearLayout (ArrayRef<int64_t > shape,
284- NvidiaMmaEncodingAttr mma) {
285- int rank = shape.size ();
286-
287- assert (mma.isAmpere ());
288- assert (rank == 2 || rank == 3 );
289- assert (mma.getInstrShape ().size () == rank);
290- assert ((rank == 2 && mma.getInstrShape () == ArrayRef<unsigned >({16 , 8 })) ||
291- (rank == 3 && mma.getInstrShape () == ArrayRef<unsigned >({1 , 16 , 8 })));
292-
293- MLIRContext *ctx = mma.getContext ();
294- SmallVector<StringAttr> dimNames = standardOutDimNames (ctx, rank);
295-
296- auto orderedDimNames = permuteDimNames (dimNames, mma.getRepOrder ());
297- assert (mma.getRepOrder () == getMatrixOrder (rank, /* rowMajor=*/ true ));
298-
299- LinearLayout ctaLayout (
300- {{S (" register" ), {{1 , 0 }, {0 , 8 }}},
301- {S (" lane" ), {{2 , 0 }, {4 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }}}},
302- ArrayRef (orderedDimNames).take_front (2 ));
303- assert (getWarpOrder (mma) == getMatrixOrder (rank, /* rowMajor=*/ true ));
304- // FIXME(Lezcano). identityND should not have an `order` param as it's
305- // redundant with the order of the out dims.
306- ctaLayout *=
307- identityND (S (" warp" ), mma.getWarpsPerCTA (), mma.getWarpOrder (), dimNames);
308-
309- return combineCtaCgaWithShape (ctaLayout, mma.getCTALayout (), shape);
310- }
311-
312- LinearLayout hopperMmaToLinearLayout (ArrayRef<int64_t > shape,
313- NvidiaMmaEncodingAttr mma) {
314- int rank = shape.size ();
315- assert (mma.isHopper ());
316- assert (rank == 2 );
317-
318- // wgmma operates on groups of 4 warps.
319- assert (product (mma.getWarpsPerCTA ()) % 4 == 0 );
320-
321- // Check that it's a known MMA layout.
322- assert (mma.getInstrShape ().size () == 3 );
323- int m = mma.getInstrShape ()[0 ];
324- int n = mma.getInstrShape ()[1 ];
325- int k = mma.getInstrShape ()[2 ];
326- assert (m == 16 );
327- assert (n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256 );
328- assert (k == 8 || k == 16 || k == 32 );
329-
330- MLIRContext *ctx = mma.getContext ();
331- LinearLayout ctaLayout (
332- {{S (" register" ), {{1 , 0 }, {0 , 8 }}},
333- {S (" lane" ), {{2 , 0 }, {4 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }}}},
334- {S (" dim1" ), S (" dim0" )});
335-
336- // Expand the `register` dimension so the size of dim1 matches `n`.
337- ctaLayout *= LinearLayout::identity1D (n / ctaLayout.getOutDimSize (S (" dim1" )),
338- S (" register" ), S (" dim1" ));
339-
340- // The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major.
341- // Since the warpOrder needs to be M-major, we need to transpose the out
342- // dimensions AND transpose the order
343- // FIXME(Lezcano). identityND should not have an `order` param as it's
344- // redundant. The order is already given by the order of the
345- // out dims, and if it has an order, it shouldn't change the
346- // order of the out dims.
347- assert (getWarpOrder (mma) == SmallVector<unsigned >({0 , 1 }));
348- ctaLayout *= identityND (S (" warp" ), mma.getWarpsPerCTA (), /* order=*/ {0 , 1 },
349- {S (" dim0" ), S (" dim1" )})
350- .transposeOuts (llvm::to_vector (ctaLayout.getOutDimNames ()));
351-
352- return combineCtaCgaWithShape (ctaLayout, mma.getCTALayout (), shape);
353- }
354-
355283LinearLayout sharedToLinearLayoutNoLeadingOffset (ArrayRef<int64_t > shape,
356284 SharedEncodingAttr shared) {
357285 assert (!shared.getHasLeadingOffset ());
@@ -779,13 +707,153 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
779707 return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
780708}
781709
710+ LinearLayout nvidiaMmaTile (MLIRContext *ctx, ArrayRef<unsigned > tileShape,
711+ unsigned kWidth , ArrayRef<unsigned > order,
712+ ArrayRef<unsigned > repOrder) {
713+ // Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder
714+ int rank = repOrder.size ();
715+ auto dimNames = standardOutDimNames (ctx, rank);
716+ auto trivialShape = SmallVector<unsigned >(rank, 1 );
717+ LinearLayout ctaLayout =
718+ identityND (S (" register" ), trivialShape, repOrder, dimNames);
719+
720+ assert (rank >= 2 );
721+ auto inner = order[0 ];
722+ auto outer = order[1 ];
723+
724+ assert (tileShape.size () == rank);
725+ int m = tileShape[outer];
726+ int n = tileShape[inner];
727+
728+ // The relative order of registers and lanes is given by:
729+ // - Inner dim: kWidth registers
730+ // - Inner dim: 4 lanes
731+ // - Outer dim: 8 lanes
732+ // - Outer dim: repeat m / 8 times
733+ // - Inner dim: repeat n / (kWidth * 4) times
734+ assert (m % 8 == 0 );
735+ assert (n % (kWidth * 4 ) == 0 );
736+ // There is at least one subtile on the inner-most dimension
737+ // FIXME. We should implement operator* in terms of operator*=
738+ // and chain *= instead of using *
739+ auto outDimNames = llvm::to_vector (ctaLayout.getOutDimNames ());
740+ ctaLayout = ctaLayout *
741+ LinearLayout::identity1D (kWidth , S (" register" ), dimNames[inner]) *
742+ LinearLayout::identity1D (4 , S (" lane" ), dimNames[inner]) *
743+ LinearLayout::identity1D (8 , S (" lane" ), dimNames[outer]) *
744+ LinearLayout::identity1D (m / 8 , S (" register" ), dimNames[outer]) *
745+ LinearLayout::identity1D (n / (kWidth * 4 ), S (" register" ),
746+ dimNames[inner]);
747+ return ctaLayout;
748+ }
749+
782750std::optional<LinearLayout>
783751NvidiaMmaEncodingAttr::toLinearLayout (ArrayRef<int64_t > shape) const {
752+ auto ctx = getContext ();
753+ int rank = shape.size ();
754+
755+ SmallVector<unsigned > tileShape;
784756 if (isAmpere ()) {
785- return ampereMmaToLinearLayout (shape, *this );
757+ // Ampere.getInstrShape() returns the tile shape
758+ tileShape = SmallVector<unsigned >(getInstrShape ());
759+ } else {
760+ assert (isHopper ());
761+ auto instrShapeMNK = getInstrShape ();
762+ tileShape = SmallVector<unsigned >({instrShapeMNK[0 ], instrShapeMNK[1 ]});
786763 }
787- if (isHopper ()) {
788- return hopperMmaToLinearLayout (shape, *this );
764+ // nvidiamma layout always assumes kWidth = 2
765+ constexpr auto kWidth = 2 ;
766+ auto ctaLayout =
767+ nvidiaMmaTile (ctx, tileShape, kWidth , getOrder (*this ), getRepOrder ());
768+
769+ // The triton orders are defined on [dim0, dim1, ...], so we need to pass
770+ // those dims Then, for some reason, operator* requires the orders to match
771+ // so we need to reorder the outs to match
772+ // FIXME(Lezcano). identityND should not take a dim name, as it's redundant.
773+ // The order in triton assumes the standardDims, so it should
774+ // use those.
775+ ctaLayout *= identityND (S (" warp" ), getWarpsPerCTA (), getWarpOrder (),
776+ standardOutDimNames (ctx, rank))
777+ .transposeOuts (llvm::to_vector (ctaLayout.getOutDimNames ()));
778+
779+ return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
780+ }
781+
782+ LinearLayout warpsNvidiaDot (MLIRContext *ctx, ArrayRef<unsigned > mmaWarpShape,
783+ ArrayRef<unsigned > mmaWarpOrder, bool isA) {
784+ // Let warpsPerCTAMma = {2, 2}, then
785+ // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
786+ // assume warpOrder = {1, 0}
787+ // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
788+ // the C is owned as per the following layout:
789+ // C: 0 | 1
790+ // - | -
791+ // 2 | 3
792+ // In order to be able to compute C, we need the following warp tiling of
793+ // A and B:
794+ // A: 0 1 | 0 1 B: 0 2 | 1 3
795+ // - - | - - - - | - -
796+ // 2 3 | 2 3 0 2 | 1 3
797+ // In other words, we need to broadcast along K
798+ auto rank = mmaWarpOrder.size ();
799+ auto inner = isA ? rank - 1 : rank - 2 ;
800+ auto outer = isA ? rank - 2 : rank - 1 ;
801+ auto dimNames = standardOutDimNames (ctx, rank);
802+ auto trivialShape = SmallVector<unsigned >(rank, 1 );
803+ LinearLayout warpLayout =
804+ identityND (S (" warp" ), trivialShape, mmaWarpOrder, dimNames);
805+
806+ // We have to broadcast along the inner dimension
807+ // For A, when moving along M we go from 0 to 2.
808+ // For B, when moving along N we go from 0 to 1.
809+ // As such, choosing the order of A {1, 0}, gives us the correct broadcasting
810+ // Same happens if the mmaWarpOrder is {0, 1}, like in Hopper
811+ for (auto d : mmaWarpOrder) {
812+ if (d == inner) {
813+ warpLayout *=
814+ LinearLayout::zeros1D (mmaWarpShape[d], S (" warp" ), dimNames[d]);
815+ } else {
816+ warpLayout *=
817+ LinearLayout::identity1D (mmaWarpShape[d], S (" warp" ), dimNames[d]);
818+ }
819+ }
820+ return warpLayout;
821+ }
822+
823+ LinearLayout nvidiaDotToLinearLayout (ArrayRef<int64_t > shape,
824+ DotOperandEncodingAttr dot) {
825+ int rank = shape.size ();
826+ auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent ());
827+ int kWidth = dot.getKWidth ();
828+ bool isA = dot.getOpIdx () == 0 ;
829+ MLIRContext *ctx = mma.getContext ();
830+
831+ SmallVector<unsigned > tileShape (rank, 1 );
832+ if (isA) {
833+ tileShape[rank - 2 ] = 16 ;
834+ tileShape[rank - 1 ] = kWidth * 8 ;
835+ } else {
836+ // Hopper takes the rhs via shared memory
837+ assert (mma.isAmpere ());
838+ tileShape[rank - 2 ] = kWidth * 8 ;
839+ tileShape[rank - 1 ] = 8 ;
840+ }
841+ auto ctaLayout =
842+ nvidiaMmaTile (ctx, tileShape, kWidth , getOrder (dot), dot.getRepOrder ());
843+ ctaLayout *=
844+ warpsNvidiaDot (ctx, mma.getWarpsPerCTA (), mma.getWarpOrder (), isA)
845+ .transposeOuts (llvm::to_vector (ctaLayout.getOutDimNames ()));
846+
847+ return combineCtaCgaWithShape (ctaLayout, getCTALayout (dot), shape);
848+ }
849+
850+ std::optional<LinearLayout>
851+ DotOperandEncodingAttr::toLinearLayout (ArrayRef<int64_t > shape) const {
852+ auto parent = getParent ();
853+ if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
854+ return mfmaDotToLinearLayout (*this , shape);
855+ } else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
856+ return nvidiaDotToLinearLayout (shape, *this );
789857 }
790858 return std::nullopt ;
791859}
@@ -860,116 +928,6 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
860928 return ret;
861929}
862930
863- LinearLayout ampereDotToLinearLayout (ArrayRef<int64_t > shape,
864- DotOperandEncodingAttr dot) {
865- // Note that, even though MMAv2 looks similar to this layout, they are just
866- // the same at a register and lane level. The warps treatment is different!
867- int rank = shape.size ();
868- auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent ());
869- int kWidth = dot.getKWidth ();
870- bool isA = dot.getOpIdx () == 0 ;
871-
872- assert ((rank == 2 && mma.getInstrShape () == ArrayRef<unsigned >({16 , 8 })) ||
873- (rank == 3 && mma.getInstrShape () == ArrayRef<unsigned >({1 , 16 , 8 })));
874- assert (mma.isAmpere ());
875-
876- MLIRContext *ctx = mma.getContext ();
877-
878- // The A and B operands are tiled in a kMajor fashion
879- auto kMajorOrder = dot.getRepOrder ();
880- assert (kMajorOrder ==
881- getOrderForDotOperand (dot.getOpIdx (), rank, /* kMajor=*/ true ));
882-
883- auto kMajorDims =
884- permuteDimNames (standardOutDimNames (ctx, rank), kMajorOrder );
885- // This agrees with the order of the elements, which means that we can share
886- // the code below for both A and B without having to perform any swaps
887- assert (getOrder (dot) == kMajorOrder );
888-
889- std::vector<std::vector<int32_t >> registers;
890- std::vector<std::vector<int32_t >> lanes;
891- int32_t i = 1 ;
892- // kWidth contiguous elements
893- while (i < kWidth ) {
894- registers.push_back ({i, 0 });
895- i *= 2 ;
896- }
897- // 4 threads per chunk
898- for (int j = 0 ; j < 2 ; j++) {
899- lanes.push_back ({i, 0 });
900- i *= 2 ;
901- }
902- // 8 threads going down
903- lanes.push_back ({0 , 1 });
904- lanes.push_back ({0 , 2 });
905- lanes.push_back ({0 , 4 });
906- // 2 tiles in column-major order
907- // Just one if it's the B operand
908- if (isA) {
909- registers.push_back ({0 , 8 });
910- }
911- registers.push_back ({i, 0 });
912-
913- LinearLayout ctaLayout ({{S (" register" ), registers}, {S (" lane" ), lanes}},
914- ArrayRef (kMajorDims ).take_front (2 ));
915-
916- // Let warpsPerCTAMma = {2, 2}, then
917- // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
918- // assume warpOrder = {0, 1}
919- // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
920- // the C is owned as per the following layout:
921- // C: 0 | 1
922- // - | -
923- // 2 | 3
924- // In order to be able to compute C, we need the following warp tiling of
925- // A and B:
926- // A: 0 1 | 0 1 B: 0 2 | 1 3
927- // - - | - - - - | - -
928- // 2 3 | 2 3 0 2 | 1 3
929- // In particular, for A and B we need to broadcast along K
930-
931- assert (mma.getWarpOrder () == getMatrixOrder (rank, /* rowMajor=*/ true ));
932- auto warpsPerCTAMma = mma.getWarpsPerCTA ();
933- std::vector<std::vector<int32_t >> warps;
934- if (isA) {
935- for (int i = 1 ; i < warpsPerCTAMma[1 ]; i *= 2 ) {
936- warps.push_back ({0 , 0 });
937- }
938- for (int i = 1 ; i < warpsPerCTAMma[0 ]; i *= 2 ) {
939- warps.push_back ({0 , i});
940- }
941- } else {
942- for (int i = 1 ; i < warpsPerCTAMma[1 ]; i *= 2 ) {
943- warps.push_back ({0 , i});
944- }
945- for (int i = 1 ; i < warpsPerCTAMma[0 ]; i *= 2 ) {
946- warps.push_back ({0 , 0 });
947- }
948- }
949- if (rank == 3 ) {
950- for (auto &w : warps) {
951- w.push_back (0 );
952- }
953- }
954-
955- ctaLayout *= LinearLayout ({{S (" warp" ), warps}}, kMajorDims );
956-
957- return combineCtaCgaWithShape (ctaLayout, getCTALayout (dot), shape);
958- }
959-
960- std::optional<LinearLayout>
961- DotOperandEncodingAttr::toLinearLayout (ArrayRef<int64_t > shape) const {
962- auto parent = getParent ();
963- if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
964- return mfmaDotToLinearLayout (*this , shape);
965- } else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
966- if (mma.isAmpere ()) {
967- return ampereDotToLinearLayout (shape, *this );
968- }
969- }
970- return std::nullopt ;
971- }
972-
973931std::optional<LinearLayout>
974932toLinearLayout (ArrayRef<int64_t > shape, Attribute layout,
975933 std::optional<int32_t > elemBitWidth /* = std::nullopt*/ ) {
0 commit comments