@@ -32,7 +32,7 @@ namespace {
3232
3333#define S (v ) StringAttr::get(ctx, (v))
3434
35- // Returns ["out0 ", "out1 ", ..., "out <rank-1>"].
35+ // Returns ["dim0 ", "dim1 ", ..., "dim <rank-1>"].
3636SmallVector<StringAttr> standardOutDimNames (MLIRContext *ctx, int rank) {
3737 SmallVector<StringAttr> ret;
3838 for (int i = 0 ; i < rank; i++) {
@@ -71,14 +71,18 @@ void assertIsRegisterLayout(const LinearLayout &layout) {
7171 expectedOuts.end ()));
7272}
7373
74- // Returns a 1D -> ND layout that's equivalent to creating a 1D -> 1D mapping of
75- // size product(shape) and then reshaping to permute(shape, order).
76- LinearLayout identityND (StringAttr inDimName, ArrayRef< unsigned > shape,
77- ArrayRef<unsigned > order ,
78- ArrayRef<StringAttr> outDimNames ) {
74+ // Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
75+ // creating a 1D -> 1D mapping of size product(shape) and then reshaping to
76+ // permute( shape, order).
77+ LinearLayout identityStandardND (StringAttr inDimName, ArrayRef<unsigned > shape ,
78+ ArrayRef<unsigned > order ) {
7979 assert (shape.size () == order.size ());
80-
8180 MLIRContext *ctx = inDimName.getContext ();
81+ auto rank = shape.size ();
82+
83+ // The order in triton is written wrt. [dim0, dim1, ...].
84+ SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
85+
8286 LinearLayout ret = LinearLayout::empty ();
8387 for (int i = 0 ; i < shape.size (); i++) {
8488 // Start with the most-minor dimension, which is order[0].
@@ -491,7 +495,7 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
491495 // And each warp takes the same register and lane sub-layout. So mulitply with
492496 // an identity layout for the warp.
493497 LinearLayout warpLayout =
494- identityND (S (" warp" ), getWarpsPerCTA (), order, outDimNames );
498+ identityStandardND (S (" warp" ), getWarpsPerCTA (), order);
495499 LinearLayout ctaLayout = tileLayout * warpLayout;
496500
497501 return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
@@ -601,8 +605,7 @@ mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
601605 tileLayout *= LinearLayout::identity1D (1 , kLane , outDimNames[order[2 ]]);
602606 }
603607
604- LinearLayout warpLayout =
605- identityND (kWarp , warpsPerCTA, warpOrder, outDimNames);
608+ LinearLayout warpLayout = identityStandardND (kWarp , warpsPerCTA, warpOrder);
606609
607610 LinearLayout ctaLayout = tileLayout.transposeOuts (outDimNames) *
608611 warpLayout.transposeOuts (outDimNames);
@@ -684,7 +687,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
684687 // And each warp takes the same register and lane sub-layout. So mulitply with
685688 // an identity layout for the warp.
686689 LinearLayout warpLayout =
687- identityND (S (" warp" ), getWarpsPerCTA (), order, outDimNames );
690+ identityStandardND (S (" warp" ), getWarpsPerCTA (), order);
688691 LinearLayout ctaLayout = tileLayout * warpLayout;
689692
690693 return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
@@ -700,9 +703,9 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
700703
701704 const auto &order = getOrder ();
702705 LinearLayout ctaLayout =
703- identityND (S (" register" ), getSizePerThread (), order, outDimNames ) *
704- identityND (S (" lane" ), getThreadsPerWarp (), order, outDimNames ) *
705- identityND (S (" warp" ), getWarpsPerCTA (), order, outDimNames );
706+ identityStandardND (S (" register" ), getSizePerThread (), order) *
707+ identityStandardND (S (" lane" ), getThreadsPerWarp (), order) *
708+ identityStandardND (S (" warp" ), getWarpsPerCTA (), order);
706709
707710 return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
708711}
@@ -711,11 +714,12 @@ LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
711714 unsigned kWidth , ArrayRef<unsigned > order,
712715 ArrayRef<unsigned > repOrder) {
713716 // Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder
717+ // Like LinearLayout::empty() but with a rank and an order
714718 int rank = repOrder.size ();
715719 auto dimNames = standardOutDimNames (ctx, rank);
716720 auto trivialShape = SmallVector<unsigned >(rank, 1 );
717721 LinearLayout ctaLayout =
718- identityND (S (" register" ), trivialShape, repOrder, dimNames );
722+ identityStandardND (S (" register" ), trivialShape, repOrder);
719723
720724 assert (rank >= 2 );
721725 auto inner = order[0 ];
@@ -769,11 +773,7 @@ NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
769773 // The triton orders are defined on [dim0, dim1, ...], so we need to pass
770774 // those dims Then, for some reason, operator* requires the orders to match
771775 // so we need to reorder the outs to match
772- // FIXME(Lezcano). identityND should not take a dim name, as it's redundant.
773- // The order in triton assumes the standardDims, so it should
774- // use those.
775- ctaLayout *= identityND (S (" warp" ), getWarpsPerCTA (), getWarpOrder (),
776- standardOutDimNames (ctx, rank))
776+ ctaLayout *= identityStandardND (S (" warp" ), getWarpsPerCTA (), getWarpOrder ())
777777 .transposeOuts (llvm::to_vector (ctaLayout.getOutDimNames ()));
778778
779779 return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
@@ -797,11 +797,8 @@ LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef<unsigned> mmaWarpShape,
797797 // In other words, we need to broadcast along K
798798 auto rank = mmaWarpOrder.size ();
799799 auto inner = isA ? rank - 1 : rank - 2 ;
800- auto outer = isA ? rank - 2 : rank - 1 ;
801800 auto dimNames = standardOutDimNames (ctx, rank);
802- auto trivialShape = SmallVector<unsigned >(rank, 1 );
803- LinearLayout warpLayout =
804- identityND (S (" warp" ), trivialShape, mmaWarpOrder, dimNames);
801+ LinearLayout warpLayout = LinearLayout::empty ();
805802
806803 // We have to broadcast along the inner dimension
807804 // For A, when moving along M we go from 0 to 2.
@@ -1086,9 +1083,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
10861083
10871084 // Expand the `warp` dimension according to warpsPerCTA.
10881085 auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding ());
1089- layout *=
1090- identityND (kWarp , mma.getWarpsPerCTA (), /* order=*/ {0 , 1 }, {kRow , kCol })
1091- .transposeOuts (llvm::to_vector (layout.getOutDimNames ()));
1086+ layout *= identityStandardND (kWarp , mma.getWarpsPerCTA (), /* order=*/ {0 , 1 })
1087+ .transposeOuts (llvm::to_vector (layout.getOutDimNames ()));
10921088
10931089 // Expand the `register` dimension so the size of columns matches `n`.
10941090 int n = mma.getInstrShape ()[1 ];
@@ -1126,9 +1122,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
11261122 LinearLayout::identity1D (n / layout.getOutDimSize (kCol ), kReg , kCol );
11271123
11281124 // Expand the `warp` dimension according to warpsPerCTA.
1129- layout *=
1130- identityND (kWarp , mma.getWarpsPerCTA (), /* order=*/ {0 , 1 }, {kRow , kCol })
1131- .transposeOuts (llvm::to_vector (layout.getOutDimNames ()));
1125+ layout *= identityStandardND (kWarp , mma.getWarpsPerCTA (), /* order=*/ {0 , 1 })
1126+ .transposeOuts (llvm::to_vector (layout.getOutDimNames ()));
11321127 auto ret =
11331128 combineCtaCgaWithShape (layout, mma.getCTALayout (), tensorTy.getShape ());
11341129 auto tensorShapePerCTA = getShapePerCTA (mma, tensorTy.getShape ());
@@ -1138,9 +1133,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
11381133 ret = ensureLayoutNotSmallerThan (ret, namedTensorShape);
11391134 ret = ensureLayoutNotLargerThan (ret, namedTensorShape);
11401135 return ret.transposeOuts (llvm::to_vector (layout.getOutDimNames ()))
1141- .reshapeOuts ({{S (" offset" ), ret.getTotalOutDimSize ()},
1142- {S (" iteration" ), 1 }}) *
1143- identityND (kBlock , {1 , 1 }, {0 , 1 }, {S (" offset" ), S (" iteration" )});
1136+ .reshapeOuts (
1137+ {{S (" offset" ), ret.getTotalOutDimSize ()}, {S (" iteration" ), 1 }});
11441138}
11451139
11461140} // anonymous namespace
0 commit comments