Skip to content

Commit 2bd83aa

Browse files
oplavsicOgnjen Plavsic
authored andcommitted
[AMD] Implement dotOperandMfma to linear layout conversion (triton-lang#4817)
This commit implements dotOperandMfma layout to linear layout conversion under following conditions: - opIdx == 1 - kWidth == 8 - warpsPerCTA[0] == 1. This serves as a preparation for a next PR, which will add a bypassLDS pass. --------- Co-authored-by: Ognjen Plavsic <[email protected]>
1 parent 0f611bc commit 2bd83aa

File tree

3 files changed

+293
-2
lines changed

3 files changed

+293
-2
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,30 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
247247
return order;
248248
}
249249

250+
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) {
251+
SmallVector<unsigned> order(rank);
252+
// The 'order' field typically represents a descending sorted array of
253+
// dimensions based on contiguity. For instance, in axisInfo utilities that
254+
// retrieve tensor contiguity, it's assumed that the dimension with the
255+
// highest contiguity corresponds to order[0].
256+
//
257+
// The relation between contiguity and order is only relevant if the layout
258+
// interfaces with HBM, as is the case when we load tensor from HBM to
259+
// registers in the dot layout to bypass LDS. When bypassing LDS, we make the
260+
// following assumptions about tensor layouts:
261+
// - Tensor A (opIdx == 0) is considered to be row-major.
262+
// - Tensor B (opIdx == 1) is considered to be column-major.
263+
//
264+
// Based on these assumptions, we define the following orders:
265+
// - For opIdx == 0, we assume an order of [1, 0].
266+
// - For opIdx == 1, we assume an order of [0, 1].
267+
std::iota(order.rbegin(), order.rend(), 0);
268+
if (opIdx == 1) {
269+
std::swap(order[0], order[1]);
270+
}
271+
return order;
272+
}
273+
250274
SmallVector<unsigned> getOrder(Attribute layout) {
251275
if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout)) {
252276
return llvm::to_vector(blockedLayout.getOrder());
@@ -261,7 +285,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
261285
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
262286
auto rank = getWarpsPerCTA(dotLayout.getParent()).size();
263287
SmallVector<unsigned> order(rank);
264-
std::iota(order.rbegin(), order.rend(), 0);
288+
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
289+
return getOrderForDotOperand(dotLayout.getOpIdx(), rank);
290+
} else {
291+
std::iota(order.rbegin(), order.rend(), 0);
292+
}
265293
return order;
266294
}
267295
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -925,6 +953,27 @@ unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
925953
SmallVector<unsigned>
926954
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
927955
Type eltTy) const {
956+
957+
if (auto parent = mlir::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
958+
auto rank = shape.size();
959+
assert(rank == 2 || rank == 3);
960+
961+
auto idx = getOpIdx();
962+
assert(idx == 0 || idx == 1);
963+
964+
SmallVector<unsigned> elemsPerThread(rank);
965+
966+
auto kWidth = getKWidth();
967+
auto rep = parent.getMFMARepForOperands(shape, kWidth, idx);
968+
969+
if (rank == 3)
970+
elemsPerThread[0] = rep[0];
971+
elemsPerThread[rank - 2] = (idx == 0) ? rep[1] : rep[1] * kWidth;
972+
elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2];
973+
974+
return elemsPerThread;
975+
}
976+
928977
llvm_unreachable("getElemsPerThread is not supported for dot operand");
929978
return SmallVector<unsigned>();
930979
}

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,96 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
553553
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
554554
}
555555

