Skip to content

Commit 2beff56

Browse files
Revert "Revert "[Linear Layouts] Implement LL conversion for DotOperand(version=2) (#4891)""
This reverts commit d16ef81.
1 parent a3adef5 commit 2beff56

File tree

3 files changed

+155
-11
lines changed

3 files changed

+155
-11
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,16 +1044,12 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
10441044
return res;
10451045
}
10461046
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
1047-
auto parentLayout = getParent();
1048-
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
1049-
if (auto distributedLayout =
1050-
mlir::dyn_cast<DistributedEncodingTrait>(parentLayout)) {
1051-
return distributedLayout.getWarpsPerCTA();
1052-
} else {
1053-
llvm::report_fatal_error(
1054-
"DotOperandEncodingAttr non-DistributedEncodingAttr parent not "
1055-
"supported yet");
1056-
}
1047+
auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent());
1048+
auto warps = distributedLayout.getWarpsPerCTA();
1049+
auto rank = warps.size();
1050+
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
1051+
warps[kDim] = 1;
1052+
return warps;
10571053
}
10581054
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10591055
return ::getWarpOrder(*this);

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
77
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
8+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
89
#include "triton/Tools/LinearLayout.h"
910
#include "triton/Tools/StrUtil.h"
1011
#include "llvm/ADT/DenseMap.h"
@@ -822,16 +823,81 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
822823
return ret;
823824
}
824825

826+
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
827+
DotOperandEncodingAttr dot) {
828+
// TODO,BE. Implement ampereMMA in terms of this one
829+
int rank = shape.size();
830+
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
831+
int kWidth = dot.getKWidth();
832+
bool isA = dot.getOpIdx() == 0;
833+
834+
assert(mma.isAmpere());
835+
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
836+
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
837+
838+
MLIRContext *ctx = mma.getContext();
839+
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
840+
841+
// Implement A. For B transpose in the end
842+
std::vector<std::vector<int32_t>> registers;
843+
std::vector<std::vector<int32_t>> lanes;
844+
int32_t i = 1;
845+
// kWidth contiguous elements
846+
while (i < kWidth) {
847+
registers.push_back({i, 0});
848+
i *= 2;
849+
}
850+
// 4 threads per chunk
851+
for (int j = 0; j < 2; j++) {
852+
lanes.push_back({i, 0});
853+
i *= 2;
854+
}
855+
// 8 threads going down
856+
lanes.push_back({0, 1});
857+
lanes.push_back({0, 2});
858+
lanes.push_back({0, 4});
859+
// 2 tiles in column-major order
860+
// Just one if it's the B operand
861+
if (isA) {
862+
registers.push_back({0, 8});
863+
}
864+
registers.push_back({i, 0});
865+
866+
if (!isA) {
867+
for (auto &r : registers) {
868+
std::swap(r[0], r[1]);
869+
}
870+
for (auto &l : lanes) {
871+
std::swap(l[0], l[1]);
872+
}
873+
}
874+
875+
LinearLayout ctaLayout(
876+
{{S("register"), registers}, {S("lane"), lanes}},
877+
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));
878+
879+
auto order = dot.getCTAOrder();
880+
assert(order[0] == 1 && order[1] == 0);
881+
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
882+
883+
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
884+
}
885+
825886
std::optional<LinearLayout>
826887
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
827-
828888
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
829889
return dotOperandMfmaToLinearLayout(*this, shape);
830890
}
831891
if (auto dpasLayout = llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
832892
return dotOperandDpasToLinearLayout(*this, shape);
833893
}
834894

895+
// TODO Activate in a follow-up PR
896+
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
897+
// if (mma.isAmpere()) {
898+
// return ampereDotToLinearLayout(shape, *this);
899+
// }
900+
//}
835901
return std::nullopt;
836902
}
837903

unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
66
#include "triton/Tools/StrUtil.h"
7+
#include "llvm/ADT/ArrayRef.h"
78
#include "llvm/Support/Signals.h"
89
#include <gmock/gmock.h>
910
#include <gtest/gtest.h>
@@ -40,6 +41,12 @@ class LinearLayoutConversionsTest : public ::testing::Test {
4041
CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape);
4142
}
4243

44+
DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef<unsigned> warps,
45+
ArrayRef<unsigned> order) {
46+
auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order);
47+
return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth);
48+
}
49+
4350
AMDMfmaEncodingAttr mfma(ArrayRef<unsigned> warps, unsigned mDim,
4451
unsigned nDim, bool isTransposed) {
4552
SmallVector<unsigned> cpg(warps.size(), 1u);
@@ -494,6 +501,81 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) {
494501
{S("dim0"), S("dim1")}));
495502
}
496503

504+
TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) {
505+
EXPECT_EQ(ampereDotToLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})),
506+
LinearLayout(
507+
{
508+
{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}},
509+
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
510+
{S("warp"), {}},
511+
{S("block"), {}},
512+
},
513+
{S("dim0"), S("dim1")}));
514+
EXPECT_EQ(ampereDotToLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})),
515+
LinearLayout(
516+
{
517+
{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}},
518+
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
519+
{S("warp"), {}},
520+
{S("block"), {}},
521+
},
522+
{S("dim0"), S("dim1")}));
523+
}
524+
525+
TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
526+
EXPECT_EQ(
527+
ampereDotToLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})),
528+
LinearLayout(
529+
{
530+
{S("register"),
531+
{{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {64, 0}}},
532+
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
533+
{S("warp"), {{16, 0}, {32, 0}}},
534+
{S("block"), {}},
535+
},
536+
{S("dim0"), S("dim1")}));
537+
EXPECT_EQ(ampereDotToLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
538+
LinearLayout(
539+
{
540+
{S("register"),
541+
{{1, 0},
542+
{2, 0},
543+
{4, 0},
544+
{32, 0},
545+
{0, 8},
546+
{0, 16},
547+
{0, 32},
548+
{64, 0}}},
549+
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
550+
{
551+
S("warp"),
552+
{},
553+
},
554+
{S("block"), {}},
555+
},
556+
{S("dim0"), S("dim1")}));
557+
EXPECT_EQ(ampereDotToLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
558+
LinearLayout(
559+
{
560+
{S("register"),
561+
{{1, 0},
562+
{2, 0},
563+
{4, 0},
564+
{32, 0},
565+
{0, 8},
566+
{0, 16},
567+
{0, 32},
568+
{0, 64}}},
569+
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
570+
{
571+
S("warp"),
572+
{},
573+
},
574+
{S("block"), {}},
575+
},
576+
{S("dim0"), S("dim1")}));
577+
}
578+
497579
TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
498580
auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32,
499581
/*isTransposed=*/false);

0 commit comments

Comments
 (0)