Skip to content

Commit a20ce64

Browse files
authored
[AMD] Add MFMA dot operand to LinearLayout conversion (#4961)
This PR implements general conversion of MFMA dot operand to Linear Layout.
1 parent c9a40b2 commit a20ce64

File tree

4 files changed

+464
-40
lines changed

4 files changed

+464
-40
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,11 @@ SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
256256
}
257257

258258
SmallVector<unsigned> getWarpOrder(Attribute layout) {
259+
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
260+
if (isa<AMDMfmaEncodingAttr>(dotLayout.getParent())) {
261+
return getWarpOrder(dotLayout.getParent());
262+
}
263+
}
259264
auto order = getOrder(layout);
260265
// FIXME: This mmaLayout if should just return
261266
// getOrderForDotOperand(0, order.size(), kMajor=false)

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
472472
int nIndex = 1 + hasBatchDim;
473473
(void)mIndex, (void)nIndex;
474474

475-
assert(((shape[mIndex] == 1 || shape[mIndex] >= getMDim()) &&
476-
(shape[nIndex] == 1 || shape[nIndex] >= getNDim())) &&
477-
"Unsupported tensor shape for given mfma layout");
478-
479475
assert(((getMDim() == 32 && getNDim() == 32) ||
480476
(getMDim() == 16 && getNDim() == 16)) &&
481477
"Unsupported mfma type");
@@ -579,55 +575,76 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
579575
// 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is
580576
// held by exactly one thread, maintaining the same number of global loads
581577
// as in a blocked layout.
578+
//
579+
// Other use of Linear layout is a support of rare corner cases,
580+
// for example one instruction tile is larger than tensor
582581
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
583582

584-
if (dotMfmaLayout.getOpIdx() == 0) {
585-
return std::nullopt;
586-
}
587583
auto rank = shape.size();
588584
bool hasBatchDim = rank == 3;
589585
int mIndex = 0 + hasBatchDim;
590586

591-
auto kWidth = dotMfmaLayout.getKWidth();
587+
int32_t kWidth = dotMfmaLayout.getKWidth();
588+
auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
589+
int32_t kSize = shape[kDim];
592590
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
593591

594-
if (kWidth != 8 || warpsPerCTA[mIndex] != 1) {
595-
return std::nullopt;
596-
}
597-
598592
MLIRContext *ctx = dotMfmaLayout.getContext();
599593
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
600594

601595
StringAttr kRegister = S("register");
602596
StringAttr kLane = S("lane");
597+
StringAttr kWarp = S("warp");
603598

599+
// register order
600+
// operand A: [1, 0] / [2, 1, 0]
601+
// operand B: [0, 1] / [1, 2, 0]
602+
// for both cases it is [k, nonk]/[k, nonk, batch]
604603
SmallVector<unsigned> order = triton::gpu::getOrder(dotMfmaLayout);
605-
auto tileLayout = LinearLayout::empty();
604+
// warp order
605+
// common for both operand A and B: [0, 1] / [0, 1, 2]
606+
// in both cases it is [M dim, N dim]/[batch, M dim, N dim]
607+
SmallVector<unsigned> warpOrder = triton::gpu::getWarpOrder(dotMfmaLayout);
608+
609+
// Lane holds kWidth consecutive elements along k dimension, so
610+
// base register vectors for one tile are initialized in following way:
611+
// {1, 0}, {2, 0} ... {kWidth/2, 0}
612+
std::vector<std::vector<int32_t>> registerBase;
613+
for (int32_t elem = 1; elem < kWidth; elem *= 2)
614+
registerBase.emplace_back(std::vector<int32_t>{elem, 0});
615+
616+
std::vector<std::vector<int32_t>> laneBase;
617+
int32_t kTileSize = -1;
606618

