Skip to content

Commit 7bce361

Browse files
authored
[LAYOUTS] Implement LL conversion for DotOperand(Hopper) (#5193)
We also rewrite the way we implement DotOperand(Ampere) and mma Ampere to promote code reusing. I also started using what I believe is a rather compact pattern to write these things, where you first call `identiyND` with the `repOrder`, which gives you an LL with the dims in the correct order, and then you construct the final layout by specifying the tiles by multiplying `identity1D` maps. Using this allowed me to heavily simplify the handling of the `warps` of `DotOperand` which used to be a tad messy.
1 parent 54c840b commit 7bce361

File tree

3 files changed

+267
-201
lines changed

3 files changed

+267
-201
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,8 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
10171017
assert(rank == 2 || rank == 3 && "Invalid dotLayout");
10181018

10191019
// Do not split CTA in K dimension
1020-
getOpIdx() == 0 ? res[rank - 1] = 1 : res[rank - 2] = 1;
1020+
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
1021+
res[kDim] = 1;
10211022
return res;
10221023
}
10231024
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 143 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -280,78 +280,6 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
280280
return ret;
281281
}
282282

283-
LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,
284-
NvidiaMmaEncodingAttr mma) {
285-
int rank = shape.size();
286-
287-
assert(mma.isAmpere());
288-
assert(rank == 2 || rank == 3);
289-
assert(mma.getInstrShape().size() == rank);
290-
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
291-
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
292-
293-
MLIRContext *ctx = mma.getContext();
294-
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
295-
296-
auto orderedDimNames = permuteDimNames(dimNames, mma.getRepOrder());
297-
assert(mma.getRepOrder() == getMatrixOrder(rank, /*rowMajor=*/true));
298-
299-
LinearLayout ctaLayout(
300-
{{S("register"), {{1, 0}, {0, 8}}},
301-
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
302-
ArrayRef(orderedDimNames).take_front(2));
303-
assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
304-
// FIXME(Lezcano). identityND should not have an `order` param as it's
305-
// redundant with the order of the out dims.
306-
ctaLayout *=
307-
identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames);
308-
309-
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
310-
}
311-
312-
LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
313-
NvidiaMmaEncodingAttr mma) {
314-
int rank = shape.size();
315-
assert(mma.isHopper());
316-
assert(rank == 2);
317-
318-
// wgmma operates on groups of 4 warps.
319-
assert(product(mma.getWarpsPerCTA()) % 4 == 0);
320-
321-
// Check that it's a known MMA layout.
322-
assert(mma.getInstrShape().size() == 3);
323-
int m = mma.getInstrShape()[0];
324-
int n = mma.getInstrShape()[1];
325-
int k = mma.getInstrShape()[2];
326-
assert(m == 16);
327-
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
328-
assert(k == 8 || k == 16 || k == 32);
329-
330-
MLIRContext *ctx = mma.getContext();
331-
LinearLayout ctaLayout(
332-
{{S("register"), {{1, 0}, {0, 8}}},
333-
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
334-
{S("dim1"), S("dim0")});
335-
336-
// Expand the `register` dimension so the size of dim1 matches `n`.
337-
ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")),
338-
S("register"), S("dim1"));
339-
340-
// The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major.
341-
// Since the warpOrder needs to be M-major, we need to transpose the out
342-
// dimensions AND transpose the order
343-
// FIXME(Lezcano). identityND should not have an `order` param as it's
344-
// redundant. The order is already given by the order of the
345-
// out dims, and if it has an order, it shouldn't change the
346-
// order of the out dims.
347-
assert(getWarpOrder(mma) == SmallVector<unsigned>({0, 1}));
348-
ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1},
349-
{S("dim0"), S("dim1")})
350-
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
351-
352-
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
353-
}
354-
355283
LinearLayout sharedToLinearLayoutNoLeadingOffset(ArrayRef<int64_t> shape,
356284
SharedEncodingAttr shared) {
357285
assert(!shared.getHasLeadingOffset());
@@ -779,13 +707,153 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
779707
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
780708
}
781709

