Skip to content

Commit 755077c

Browse files
authored
[AMD] Always swap operands of mfma and use mfma.transposed layout (#4767)
This helps to improve writeout to use `global_store_dwordx2`. Along the way this PR - Fixed the issue with getOrder for mfma layout - Fixed the issue with reduceOp when dealing with mfma.transposed layout In general, getOrder and getThreadOrder can return different values, and this is the case for mfma.transposed layout. Therefore, we shouldn't assume order and threadOrder are always the same.
1 parent 6af74b2 commit 755077c

File tree

8 files changed

+63
-38
lines changed

8 files changed

+63
-38
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,32 @@ getThreadsPerWarpWithUniqueData(Attribute layout,
7575
SmallVector<unsigned>
7676
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
7777

78+
// Returns the dimensions of the tensor from minor (fast-varying) to
79+
// major (slow-varying). For blocked, mma, and dotOperand layouts,
80+
// though the elements are in registers, the order refers to memory
81+
// layout of the original tensor in global memory.
82+
// For shared Layout, the order refers to which dimension of the original tensor
83+
// is contiguous in shared memory.
84+
SmallVector<unsigned> getOrder(Attribute layout);
85+
86+
// Returns the dimensions along which warpId's are distributed.
87+
// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
88+
// tells there are 2 warps along dim0 and 4 warps along dim1.
89+
// warpOrder tells the specific order when distributing warp IDs.
90+
// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows
91+
// [warp0 warp2 warp4 warp6]
92+
// [warp1 warp3 warp5 warp7]
93+
// Note that in most cases, getWarpOrder and getOrder return the same results.
94+
// But this is not guaranteed.
7895
SmallVector<unsigned> getWarpOrder(Attribute layout);
7996

80-
SmallVector<unsigned> getOrder(Attribute layout);
97+
// Returns the dimensions along which threadId's are distributed.
98+
// Similar to warpOrder, threadOrder is necessary to tell the specific thread
99+
// distribution in the warp.
100+
// Note that, in most cases, getThreadOrder and getOrder return the same
101+
// results. But this is not guaranteed. One exception is mfma.transposed layout,
102+
// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1].
103+
SmallVector<unsigned> getThreadOrder(Attribute layout);
81104

82105
CTALayoutAttr getCTALayout(Attribute layout);
83106

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
3636
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
3737
return getParentOrder(sliceEncoding.getParent());
3838
}
39-
return getOrder(layout);
39+
return getThreadOrder(layout);
4040
}
4141

