Skip to content

Commit d16ef81

Browse files
Revert "[Linear Layouts] Implement LL conversion for DotOperand(version=2) (#4891)"
This reverts commit ec0bd4a.
1 parent 6d95f24 commit d16ef81

File tree

4 files changed

+11
-162
lines changed

4 files changed

+11
-162
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,13 +250,6 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
250250
ArrayRef<unsigned> repShape,
251251
ArrayRef<unsigned> paddedRepShape,
252252
ArrayRef<unsigned> order, int swizzleByteSize);
253-
254-
// FIXME
255-
// Exposing to use it in LinearLayoutConversionsTest.cpp
256-
// Remove it once we fully activate the DotOperand conversion via LLs
257-
class DotOperandEncodingAttr;
258-
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
259-
DotOperandEncodingAttr dot);
260253
} // namespace mlir::triton::gpu
261254

262255
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,12 +1044,16 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
10441044
return res;
10451045
}
10461046
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
1047-
auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent());
1048-
auto warps = distributedLayout.getWarpsPerCTA();
1049-
auto rank = warps.size();
1050-
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
1051-
warps[kDim] = 1;
1052-
return warps;
1047+
auto parentLayout = getParent();
1048+
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
1049+
if (auto distributedLayout =
1050+
mlir::dyn_cast<DistributedEncodingTrait>(parentLayout)) {
1051+
return distributedLayout.getWarpsPerCTA();
1052+
} else {
1053+
llvm::report_fatal_error(
1054+
"DotOperandEncodingAttr non-DistributedEncodingAttr parent not "
1055+
"supported yet");
1056+
}
10531057
}
10541058
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10551059
return ::getWarpOrder(*this);

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
77
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
8-
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
98
#include "triton/Tools/LinearLayout.h"
109
#include "triton/Tools/StrUtil.h"
1110
#include "llvm/ADT/DenseMap.h"
@@ -823,81 +822,16 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
823822
return ret;
824823
}
825824

826-
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
827-
DotOperandEncodingAttr dot) {
828-
// TODO,BE. Implement ampereMMA in terms of this one
829-
int rank = shape.size();
830-
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
831-
int kWidth = dot.getKWidth();
832-
bool isA = dot.getOpIdx() == 0;
833-
834-
assert(mma.isAmpere());
835-
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
836-
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
837-
838-
MLIRContext *ctx = mma.getContext();
839-
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
840-
841-
// Implement A. For B transpose in the end
842-
std::vector<std::vector<int32_t>> registers;
843-
std::vector<std::vector<int32_t>> lanes;
844-
int32_t i = 1;
845-
// kWidth contiguous elements
846-
while (i < kWidth) {
847-
registers.push_back({i, 0});
848-
i *= 2;
849-
}
850-
// 4 threads per chunk
851-
for (int j = 0; j < 2; j++) {
852-
lanes.push_back({i, 0});
853-
i *= 2;
854-
}
855-
// 8 threads going down
856-
lanes.push_back({0, 1});
857-
lanes.push_back({0, 2});
858-
lanes.push_back({0, 4});
859-
// 2 tiles in column-major order
860-
// Just one if it's the B operand
861-
if (isA) {
862-
registers.push_back({0, 8});
863-
}
864-
registers.push_back({i, 0});
865-
866-
if (!isA) {
867-
for (auto &r : registers) {
868-
std::swap(r[0], r[1]);
869-
}
870-
for (auto &l : lanes) {
871-
std::swap(l[0], l[1]);
872-
}
873-
}
874-
875-
LinearLayout ctaLayout(
876-
{{S("register"), registers}, {S("lane"), lanes}},
877-
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));
878-
879-
auto order = dot.getCTAOrder();
880-
assert(order[0] == 1 && order[1] == 0);
881-
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
882-
883-
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
884-
}
885-
886825
std::optional<LinearLayout>
887826
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
827+
888828
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
889829
return dotOperandMfmaToLinearLayout(*this, shape);
890830
}
891831
if (auto dpasLayout = llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
892832
return dotOperandDpasToLinearLayout(*this, shape);
893833
}
894834

895-
// TODO Activate in a follow-up PR
896-
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
897-
// if (mma.isAmpere()) {
898-
// return ampereDotToLinearLayout(shape, *this);
899-
// }
900-
//}
901835
return std::nullopt;
902836
}
903837

unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
66
#include "triton/Tools/StrUtil.h"
7-
#include "llvm/ADT/ArrayRef.h"
87
#include "llvm/Support/Signals.h"
98
#include <gmock/gmock.h>
109
#include <gtest/gtest.h>
@@ -41,12 +40,6 @@ class LinearLayoutConversionsTest : public ::testing::Test {
4140
CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape);
4241
}
4342

44-
DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef<unsigned> warps,
45-
ArrayRef<unsigned> order) {
46-
auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order);
47-
return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth);
48-
}
49-
5043
AMDMfmaEncodingAttr mfma(ArrayRef<unsigned> warps, unsigned mDim,
5144
unsigned nDim, bool isTransposed) {
5245
SmallVector<unsigned> cpg(warps.size(), 1u);
@@ -501,81 +494,6 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) {
501494
{S("dim0"), S("dim1")}));
502495
}
503496

504-
TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) {
505-
EXPECT_EQ(ampereDotToLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})),
506-
LinearLayout(
507-
{
508-
{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}},
509-
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
510-
{S("warp"), {}},
511-
{S("block"), {}},
512-
},
513-
{S("dim0"), S("dim1")}));
514-
EXPECT_EQ(ampereDotToLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})),
515-
LinearLayout(
516-
{
517-
{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}},
518-
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
519-
{S("warp"), {}},
520-
{S("block"), {}},
521-
},
522-
{S("dim0"), S("dim1")}));
523-
}
524-
525-
TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
526-
EXPECT_EQ(
527-
ampereDotToLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})),
528-
LinearLayout(
529-
{
530-
{S("register"),
531-
{{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {64, 0}}},
532-
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
533-
{S("warp"), {{16, 0}, {32, 0}}},
534-
{S("block"), {}},
535-
},
536-
{S("dim0"), S("dim1")}));
537-
EXPECT_EQ(ampereDotToLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
538-
LinearLayout(
539-
{
540-
{S("register"),
541-
{{1, 0},
542-
{2, 0},
543-
{4, 0},
544-
{32, 0},
545-
{0, 8},
546-
{0, 16},
547-
{0, 32},
548-
{64, 0}}},
549-
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
550-
{
551-
S("warp"),
552-
{},
553-
},
554-
{S("block"), {}},
555-
},
556-
{S("dim0"), S("dim1")}));
557-
EXPECT_EQ(ampereDotToLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
558-
LinearLayout(
559-
{
560-
{S("register"),
561-
{{1, 0},
562-
{2, 0},
563-
{4, 0},
564-
{32, 0},
565-
{0, 8},
566-
{0, 16},
567-
{0, 32},
568-
{0, 64}}},
569-
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
570-
{
571-
S("warp"),
572-
{},
573-
},
574-
{S("block"), {}},
575-
},
576-
{S("dim0"), S("dim1")}));
577-
}
578-
579497
TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
580498
auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32,
581499
/*isTransposed=*/false);

0 commit comments

Comments
 (0)