556+
std::optional<LinearLayout>
557+
dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
558+
ArrayRef<int64_t> shape) {
559+
560+
// Current linear layout conversion for dot operand is only necessary to
561+
// enable LDS bypass for operand B in the MFMA dot path. To achieve
562+
// performance gains from bypassing LDS, the following conditions must be met:
563+
//
564+
// 1) opIdx == 1: Currently, only the B tensor (e.g. weights in moe-like
565+
// kernels) bypasses LDS. This constraint is not strict and support for
566+
// bypassing operand A (e.g. Q tensor in flash attention) will be added in
567+
// the future.
568+
//
569+
// 2) B tensor must be column major: This is required to support vectorized
570+
// global load instructions, as MFMA instructions expect threads to hold B
571+
// operand elements along the K dimension.
572+
//
573+
// 3) kWidth == 8: Ensures maximum global load vectorization for fp16
574+
// operations.
575+
// TODO: Generalize conversion to handle maximum kWidth for other types
576+
// (i.e. fp8).
577+
//
578+
// 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is
579+
// held by exactly one thread, maintaining the same number of global loads
580+
// as in a blocked layout.
581+
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
582+
583+
if (dotMfmaLayout.getOpIdx() == 0) {
584+
return std::nullopt;
585+
}
586+
auto rank = shape.size();
587+
bool hasBatchDim = rank == 3;
588+
int mIndex = 0 + hasBatchDim;
589+
590+
auto kWidth = dotMfmaLayout.getKWidth();
591+
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
592+
593+
if (kWidth != 8 || warpsPerCTA[mIndex] != 1) {
594+
return std::nullopt;
595+
}
596+
597+
MLIRContext *ctx = dotMfmaLayout.getContext();
598+
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
599+
600+
StringAttr kRegister = S("register");
601+
StringAttr kLane = S("lane");
602+
603+
SmallVector<unsigned> order = triton::gpu::getOrder(dotMfmaLayout);
604+
auto tileLayout = LinearLayout::empty();
605+
606+
if (mfmaLayout.getMDim() == 32) {
607+
// Based on canonical MFMA linear layout, which handles 4 consecutive
608+
// elements along the register dimension, kWidth=8 means we have 8
609+
// consecutive elements, so we have an additional {4, 0} base vector here.
610+
// For lane dim, since the MFMA thread arrangement is {K, N} = {2, 32}, this
611+
// means that mapping of first 5 base (up to thread 16) vectors will be an
612+
// identity along N dim. Thread 32 will be mapped to element 8 in K
613+
// dimension, because kWidth == 8.
614+
tileLayout = LinearLayout(
615+
{{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
616+
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}},
617+
{outDimNames[order[0]], outDimNames[order[1]]});
618+
} else {
619+
assert(mfmaLayout.getMDim() == 16);
620+
// For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this
621+
// means that mapping of first 4 base (up to thread 16) vectors will be an
622+
// identity along N dim. Thread 16 will be mapped to element 8 in K
623+
// dimension, because kWidth == 8. Thread 32 is mapped to element 16 as that
624+
// is 2*kWidth in K dim.
625+
tileLayout = LinearLayout(
626+
{{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
627+
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}},
628+
{outDimNames[order[0]], outDimNames[order[1]]});
629+
}
630+
631+
if (hasBatchDim) {
632+
assert(order[2] == 0);
633+
// Extend the base vector with one value to accomodate for the batch
634+
// dimension, which appears at the last.
635+
tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]);
636+
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
637+
}
638+
639+
LinearLayout warpLayout =
640+
identityND(S("warp"), warpsPerCTA, order, outDimNames);
641+
LinearLayout ctaLayout = tileLayout * warpLayout;
642+
643+
return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
644+
}
645+
556646
std::optional<LinearLayout>
557647
AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
558648
int rank = shape.size();
@@ -731,9 +821,13 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
731821
return ret;
732822
}
733823

734-
// TODO: DotOperandEncoding doesn't support LinearLayout conversion yet.
735824
std::optional<LinearLayout>
736825
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
826+
827+
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
828+
return dotOperandMfmaToLinearLayout(*this, shape);
829+
}
830+
737831
return std::nullopt;
738832
}
739833

unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ class LinearLayoutConversionsTest : public ::testing::Test {
5151
isTransposed, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd));
5252
}
5353

54+
DotOperandEncodingAttr amdDot(AMDMfmaEncodingAttr mfma, unsigned opIdx,
55+
unsigned kWidth) {
56+
return DotOperandEncodingAttr::get(&ctx, opIdx, mfma, kWidth);
57+
}
58+
5459
AMDWmmaEncodingAttr wmma(ArrayRef<unsigned> warps) {
5560
SmallVector<unsigned> cpg(warps.size(), 1u);
5661
SmallVector<unsigned> cSplit(warps.size(), 1u);
@@ -651,6 +656,149 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4x1Warps) {
651656
{S("dim0"), S("dim1"), S("dim2")}));
652657
}
653658

