Skip to content

Commit 71145a6

Browse files
lezcanoguacamoleo
authored andcommitted
[BACKEND]Fix DotOperand(Ampere) LinearLayoutConversion (triton-lang#5038)
We also clean a bit `TritonGPU/IR/Dialect.cpp` using some auxiliary functions to make the intentions a bit clearer. We add a few asserts in the `LinearLayoutConversion` to make sure it's clear why we do certain things here and there. We also kill `getCvtOrder`, as it was not used anywhere
1 parent d68daff commit 71145a6

File tree

5 files changed

+166
-80
lines changed

5 files changed

+166
-80
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,17 @@ unsigned getNumWarpsPerCTA(Attribute layout);
130130

131131
unsigned getNumCTAs(Attribute layout);
132132

133+
// Return the order that represents that the batch is in row-major or
134+
// column-major order for a batch of matrices of shape [*, m, n] with
135+
// len(shape) == rank.
136+
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor);
137+
138+
// Return the order that represents that the dot operand is in kMajor
139+
// (contiguous in the inner dimension) or it's contiguous on the outer
140+
// dimension.
141+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
142+
bool kMajor);
143+
133144
bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
134145

135146
// Return true if a view between the two types cannot be implemented as a no-op.

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,19 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
235235
return resOrder;
236236
}
237237

238+
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
239+
// Return the order that represents that the batch is in row-major or
240+
// column-major order for a batch of matrices of shape [*, m, n] with
241+
// len(shape) == rank.
242+
assert(rank >= 2);
243+
SmallVector<unsigned> order(rank);
244+
std::iota(order.rbegin(), order.rend(), 0);
245+
if (!rowMajor) {
246+
std::swap(order[0], order[1]);
247+
}
248+
return order;
249+
}
250+
238251
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
239252
bool kMajor) {
240253
// kMajor: if true, the matrix is fastest-running on k,
@@ -244,15 +257,8 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
244257
// batch (if rank == 3) is always the slowest running dimension
245258
assert(rank == 2 || rank == 3);
246259
assert(opIdx == 0 || opIdx == 1);
247-
SmallVector<unsigned> order(rank);
248-
std::iota(order.rbegin(), order.rend(), 0);
249-
// If opIdx is 1 and kMajor is true, the order is [0, 1]
250-
// (resp. [1, 2, 0] if rank == 3)
251-
// Same if opIdx is 0 and kMajor is false
252-
if (bool(opIdx) == kMajor) {
253-
std::swap(order[0], order[1]);
254-
}
255-
return order;
260+
auto rowMajor = bool(opIdx) != kMajor;
261+
return getMatrixOrder(rank, rowMajor);
256262
}
257263

258264
SmallVector<unsigned> getWarpOrder(Attribute layout) {
@@ -262,20 +268,21 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
262268
}
263269
}
264270
auto order = getOrder(layout);
265-
// FIXME: This mmaLayout if should just return
266-
// getOrderForDotOperand(0, order.size(), kMajor=false)
267-
// as mma has the same order as DotOperand(opIdx=0)
271+
// FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's
272+
// M-major This is awkward. Since we can choose any warpOrder in Ampere, we
273+
// should probably choose M-major and change `LinearLayoutConversion.cpp` and
274+
// `MMAv2.cpp` to match.
268275
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
269276
if (mmaLayout.isHopper()) {
270-
// Hopper MMA instructions force a warp order of [0, 1]. See docs:
277+
// Hopper MMA instructions force warps to be column-major
271278
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8
272-
auto it = std::find(order.begin(), order.end(), 0);
273-
order.erase(it);
274-
order.insert(order.begin(), 0);
279+
return getMatrixOrder(order.size(), /*rowMajor*/ false);
275280
}
276281
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
277-
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
278-
/*kMajor*/ false);
282+
// It's quite weird to talk about warp order when that the warps
283+
// are broadcasted along the K dimension
284+
llvm::report_fatal_error(
285+
"DotOperandEncoding::getWarpOrder not implemented");
279286
}
280287
return order;
281288
}
@@ -285,11 +292,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
285292
return llvm::to_vector(blockedLayout.getOrder());
286293
}
287294
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(layout)) {
295+
// Order doesn't really matter. We just have to be consistent when unpacking
296+
// the elements in the MMAv2/V3 lowerings. We choose row-major
288297
auto distributedLayout = cast<DistributedEncodingTrait>(layout);
289298
auto rank = distributedLayout.getWarpsPerCTA().size();
290-
SmallVector<unsigned> order(rank);
291-
std::iota(order.rbegin(), order.rend(), 0);
292-
return order;
299+
return getMatrixOrder(rank, /*rowMajor*/ true);
293300
}
294301
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
295302
auto rank = dotLayout.getWarpsPerCTA().size();
@@ -421,7 +428,7 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
421428
else if (auto wmmaLayout = dyn_cast<AMDWmmaEncodingAttr>(layout))
422429
warpsPerCTA = wmmaLayout.getWarpsPerCTA();
423430
else if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout))
424-
return getNumWarpsPerCTA(dotLayout.getParent());
431+
warpsPerCTA = dotLayout.getWarpsPerCTA();
425432
else if (auto sharedLayout = dyn_cast<SharedEncodingAttr>(layout))
426433
llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr");
427434
else
@@ -2136,25 +2143,12 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand(
21362143
SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
21372144
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
21382145
assert(isAmpere() && "mmaLayout version = 1 is not implemented yet");
2139-
auto parentShapePerCTATile = getShapePerCTATile(shape);
2140-
auto rank = parentShapePerCTATile.size();
2146+
auto shapePerCTATile = getShapePerCTATile(shape);
2147+
auto rank = shapePerCTATile.size();
2148+
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
21412149
// 4 threads * 2 subtiles
2142-
unsigned kWidthTile = kWidth * 2 * 4;
2143-
if (opIdx == 0) {
2144-
if (rank == 2)
2145-
return {parentShapePerCTATile[rank - 2], kWidthTile};
2146-
else
2147-
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2],
2148-
kWidthTile};
2149-
} else if (opIdx == 1) {
2150-
if (rank == 2)
2151-
return {kWidthTile, parentShapePerCTATile[rank - 1]};
2152-
else
2153-
return {parentShapePerCTATile[0], kWidthTile,
2154-
parentShapePerCTATile[rank - 1]};
2155-
} else {
2156-
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
2157-
}
2150+
shapePerCTATile[kDim] = kWidth * 2 * 4;
2151+
return shapePerCTATile;
21582152
}
21592153
SmallVector<unsigned>
21602154
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
4141
return ret;
4242
}
4343