607619
if (mfmaLayout.getMDim() == 32) {
608-
// Based on canonical MFMA linear layout, which handles 4 consecutive
609-
// elements along the register dimension, kWidth=8 means we have 8
610-
// consecutive elements, so we have an additional {4, 0} base vector here.
611-
// For lane dim, since the MFMA thread arrangement is {K, N} = {2, 32}, this
612-
// means that mapping of first 5 base (up to thread 16) vectors will be an
613-
// identity along N dim. Thread 32 will be mapped to element 8 in K
614-
// dimension, because kWidth == 8.
615-
tileLayout = LinearLayout(
616-
{{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
617-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}},
618-
{outDimNames[order[0]], outDimNames[order[1]]});
620+
// Canonical MFMA linear layout handles 4 consecutive elements along
621+
// the register dimension. Dot operand handles varaible kWidth consecutive
622+
// elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2,
623+
// 32}, this means that mapping of first 5 base (up to thread 16) vectors
624+
// will be an identity along N dim. Thread 32 will be mapped to element
625+
// kWidth in K dimension.
626+
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {kWidth, 0}};
627+
kTileSize = kWidth * 2;
619628
} else {
620629
assert(mfmaLayout.getMDim() == 16);
621630
// For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this
622631
// means that mapping of first 4 base (up to thread 16) vectors will be an
623-
// identity along N dim. Thread 16 will be mapped to element 8 in K
624-
// dimension, because kWidth == 8. Thread 32 is mapped to element 16 as that
625-
// is 2*kWidth in K dim.
626-
tileLayout = LinearLayout(
627-
{{kRegister, {{1, 0}, {2, 0}, {4, 0}}},
628-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {8, 0}, {16, 0}}}},
629-
{outDimNames[order[0]], outDimNames[order[1]]});
632+
// identity along N dim. Thread 16 will be mapped to element kWisth in K
633+
// dimension. Thread 32 is mapped to element 2*kWidth in K dim.
634+
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {kWidth, 0}, {kWidth * 2, 0}};
635+
kTileSize = kWidth * 4;
630636
}
637+
assert(kTileSize != -1);
638+
// Add repeats of registers along K dimension to register base vectors
639+
for (int32_t elem = kTileSize; elem < kSize; elem *= 2)
640+
registerBase.emplace_back(std::vector<int32_t>{elem, 0});
641+
642+
// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
643+
// To assign them to actual matrix dimensions `order` array is used.
644+
// For operand A: non-k-dim -> dim0, k-dim -> dim1
645+
// For operand B: non-k-dim -> dim1, k-dim -> dim0
646+
LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}},
647+
{outDimNames[order[0]], outDimNames[order[1]]});
631648

632649
if (hasBatchDim) {
633650
assert(order[2] == 0);
@@ -638,8 +655,10 @@ dotOperandMfmaToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
638655
}
639656

640657
LinearLayout warpLayout =
641-
identityND(S("warp"), warpsPerCTA, order, outDimNames);
642-
LinearLayout ctaLayout = tileLayout * warpLayout;
658+
identityND(kWarp, warpsPerCTA, warpOrder, outDimNames);
659+
660+
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
661+
warpLayout.transposeOuts(outDimNames);
643662

644663
return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
645664
}

unittest/Dialect/TritonGPU/DialectTest.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,12 @@ class AMDMfmaLayoutTest : public ::testing::Test {
543543
/*isTransposed=*/true, ctaLayout);
544544
}
545545

546+
triton::gpu::DotOperandEncodingAttr
547+
createDotOperand(int idx, triton::gpu::AMDMfmaEncodingAttr parent,
548+
int kWidth) {
549+
return triton::gpu::DotOperandEncodingAttr::get(&ctx, idx, parent, kWidth);
550+
}
551+
546552
protected:
547553
MLIRContext ctx;
548554
const SmallVector<unsigned> ctaPerCGA{1, 1, 1};
@@ -588,6 +594,32 @@ TEST_F(AMDMfmaLayoutTest, mfma16) {
588594
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
589595
}
590596

597+
TEST_F(AMDMfmaLayoutTest, mfma_dot_op) {
598+
auto mfma2d = createMFMA(32, 32, {2, 4});
599+
auto dot2dOp0 = createDotOperand(0, mfma2d, 4);
600+
auto dot2dOp1 = createDotOperand(1, mfma2d, 4);
601+
ASSERT_THAT(dot2dOp0.getWarpOrder(), mfma2d.getWarpOrder());
602+
ASSERT_THAT(dot2dOp1.getWarpOrder(), mfma2d.getWarpOrder());
603+
604+
auto tmfma2d = createTransposedMFMA(32, 32, {2, 4});
605+
auto tdot2dOp0 = createDotOperand(0, tmfma2d, 4);
606+
auto tdot2dOp1 = createDotOperand(1, tmfma2d, 4);
607+
ASSERT_THAT(tdot2dOp0.getWarpOrder(), tmfma2d.getWarpOrder());
608+
ASSERT_THAT(tdot2dOp1.getWarpOrder(), tmfma2d.getWarpOrder());
609+
610+
auto mfma3d = createMFMA(32, 32, {2, 4, 1});
611+
auto dot3dOp0 = createDotOperand(0, mfma3d, 4);
612+
auto dot3dOp1 = createDotOperand(1, mfma3d, 4);
613+
ASSERT_THAT(dot3dOp0.getWarpOrder(), mfma3d.getWarpOrder());
614+
ASSERT_THAT(dot3dOp1.getWarpOrder(), mfma3d.getWarpOrder());
615+
616+
auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1});
617+
auto tdot3dOp0 = createDotOperand(0, tmfma3d, 4);
618+
auto tdot3dOp1 = createDotOperand(1, tmfma3d, 4);
619+
ASSERT_THAT(tdot3dOp0.getWarpOrder(), tmfma3d.getWarpOrder());
620+
ASSERT_THAT(tdot3dOp1.getWarpOrder(), tmfma3d.getWarpOrder());
621+
}
622+
591623
} // anonymous namespace
592624
} // namespace mlir::triton::gpu
593625

0 commit comments

Comments
 (0)