659+
TEST_F(LinearLayoutConversionsTest, warp1onK_mfma32_rhs_kwidth8) {
660+
auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/32, /*nDim=*/32,
661+
/*isTransposed=*/false);
662+
auto amdDot_1_8 = amdDot(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8);
663+
EXPECT_EQ(
664+
toLinearLayout({128, 128}, amdDot_1_8),
665+
LinearLayout(
666+
{{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}}},
667+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}},
668+
{S("warp"), {{0, 32}, {0, 64}, {0, 0}}},
669+
{S("block"), {}}},
670+
{S("dim0"), S("dim1")}));
671+
672+
EXPECT_EQ(
673+
toLinearLayout({128, 256}, amdDot_1_8),
674+
LinearLayout(
675+
{{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}}},
676+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}},
677+
{S("warp"), {{0, 32}, {0, 64}, {0, 128}}},
678+
{S("block"), {}}},
679+
{S("dim0"), S("dim1")}));
680+
681+
EXPECT_EQ(toLinearLayout({32, 64}, amdDot_1_8),
682+
LinearLayout(
683+
{{S("register"), {{1, 0}, {2, 0}, {4, 0}, {16, 0}}},
684+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}},
685+
{S("warp"), {{0, 32}, {0, 0}, {0, 0}}},
686+
{S("block"), {}}},
687+
{S("dim0"), S("dim1")}));
688+
689+
EXPECT_EQ(
690+
toLinearLayout({256, 256}, amdDot_1_8),
691+
LinearLayout(
692+
{{S("register"),
693+
{{1, 0}, {2, 0}, {4, 0}, {16, 0}, {32, 0}, {64, 0}, {128, 0}}},
694+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}},
695+
{S("warp"), {{0, 32}, {0, 64}, {0, 128}}},
696+
{S("block"), {}}},
697+
{S("dim0"), S("dim1")}));
698+
699+
auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/32, /*nDim=*/32,
700+
/*isTransposed=*/false);
701+
auto amdDot_1_4 = amdDot(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8);
702+
EXPECT_EQ(toLinearLayout({256, 256}, amdDot_1_4),
703+
LinearLayout(
704+
{{S("register"),
705+
{{1, 0},
706+
{2, 0},
707+
{4, 0},
708+
{16, 0},
709+
{32, 0},
710+
{64, 0},
711+
{128, 0},
712+
{0, 128}}},
713+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}},
714+
{S("warp"), {{0, 32}, {0, 64}}},
715+
{S("block"), {}}},
716+
{S("dim0"), S("dim1")}));
717+
}
718+
719+
TEST_F(LinearLayoutConversionsTest, warp1onK_mfma16_rhs_kwidth8) {
720+
auto parentMfma_1_4 = mfma(/*warps=*/{1, 4}, /*mDim=*/16, /*nDim=*/16,
721+
/*isTransposed=*/false);
722+
auto amdDot_1_4 = amdDot(parentMfma_1_4, /*opIdx=*/1, /*kWidth=*/8);
723+
EXPECT_EQ(
724+
toLinearLayout({128, 128}, amdDot_1_4),
725+
LinearLayout(
726+
{{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 64}}},
727+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}},
728+
{S("warp"), {{0, 16}, {0, 32}}},
729+
{S("block"), {}}},
730+
{S("dim0"), S("dim1")}));
731+
732+
EXPECT_EQ(toLinearLayout({1, 128}, amdDot_1_4),
733+
LinearLayout(
734+
{{S("register"), {{0, 0}, {0, 0}, {0, 0}, {0, 64}}},
735+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}, {0, 0}}},
736+
{S("warp"), {{0, 16}, {0, 32}}},
737+
{S("block"), {}}},
738+
{S("dim0"), S("dim1")}));
739+
740+
EXPECT_EQ(toLinearLayout({128, 1}, amdDot_1_4),
741+
LinearLayout(
742+
{{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}}},
743+
{S("lane"), {{0, 0}, {0, 0}, {0, 0}, {0, 0}, {8, 0}, {16, 0}}},
744+
{S("warp"), {{0, 0}, {0, 0}}},
745+
{S("block"), {}}},
746+
{S("dim0"), S("dim1")}));
747+
748+
EXPECT_EQ(toLinearLayout({256, 256}, amdDot_1_4),
749+
LinearLayout(
750+
{{S("register"),
751+
{{1, 0},
752+
{2, 0},
753+
{4, 0},
754+
{32, 0},
755+
{64, 0},
756+
{128, 0},
757+
{0, 64},
758+
{0, 128}}},
759+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}},
760+
{S("warp"), {{0, 16}, {0, 32}}},
761+
{S("block"), {}}},
762+
{S("dim0"), S("dim1")}));
763+
764+
auto parentMfma_1_8 = mfma(/*warps=*/{1, 8}, /*mDim=*/16, /*nDim=*/16,
765+
/*isTransposed=*/false);
766+
auto amdDot_1_8 = amdDot(parentMfma_1_8, /*opIdx=*/1, /*kWidth=*/8);
767+
EXPECT_EQ(
768+
toLinearLayout({256, 256}, amdDot_1_8),
769+
LinearLayout(
770+
{{S("register"),
771+
{{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {128, 0}, {0, 128}}},
772+
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}},
773+
{S("warp"), {{0, 16}, {0, 32}, {0, 64}}},
774+
{S("block"), {}}},
775+
{S("dim0"), S("dim1")}));
776+
777+
auto parentMfma_1_8_1 = mfma(/*warps=*/{1, 1, 8}, /*mDim=*/16, /*nDim=*/16,
778+
/*isTransposed=*/false);
779+
auto amdDot_1_8_1 = amdDot(parentMfma_1_8_1, /*opIdx=*/1, /*kWidth=*/8);
780+
781+
EXPECT_EQ(toLinearLayout({1, 256, 256}, amdDot_1_8_1),
782+
LinearLayout({{S("register"),
783+
{{0, 1, 0},
784+
{0, 2, 0},
785+
{0, 4, 0},
786+
{0, 32, 0},
787+
{0, 64, 0},
788+
{0, 128, 0},
789+
{0, 0, 128}}},
790+
{S("lane"),
791+
{{0, 0, 1},
792+
{0, 0, 2},
793+
{0, 0, 4},
794+
{0, 0, 8},
795+
{0, 8, 0},
796+
{0, 16, 0}}},
797+
{S("warp"), {{0, 0, 16}, {0, 0, 32}, {0, 0, 64}}},
798+
{S("block"), {}}},
799+
{S("dim0"), S("dim1"), S("dim2")}));
800+
}
801+
654802
TEST_F(LinearLayoutConversionsTest, WMMA_2x4Warps) {
655803
auto legacy = wmma(/*warps=*/{2, 4});
656804

0 commit comments

Comments
 (0)