Skip to content

Commit 1214ac7

Browse files
authored
[LLs] [BE] Simplify identityND (#5199)
The auxiliary function `identityND` used to take an `order` parameter, that comes from triton, and a set of dimensions. Now, the order in triton is defined wrt. `dim0..dim<rank-1>`, so the dimension arg was redundant. This was quite confusing. We see that in all the uses of `identiyND`, we would pass the canonical dimensions, other than in one that we simply remove as it was not necessary. We remove the dims arg and simply return a layout with output dims `dim0..dim<rank-1>`.
1 parent 95e569e commit 1214ac7

File tree

1 file changed

+27
-33
lines changed

1 file changed

+27
-33
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace {
3232

3333
#define S(v) StringAttr::get(ctx, (v))
3434

35-
// Returns ["out0", "out1", ..., "out<rank-1>"].
35+
// Returns ["dim0", "dim1", ..., "dim<rank-1>"].
3636
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
3737
SmallVector<StringAttr> ret;
3838
for (int i = 0; i < rank; i++) {
@@ -71,14 +71,18 @@ void assertIsRegisterLayout(const LinearLayout &layout) {
7171
expectedOuts.end()));
7272
}
7373

74-
// Returns a 1D -> ND layout that's equivalent to creating a 1D -> 1D mapping of
75-
// size product(shape) and then reshaping to permute(shape, order).
76-
LinearLayout identityND(StringAttr inDimName, ArrayRef<unsigned> shape,
77-
ArrayRef<unsigned> order,
78-
ArrayRef<StringAttr> outDimNames) {
74+
// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to
75+
// creating a 1D -> 1D mapping of size product(shape) and then reshaping to
76+
// permute(shape, order).
77+
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
78+
ArrayRef<unsigned> order) {
7979
assert(shape.size() == order.size());
80-
8180
MLIRContext *ctx = inDimName.getContext();
81+
auto rank = shape.size();
82+
83+
// The order in triton is written wrt. [dim0, dim1, ...].
84+
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
85+
8286
LinearLayout ret = LinearLayout::empty();
8387
for (int i = 0; i < shape.size(); i++) {
8488
// Start with the most-minor dimension, which is order[0].
@@ -491,7 +495,7 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
491495
// And each warp takes the same register and lane sub-layout. So mulitply with
492496
// an identity layout for the warp.
493497
LinearLayout warpLayout =
494-
identityND(S("warp"), getWarpsPerCTA(), order, outDimNames);
498+
identityStandardND(S("warp"), getWarpsPerCTA(), order);
495499
LinearLayout ctaLayout = tileLayout * warpLayout;
496500

497501
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
@@ -601,8 +605,7 @@ mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
601605
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
602606
}
603607

604-
LinearLayout warpLayout =
605-
identityND(kWarp, warpsPerCTA, warpOrder, outDimNames);
608+
LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder);
606609

607610
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
608611
warpLayout.transposeOuts(outDimNames);
@@ -684,7 +687,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
684687
// And each warp takes the same register and lane sub-layout. So mulitply with
685688
// an identity layout for the warp.
686689
LinearLayout warpLayout =
687-
identityND(S("warp"), getWarpsPerCTA(), order, outDimNames);
690+
identityStandardND(S("warp"), getWarpsPerCTA(), order);
688691
LinearLayout ctaLayout = tileLayout * warpLayout;
689692

690693
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
@@ -700,9 +703,9 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
700703

701704
const auto &order = getOrder();
702705
LinearLayout ctaLayout =
703-
identityND(S("register"), getSizePerThread(), order, outDimNames) *
704-
identityND(S("lane"), getThreadsPerWarp(), order, outDimNames) *
705-
identityND(S("warp"), getWarpsPerCTA(), order, outDimNames);
706+
identityStandardND(S("register"), getSizePerThread(), order) *
707+
identityStandardND(S("lane"), getThreadsPerWarp(), order) *
708+
identityStandardND(S("warp"), getWarpsPerCTA(), order);
706709

707710
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
708711
}
@@ -711,11 +714,12 @@ LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
711714
unsigned kWidth, ArrayRef<unsigned> order,
712715
ArrayRef<unsigned> repOrder) {
713716
// Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder
717+
// Like LinearLayout::empty() but with a rank and an order
714718
int rank = repOrder.size();
715719
auto dimNames = standardOutDimNames(ctx, rank);
716720
auto trivialShape = SmallVector<unsigned>(rank, 1);
717721
LinearLayout ctaLayout =
718-
identityND(S("register"), trivialShape, repOrder, dimNames);
722+
identityStandardND(S("register"), trivialShape, repOrder);
719723

720724
assert(rank >= 2);
721725
auto inner = order[0];
@@ -769,11 +773,7 @@ NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
769773
// The triton orders are defined on [dim0, dim1, ...], so we need to pass
770774
// those dims Then, for some reason, operator* requires the orders to match
771775
// so we need to reorder the outs to match
772-
// FIXME(Lezcano). identityND should not take a dim name, as it's redundant.
773-
// The order in triton assumes the standardDims, so it should
774-
// use those.
775-
ctaLayout *= identityND(S("warp"), getWarpsPerCTA(), getWarpOrder(),
776-
standardOutDimNames(ctx, rank))
776+
ctaLayout *= identityStandardND(S("warp"), getWarpsPerCTA(), getWarpOrder())
777777
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
778778

779779
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
@@ -797,11 +797,8 @@ LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef<unsigned> mmaWarpShape,
797797
// In other words, we need to broadcast along K
798798
auto rank = mmaWarpOrder.size();
799799
auto inner = isA ? rank - 1 : rank - 2;
800-
auto outer = isA ? rank - 2 : rank - 1;
801800
auto dimNames = standardOutDimNames(ctx, rank);
802-
auto trivialShape = SmallVector<unsigned>(rank, 1);
803-
LinearLayout warpLayout =
804-
identityND(S("warp"), trivialShape, mmaWarpOrder, dimNames);
801+
LinearLayout warpLayout = LinearLayout::empty();
805802

806803
// We have to broadcast along the inner dimension
807804
// For A, when moving along M we go from 0 to 2.
@@ -1086,9 +1083,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
10861083

10871084
// Expand the `warp` dimension according to warpsPerCTA.
10881085
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
1089-
layout *=
1090-
identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol})
1091-
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
1086+
layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
1087+
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
10921088

10931089
// Expand the `register` dimension so the size of columns matches `n`.
10941090
int n = mma.getInstrShape()[1];
@@ -1126,9 +1122,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
11261122
LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol);
11271123

11281124
// Expand the `warp` dimension according to warpsPerCTA.
1129-
layout *=
1130-
identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol})
1131-
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
1125+
layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
1126+
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
11321127
auto ret =
11331128
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
11341129
auto tensorShapePerCTA = getShapePerCTA(mma, tensorTy.getShape());
@@ -1138,9 +1133,8 @@ std::optional<LinearLayout> chooseStMatrixLayoutNoLeadingOffset(
11381133
ret = ensureLayoutNotSmallerThan(ret, namedTensorShape);
11391134
ret = ensureLayoutNotLargerThan(ret, namedTensorShape);
11401135
return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames()))
1141-
.reshapeOuts({{S("offset"), ret.getTotalOutDimSize()},
1142-
{S("iteration"), 1}}) *
1143-
identityND(kBlock, {1, 1}, {0, 1}, {S("offset"), S("iteration")});
1136+
.reshapeOuts(
1137+
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
11441138
}
11451139

11461140
} // anonymous namespace

0 commit comments

Comments
 (0)