710+
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
711+
unsigned kWidth, ArrayRef<unsigned> order,
712+
ArrayRef<unsigned> repOrder) {
713+
// Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder
714+
int rank = repOrder.size();
715+
auto dimNames = standardOutDimNames(ctx, rank);
716+
auto trivialShape = SmallVector<unsigned>(rank, 1);
717+
LinearLayout ctaLayout =
718+
identityND(S("register"), trivialShape, repOrder, dimNames);
719+
720+
assert(rank >= 2);
721+
auto inner = order[0];
722+
auto outer = order[1];
723+
724+
assert(tileShape.size() == rank);
725+
int m = tileShape[outer];
726+
int n = tileShape[inner];
727+
728+
// The relative order of registers and lanes is given by:
729+
// - Inner dim: kWidth registers
730+
// - Inner dim: 4 lanes
731+
// - Outer dim: 8 lanes
732+
// - Outer dim: repeat m / 8 times
733+
// - Inner dim: repeat n / (kWidth * 4) times
734+
assert(m % 8 == 0);
735+
assert(n % (kWidth * 4) == 0);
736+
// There is at least one subtile on the inner-most dimension
737+
// FIXME. We should implement operator* in terms of operator*=
738+
// and chain *= instead of using *
739+
auto outDimNames = llvm::to_vector(ctaLayout.getOutDimNames());
740+
ctaLayout = ctaLayout *
741+
LinearLayout::identity1D(kWidth, S("register"), dimNames[inner]) *
742+
LinearLayout::identity1D(4, S("lane"), dimNames[inner]) *
743+
LinearLayout::identity1D(8, S("lane"), dimNames[outer]) *
744+
LinearLayout::identity1D(m / 8, S("register"), dimNames[outer]) *
745+
LinearLayout::identity1D(n / (kWidth * 4), S("register"),
746+
dimNames[inner]);
747+
return ctaLayout;
748+
}
749+
782750
std::optional<LinearLayout>
783751
NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
752+
auto ctx = getContext();
753+
int rank = shape.size();
754+
755+
SmallVector<unsigned> tileShape;
784756
if (isAmpere()) {
785-
return ampereMmaToLinearLayout(shape, *this);
757+
// Ampere.getInstrShape() returns the tile shape
758+
tileShape = SmallVector<unsigned>(getInstrShape());
759+
} else {
760+
assert(isHopper());
761+
auto instrShapeMNK = getInstrShape();
762+
tileShape = SmallVector<unsigned>({instrShapeMNK[0], instrShapeMNK[1]});
786763
}
787-
if (isHopper()) {
788-
return hopperMmaToLinearLayout(shape, *this);
764+
// nvidiamma layout always assumes kWidth = 2
765+
constexpr auto kWidth = 2;
766+
auto ctaLayout =
767+
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(*this), getRepOrder());
768+
769+
// The triton orders are defined on [dim0, dim1, ...], so we need to pass
770+
// those dims Then, for some reason, operator* requires the orders to match
771+
// so we need to reorder the outs to match
772+
// FIXME(Lezcano). identityND should not take a dim name, as it's redundant.
773+
// The order in triton assumes the standardDims, so it should
774+
// use those.
775+
ctaLayout *= identityND(S("warp"), getWarpsPerCTA(), getWarpOrder(),
776+
standardOutDimNames(ctx, rank))
777+
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
778+
779+
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
780+
}
781+
782+
LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef<unsigned> mmaWarpShape,
783+
ArrayRef<unsigned> mmaWarpOrder, bool isA) {
784+
// Let warpsPerCTAMma = {2, 2}, then
785+
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
786+
// assume warpOrder = {1, 0}
787+
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
788+
// the C is owned as per the following layout:
789+
// C: 0 | 1
790+
// - | -
791+
// 2 | 3
792+
// In order to be able to compute C, we need the following warp tiling of
793+
// A and B:
794+
// A: 0 1 | 0 1 B: 0 2 | 1 3
795+
// - - | - - - - | - -
796+
// 2 3 | 2 3 0 2 | 1 3
797+
// In other words, we need to broadcast along K
798+
auto rank = mmaWarpOrder.size();
799+
auto inner = isA ? rank - 1 : rank - 2;
800+
auto outer = isA ? rank - 2 : rank - 1;
801+
auto dimNames = standardOutDimNames(ctx, rank);
802+
auto trivialShape = SmallVector<unsigned>(rank, 1);
803+
LinearLayout warpLayout =
804+
identityND(S("warp"), trivialShape, mmaWarpOrder, dimNames);
805+
806+
// We have to broadcast along the inner dimension
807+
// For A, when moving along M we go from 0 to 2.
808+
// For B, when moving along N we go from 0 to 1.
809+
// As such, choosing the order of A {1, 0}, gives us the correct broadcasting
810+
// Same happens if the mmaWarpOrder is {0, 1}, like in Hopper
811+
for (auto d : mmaWarpOrder) {
812+
if (d == inner) {
813+
warpLayout *=
814+
LinearLayout::zeros1D(mmaWarpShape[d], S("warp"), dimNames[d]);
815+
} else {
816+
warpLayout *=
817+
LinearLayout::identity1D(mmaWarpShape[d], S("warp"), dimNames[d]);
818+
}
819+
}
820+
return warpLayout;
821+
}
822+
823+
LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
824+
DotOperandEncodingAttr dot) {
825+
int rank = shape.size();
826+
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
827+
int kWidth = dot.getKWidth();
828+
bool isA = dot.getOpIdx() == 0;
829+
MLIRContext *ctx = mma.getContext();
830+
831+
SmallVector<unsigned> tileShape(rank, 1);
832+
if (isA) {
833+
tileShape[rank - 2] = 16;
834+
tileShape[rank - 1] = kWidth * 8;
835+
} else {
836+
// Hopper takes the rhs via shared memory
837+
assert(mma.isAmpere());
838+
tileShape[rank - 2] = kWidth * 8;
839+
tileShape[rank - 1] = 8;
840+
}
841+
auto ctaLayout =
842+
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder());
843+
ctaLayout *=
844+
warpsNvidiaDot(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), isA)
845+
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
846+
847+
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
848+
}
849+
850+
std::optional<LinearLayout>
851+
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
852+
auto parent = getParent();
853+
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
854+
return mfmaDotToLinearLayout(*this, shape);
855+
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
856+
return nvidiaDotToLinearLayout(shape, *this);
789857
}
790858
return std::nullopt;
791859
}
@@ -860,116 +928,6 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
860928
return ret;
861929
}
862930