4242
} // namespace
@@ -75,7 +75,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
7575
threadOffset = threadsPerWarp[sliceLayout.getDim()];
7676
} else {
7777
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
78-
auto order = getOrder(srcLayout);
78+
auto order = getThreadOrder(srcLayout);
7979
for (unsigned i = 0; i < order.size(); i++) {
8080
if (order[i] == axis)
8181
break;

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using namespace mlir::triton;
99
using ::mlir::LLVM::delinearize;
1010
using ::mlir::LLVM::linearize;
1111
using ::mlir::triton::gpu::getOrder;
12+
using ::mlir::triton::gpu::getThreadOrder;
1213
using ::mlir::triton::gpu::getTotalElemsPerThread;
1314

1415
namespace {
@@ -271,7 +272,7 @@ struct ReduceOpConversion
271272

272273
auto threadsPerWarp =
273274
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
274-
auto order = getOrder(srcLayout);
275+
auto order = getThreadOrder(srcLayout);
275276
SmallVector<Value> multiDimLaneId =
276277
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
277278
Value laneIdAxis = multiDimLaneId[axis];

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,6 @@ SmallVector<unsigned> getOrder(Attribute layout) {
256256
auto rank = distributedLayout.getWarpsPerCTA().size();
257257
SmallVector<unsigned> order(rank);
258258
std::iota(order.rbegin(), order.rend(), 0);
259-
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(layout);
260-
if (!mfmaLayout)
261-
return order;
262-
// For transposed MFMA layouts, we swap M and N dimensions, which is
263-
// always the first two in order; as we can have an optional batch
264-
// dimension following them.
265-
if (mfmaLayout.getIsTransposed())
266-
std::swap(order[0], order[1]);
267259
return order;
268260
}
269261
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
@@ -290,6 +282,14 @@ SmallVector<unsigned> getOrder(Attribute layout) {
290282
return {};
291283
};
292284

285+
SmallVector<unsigned> getThreadOrder(Attribute layout) {
286+
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
287+
return distributedLayout.getThreadOrder();
288+
else
289+
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
290+
return {};
291+
};
292+
293293
CTALayoutAttr getCTALayout(Attribute layout) {
294294
if (auto distributedLayout =
295295
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
@@ -1536,7 +1536,10 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
15361536
return ::getWarpOrder(*this);
15371537
}
15381538
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
1539-
return ::getOrder(*this);
1539+
auto order = ::getOrder(*this);
1540+
if (getIsTransposed())
1541+
std::swap(order[0], order[1]);
1542+
return order;
15401543
}
15411544
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadsPerWarp() const {
15421545
unsigned rows, cols;

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
507507
{{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}},
508508
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}},
509509
{outDimNames[order[0]], outDimNames[order[1]]});
510+
// For mfma.transposed layout, the element ownership among threads are
511+
// "transposed" within each warp.
512+
if (getIsTransposed())
513+
tileLayout = LinearLayout(
514+
{{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}},
515+
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}},
516+
{outDimNames[order[0]], outDimNames[order[1]]});
510517
} else {
511518
assert(getMDim() == 16);
512519
// For mfma with 16x16 output, each of the 64 threads holds 4 elements.
@@ -521,6 +528,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
521528
{{kRegister, {{0, 1}, {0, 2}}},
522529
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}},
523530
{outDimNames[order[0]], outDimNames[order[1]]});
531+
// For mfma.transposed layout, the element ownership among threads are
532+
// "transposed" within each warp.
533+
if (getIsTransposed())
534+
tileLayout = LinearLayout(
535+
{{kRegister, {{1, 0}, {2, 0}}},
536+
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}},
537+
{outDimNames[order[0]], outDimNames[order[1]]});
524538
}
525539
if (hasBatchDim) {
526540
assert(order[2] == 0);

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -269,23 +269,6 @@ class BlockedToMFMA : public RewritePattern {
269269
: RewritePattern(tt::DotOp::getOperationName(), 2, context),
270270
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {}
271271

272-
bool isChainDot(tt::DotOp &dotOp) const {
273-
auto filter = [&dotOp](Operation *op) {
274-
return op->getParentRegion() == dotOp->getParentRegion();
275-
};
276-
ForwardSliceOptions fwdOpt;
277-
fwdOpt.filter = filter;
278-
BackwardSliceOptions bwdOpt;
279-
bwdOpt.omitBlockArguments = true;
280-
bwdOpt.filter = filter;
281-
auto slices = getSlice(dotOp, bwdOpt, fwdOpt);
282-
for (Operation *op : slices) {
283-
if (isa<tt::DotOp>(op) && (op != dotOp))
284-
return true;
285-
}
286-
return false;
287-
}
288-
289272
bool isSecondDot(tt::DotOp &dotOp) const {
290273
auto filter = [&dotOp](Operation *op) {
291274
return op->getParentRegion() == dotOp->getParentRegion();
@@ -400,11 +383,12 @@ class BlockedToMFMA : public RewritePattern {
400383
auto warpsPerTile =
401384
warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim});
402385

403-
bool isTransposed = isChainDot(dotOp);
386+
// Always use transposed mfma layout. This enables larger vectorization
387+
// for global store instructions
404388
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
405389
oldRetType.getContext(),
406390
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
407-
/*instrShape*/ mDim, nDim, isTransposed, CTALayout);
391+
/*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout);
408392

409393
Type mfmaAccType;
410394
if (oldRetType.getElementType().isIntOrIndex())

unittest/Dialect/TritonGPU/DialectTest.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -559,15 +559,15 @@ TEST_F(AMDMfmaLayoutTest, mfma32) {
559559

560560
auto tmfma2d = createTransposedMFMA(32, 32, {2, 4});
561561
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u));
562-
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u));
562+
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));
563563

564564
auto mfma3d = createMFMA(32, 32, {2, 4, 1});
565565
ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
566566
ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
567567

568568
auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1});
569569
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u));
570-
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u));
570+
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
571571
}
572572

573573
TEST_F(AMDMfmaLayoutTest, mfma16) {
@@ -577,15 +577,15 @@ TEST_F(AMDMfmaLayoutTest, mfma16) {
577577

578578
auto tmfma2d = createTransposedMFMA(16, 16, {2, 4});
579579
ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u));
580-
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u));
580+
ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u));
581581

582582
auto mfma3d = createMFMA(16, 16, {2, 4, 1});
583583
ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u));
584584
ASSERT_THAT(mfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
585585

586586
auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1});
587587
ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u));
588-
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u));
588+
ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u));
589589
}
590590

591591
} // anonymous namespace

unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,14 +529,14 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
529529
LinearLayout(
530530
{{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}},
531531
{S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}},
532-
{S("warp"), {{32, 0}, {0, 0}, {0, 0}}},
532+
{S("warp"), {{0, 0}, {0, 0}, {32, 0}}},
533533
{S("block"), {}}},
534534
{S("dim0"), S("dim1")}));
535535
EXPECT_EQ(toLinearLayout({128, 128}, mfmaT),
536536
LinearLayout(
537537
{{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {64, 0}}},
538538
{S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}},
539-
{S("warp"), {{32, 0}, {0, 32}, {0, 64}}},
539+
{S("warp"), {{0, 32}, {0, 64}, {32, 0}}},
540540
{S("block"), {}}},
541541
{S("dim0"), S("dim1")}));
542542
}

0 commit comments

Comments
 (0)