44+
// TODO Have order be a mandatory argument of standardOutDimNames.
45+
SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
46+
const SmallVector<unsigned> &order) {
47+
assert(names.size() == order.size());
48+
SmallVector<StringAttr> ret;
49+
for (unsigned i : order) {
50+
ret.push_back(names[i]);
51+
}
52+
return ret;
53+
}
54+
4455
void assertIsRegisterLayout(const LinearLayout &layout) {
4556
assert(layout.getNumInDims() > 0);
4657
MLIRContext *ctx = layout.getInDimNames().begin()->getContext();
@@ -281,15 +292,19 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef<int64_t> shape,
281292

282293
MLIRContext *ctx = mma.getContext();
283294
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
295+
auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma));
296+
// By using `reverse(dimNames)` below, we set the order to be row-major
297+
assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
284298

285299
LinearLayout ctaLayout(
286300
{{S("register"), {{1, 0}, {0, 8}}},
287301
{S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}},
288-
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));
289-
290-
ctaLayout *= identityND(
291-
S("warp"), mma.getWarpsPerCTA(),
292-
llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank))), dimNames);
302+
ArrayRef(orderedDimNames).take_front(2));
303+
assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true));
304+
// FIXME(Lezcano). identityND should not have an `order` param as it's
305+
// redundant with the order of the out dims.
306+
ctaLayout *=
307+
identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames);
293308

294309
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
295310
}
@@ -322,10 +337,14 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
322337
ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")),
323338
S("register"), S("dim1"));
324339

325-
// Expand the `warp` dimension according to warpsPerCTA.
326-
//
327-
// It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but
328-
// this really does seem to be correct.
340+
// The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major.
341+
// Since the warpOrder needs to be M-major, we need to transpose the out
342+
// dimensions AND transpose the order
343+
// FIXME(Lezcano). identityND should not have an `order` param as it's
344+
// redundant. The order is already given by the order of the
345+
// out dims, and if it has an order, it shouldn't change the
346+
// order of the out dims.
347+
assert(getWarpOrder(mma) == SmallVector<unsigned>({0, 1}));
329348
ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1},
330349
{S("dim0"), S("dim1")})
331350
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
@@ -843,18 +862,24 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
843862

844863
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
845864
DotOperandEncodingAttr dot) {
846-
// TODO,BE. Implement ampereMMA in terms of this one
865+
// Note that, even though MMAv2 looks similar to this layout, they are just
866+
// the same at a register and lane level. The warps treatment is different!
847867
int rank = shape.size();
848868
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
849869
int kWidth = dot.getKWidth();
850870
bool isA = dot.getOpIdx() == 0;
851871

852-
assert(mma.isAmpere());
853872
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
854873
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
874+
assert(mma.isAmpere());
855875

856876
MLIRContext *ctx = mma.getContext();
857-
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
877+
// A and B have kMajor order
878+
assert(getOrder(dot) ==
879+
getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true));
880+
881+
auto kMajorDims =
882+
permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot));
858883

859884
// Implement A. For B transpose in the end
860885
std::vector<std::vector<int32_t>> registers;
@@ -881,24 +906,51 @@ LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
881906
}
882907
registers.push_back({i, 0});
883908

