Skip to content

Commit ec0bd4a

Browse files
authored
[Linear Layouts] Implement LL conversion for DotOperand(version=2) (#4891)
Note that the current implementation uses `DotOperandEncodingAttr::getWarpsPerCTA`, which was buggy for cases where the warps are not of the form `[numWarps, 1]` or `[1, numWarps]`. This PR bundles a fix for this issue. We will activate its use for a subset of `DotOperandEncoding`s in a PR coming soon.
1 parent a60fa8c commit ec0bd4a

File tree

4 files changed

+162
-12
lines changed

4 files changed

+162
-12
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
250250
ArrayRef<unsigned> repShape,
251251
ArrayRef<unsigned> paddedRepShape,
252252
ArrayRef<unsigned> order, int swizzleByteSize);
253+
254+
// FIXME
255+
// Exposing to use it in LinearLayoutConversionsTest.cpp
256+
// Remove it once we fully activate the DotOperand conversion via LLs
257+
class DotOperandEncodingAttr;
258+
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
259+
DotOperandEncodingAttr dot);
253260
} // namespace mlir::triton::gpu
254261

255262
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,16 +1037,12 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
10371037
return res;
10381038
}
10391039
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
1040-
auto parentLayout = getParent();
1041-
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
1042-
if (auto distributedLayout =
1043-
mlir::dyn_cast<DistributedEncodingTrait>(parentLayout)) {
1044-
return distributedLayout.getWarpsPerCTA();
1045-
} else {
1046-
llvm::report_fatal_error(
1047-
"DotOperandEncodingAttr non-DistributedEncodingAttr parent not "
1048-
"supported yet");
1049-
}
1040+
auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent());
1041+
auto warps = distributedLayout.getWarpsPerCTA();
1042+
auto rank = warps.size();
1043+
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
1044+
warps[kDim] = 1;
1045+
return warps;
10501046
}
10511047
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10521048
return ::getWarpOrder(*this);

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
66
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
7+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
78
#include "triton/Tools/LinearLayout.h"
89
#include "triton/Tools/StrUtil.h"
910
#include "llvm/ADT/DenseMap.h"
@@ -821,13 +822,77 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
821822
return ret;
822823
}
823824

825+
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
826+
DotOperandEncodingAttr dot) {
827+
// TODO,BE. Implement ampereMMA in terms of this one
828+
int rank = shape.size();
829+
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
830+
int kWidth = dot.getKWidth();
831+
bool isA = dot.getOpIdx() == 0;
832+
833+
assert(mma.isAmpere());
834+
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
835+
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
836+
837+
MLIRContext *ctx = mma.getContext();
838+
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
839+
840+
// Implement A. For B transpose in the end
841+
std::vector<std::vector<int32_t>> registers;
842+
std::vector<std::vector<int32_t>> lanes;
843+
int32_t i = 1;
844+
// kWidth contiguous elements
845+
while (i < kWidth) {
846+
registers.push_back({i, 0});
847+
i *= 2;
848+
}
849+
// 4 threads per chunk
850+
for (int j = 0; j < 2; j++) {
851+
lanes.push_back({i, 0});
852+
i *= 2;
853+
}
854+
// 8 threads going down
855+
lanes.push_back({0, 1});
856+
lanes.push_back({0, 2});
857+
lanes.push_back({0, 4});
858+
// 2 tiles in column-major order
859+
// Just one if it's the B operand
860+
if (isA) {
861+
registers.push_back({0, 8});
862+
}
863+
registers.push_back({i, 0});
864+
865+
if (!isA) {
866+
for (auto &r : registers) {
867+
std::swap(r[0], r[1]);
868+
}
869+
for (auto &l : lanes) {
870+
std::swap(l[0], l[1]);
871+
}
872+
}
873+
874+
LinearLayout ctaLayout(
875+
{{S("register"), registers}, {S("lane"), lanes}},
876+
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));
877+
878+
auto order = dot.getCTAOrder();
879+
assert(order[0] == 1 && order[1] == 0);
880+
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
881+
882+
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
883+
}
884+
824885
std::optional<LinearLayout>
825886
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
826-
827887
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
828888
return dotOperandMfmaToLinearLayout(*this, shape);
829889
}
830-
890+
// TODO Activate in a follow-up PR
891+
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
892+
// if (mma.isAmpere()) {
893+
// return ampereDotToLinearLayout(shape, *this);
894+
// }
895+
//}
831896
return std::nullopt;
832897
}
833898

unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
66
#include "triton/Tools/StrUtil.h"
7+
#include "llvm/ADT/ArrayRef.h"
78
#include "llvm/Support/Signals.h"
89
#include <gmock/gmock.h>
910
#include <gtest/gtest.h>
@@ -40,6 +41,12 @@ class LinearLayoutConversionsTest : public ::testing::Test {
4041
CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape);
4142
}
4243

44+
DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef<unsigned> warps,
45+
ArrayRef<unsigned> order) {
46+
auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order);
47+
return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth);
48+
}
49+
4350
AMDMfmaEncodingAttr mfma(ArrayRef<unsigned> warps, unsigned mDim,
4451
unsigned nDim, bool isTransposed) {
4552
SmallVector<unsigned> cpg(warps.size(), 1u);
@@ -494,6 +501,81 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) {
494501
{S("dim0"), S("dim1")}));
495502
}
496503

504+
TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) {
505+
EXPECT_EQ(ampereDotToLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})),
506+
LinearLayout(
507+
{
508+
{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}},
509+
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
510+
{S("warp"), {}},
511+
{S("block"), {}},
512+
},
513+
{S("dim0"), S("dim1")}));
514+
EXPECT_EQ(ampereDotToLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})),
515+
LinearLayout(
516+
{
517+
{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}},
518+
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
519+
{S("warp"), {}},
520+
{S("block"), {}},
521+
},
522+
{S("dim0"), S("dim1")}));
523+
}
524+
525+
TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
526+
EXPECT_EQ(
527+
ampereDotToLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})),
528+
LinearLayout(
529+
{
530+
{S("register"),
531+
{{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {64, 0}}},
532+
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
533+
{S("warp"), {{16, 0}, {32, 0}}},
534+
{S("block"), {}},
535+
},
536+
{S("dim0"), S("dim1")}));
537+
EXPECT_EQ(ampereDotToLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
538+
LinearLayout(
539+
{
540+
{S("register"),
541+
{{1, 0},
542+
{2, 0},
543+
{4, 0},
544+
{32, 0},
545+
{0, 8},
546+
{0, 16},
547+
{0, 32},
548+
{64, 0}}},
549+
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
550+
{
551+
S("warp"),
552+
{},
553+
},
554+
{S("block"), {}},
555+
},
556+
{S("dim0"), S("dim1")}));
557+
EXPECT_EQ(ampereDotToLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
558+
LinearLayout(
559+
{
560+
{S("register"),
561+
{{1, 0},
562+
{2, 0},
563+
{4, 0},
564+
{32, 0},
565+
{0, 8},
566+
{0, 16},
567+
{0, 32},
568+
{0, 64}}},
569+
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
570+
{
571+
S("warp"),
572+
{},
573+
},
574+
{S("block"), {}},
575+
},
576+
{S("dim0"), S("dim1")}));
577+
}
578+
497579
TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
498580
auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32,
499581
/*isTransposed=*/false);

0 commit comments

Comments
 (0)