863-
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
864-
DotOperandEncodingAttr dot) {
865-
// Note that, even though MMAv2 looks similar to this layout, they are just
866-
// the same at a register and lane level. The warps treatment is different!
867-
int rank = shape.size();
868-
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
869-
int kWidth = dot.getKWidth();
870-
bool isA = dot.getOpIdx() == 0;
871-
872-
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
873-
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
874-
assert(mma.isAmpere());
875-
876-
MLIRContext *ctx = mma.getContext();
877-
878-
// The A and B operands are tiled in a kMajor fashion
879-
auto kMajorOrder = dot.getRepOrder();
880-
assert(kMajorOrder ==
881-
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));
882-
883-
auto kMajorDims =
884-
permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder);
885-
// This agrees with the order of the elements, which means that we can share
886-
// the code below for both A and B without having to perform any swaps
887-
assert(getOrder(dot) == kMajorOrder);
888-
889-
std::vector<std::vector<int32_t>> registers;
890-
std::vector<std::vector<int32_t>> lanes;
891-
int32_t i = 1;
892-
// kWidth contiguous elements
893-
while (i < kWidth) {
894-
registers.push_back({i, 0});
895-
i *= 2;
896-
}
897-
// 4 threads per chunk
898-
for (int j = 0; j < 2; j++) {
899-
lanes.push_back({i, 0});
900-
i *= 2;
901-
}
902-
// 8 threads going down
903-
lanes.push_back({0, 1});
904-
lanes.push_back({0, 2});
905-
lanes.push_back({0, 4});
906-
// 2 tiles in column-major order
907-
// Just one if it's the B operand
908-
if (isA) {
909-
registers.push_back({0, 8});
910-
}
911-
registers.push_back({i, 0});
912-
913-
LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}},
914-
ArrayRef(kMajorDims).take_front(2));
915-
916-
// Let warpsPerCTAMma = {2, 2}, then
917-
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
918-
// assume warpOrder = {0, 1}
919-
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
920-
// the C is owned as per the following layout:
921-
// C: 0 | 1
922-
// - | -
923-
// 2 | 3
924-
// In order to be able to compute C, we need the following warp tiling of
925-
// A and B:
926-
// A: 0 1 | 0 1 B: 0 2 | 1 3
927-
// - - | - - - - | - -
928-
// 2 3 | 2 3 0 2 | 1 3
929-
// In particular, for A and B we need to broadcast along K
930-
931-
assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true));
932-
auto warpsPerCTAMma = mma.getWarpsPerCTA();
933-
std::vector<std::vector<int32_t>> warps;
934-
if (isA) {
935-
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
936-
warps.push_back({0, 0});
937-
}
938-
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
939-
warps.push_back({0, i});
940-
}
941-
} else {
942-
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
943-
warps.push_back({0, i});
944-
}
945-
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
946-
warps.push_back({0, 0});
947-
}
948-
}
949-
if (rank == 3) {
950-
for (auto &w : warps) {
951-
w.push_back(0);
952-
}
953-
}
954-
955-
ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims);
956-
957-
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
958-
}
959-
960-
std::optional<LinearLayout>
961-
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
962-
auto parent = getParent();
963-
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
964-
return mfmaDotToLinearLayout(*this, shape);
965-
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
966-
if (mma.isAmpere()) {
967-
return ampereDotToLinearLayout(shape, *this);
968-
}
969-
}
970-
return std::nullopt;
971-
}
972-
973931
std::optional<LinearLayout>
974932
toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
975933
std::optional<int32_t> elemBitWidth /*= std::nullopt*/) {

0 commit comments

Comments
 (0)