884-
if (!isA) {
885-
for (auto &r : registers) {
886-
std::swap(r[0], r[1]);
909+
LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}},
910+
ArrayRef(kMajorDims).take_front(2));
911+
912+
// Let warpsPerCTAMma = {2, 2}, then
913+
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
914+
// assume warpOrder = {0, 1}
915+
// Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
916+
// the C is owned as per the following layout:
917+
// C: 0 | 1
918+
// - | -
919+
// 2 | 3
920+
// In order to be able to compute C, we need the following warp tiling of
921+
// A and B:
922+
// A: 0 1 | 0 1 B: 0 2 | 1 3
923+
// - - | - - - - | - -
924+
// 2 3 | 2 3 0 2 | 1 3
925+
// In particular, for A and B we need to broadcast along K
926+
927+
assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true));
928+
auto warpsPerCTAMma = mma.getWarpsPerCTA();
929+
std::vector<std::vector<int32_t>> warps;
930+
if (isA) {
931+
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
932+
warps.push_back({0, 0});
933+
}
934+
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
935+
warps.push_back({0, i});
936+
}
937+
} else {
938+
for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) {
939+
warps.push_back({0, i});
887940
}
888-
for (auto &l : lanes) {
889-
std::swap(l[0], l[1]);
941+
for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) {
942+
warps.push_back({0, 0});
943+
}
944+
}
945+
if (rank == 3) {
946+
for (auto &w : warps) {
947+
w.push_back(0);
890948
}
891949
}
892950

893-
LinearLayout ctaLayout(
894-
{{S("register"), registers}, {S("lane"), lanes}},
895-
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));
896-
897-
auto order = dot.getCTAOrder();
898-
assert(order[0] == rank - 1 && order[1] == rank - 2);
899-
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
951+
ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims);
900952

901-
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
953+
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
902954
}
903955

904956
std::optional<LinearLayout>
@@ -907,7 +959,7 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
907959
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
908960
return mfmaDotToLinearLayout(*this, shape);
909961
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
910-
if (mma.getVersionMajor() == 2 && mma.getVersionMinor() == 0) {
962+
if (mma.isAmpere()) {
911963
return ampereDotToLinearLayout(shape, *this);
912964
}
913965
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,19 +121,15 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
121121
}
122122

123123
if (dot.getOpIdx() == 1) {
124-
// there are kWidth * 2 elems packed as bf16x2
125124
int elemsInTile = dot.getKWidth();
126-
// n0 and n1 are unrolled in the legacy path
127-
// Unrolling n1 makes some sense, but unrolling n0 makes absolutely no
128-
// sense IMO
125+
// n0 is unrolled in the legacy path, which makes no sense
129126
n0 *= 2;
130-
n1 *= 2;
131127
for (auto b = 0; b < batch; ++b)
132-
for (auto j = 0; j < n1 / elemsInTile; ++j)
133-
for (auto i = 0; i < n0; ++i)
134-
for (auto k = 0; k < elemsInTile; ++k) {
135-
vals[{b, i, elemsInTile * j + k}] = elems[offset++];
136-
}
128+
for (auto i = 0; i < n0; ++i)
129+
for (auto j = 0; j < n1; ++j) {
130+
vals[{b, i, 2 * j}] = elems[offset++];
131+
vals[{b, i, 2 * j + 1}] = elems[offset++];
132+
}
137133
return vals;
138134
}
139135
}

unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -555,14 +555,14 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
555555
{2, 0},
556556
{4, 0},
557557
{32, 0},
558+
{64, 0},
558559
{0, 8},
559560
{0, 16},
560-
{0, 32},
561-
{64, 0}}},
561+
{0, 32}}},
562562
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
563563
{
564564
S("warp"),
565-
{},
565+
{{0, 0}, {0, 0}},
566566
},
567567
{S("block"), {}},
568568
},
@@ -582,13 +582,46 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
582582
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
583583
{
584584
S("warp"),
585-
{},
585+
{{0, 0}, {0, 0}},
586586
},
587587
{S("block"), {}},
588588
},
589589
{S("dim0"), S("dim1")}));
590590
}
591591

592+
TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) {
593+
EXPECT_EQ(
594+
toLinearLayout({32, 64}, dotMMAv2(0, 8, {2, 2})),
595+
LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}},
596+
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
597+
{S("warp"), {{0, 0}, {16, 0}}},
598+
{S("block"), {}}},
599+
{S("dim0"), S("dim1")}));
600+
EXPECT_EQ(
601+
toLinearLayout({64, 16}, dotMMAv2(1, 8, {2, 2})),
602+
LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}},
603+
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
604+
{S("warp"), {{0, 8}, {0, 0}}},
605+
{S("block"), {}}},
606+
{S("dim0"), S("dim1")}));
607+
EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(0, 8, {2, 2})),
608+
LinearLayout(
609+
{{S("register"),
610+
{{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {32, 0}}},
611+
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
612+
{S("warp"), {{0, 0}, {16, 0}}},
613+
{S("block"), {}}},
614+
{S("dim0"), S("dim1")}));
615+
EXPECT_EQ(
616+
toLinearLayout({128, 32}, dotMMAv2(1, 8, {2, 2})),
617+
LinearLayout(
618+
{{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 16}}},
619+
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
620+
{S("warp"), {{0, 8}, {0, 0}}},
621+
{S("block"), {}}},
622+
{S("dim0"), S("dim1")}));
623+
}
624+
592625
TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
593626
auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32,
594627
/*isTransposed=*/false);

0 commit comments

Comments
 (0)