Skip to content

Commit 6d95f24

Browse files
Merge commit 'ec0bd4ac8d5d35896193393e19614c3a7dade5ae'
2 parents 25a7cba + ec0bd4a commit 6d95f24

File tree

6 files changed

+168
-16
lines changed

6 files changed

+168
-16
lines changed

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
8888
// encoding not available
8989
return resultVals;
9090
Attribute baseEncoding = encoding;
91-
if (isa<AMDMfmaEncodingAttr>(baseEncoding))
92-
// TODO: this logic seems incorrect for mfma layout. Skip for now.
93-
// We saw mismatches for some flash-attention tests on AMD backend.
94-
// Note that this logic works for sliced layout whose parent is
91+
if (isa<AMDMfmaEncodingAttr>(baseEncoding) ||
92+
isa<AMDWmmaEncodingAttr>(baseEncoding))
93+
// TODO: this logic seems incorrect for mfma and wmma layout. Skip for
94+
// now. We saw mismatches for some flash-attention and dot tests on AMD
95+
// backend. Note that this logic works for sliced layout whose parent is
9596
// mfma layout. Therefore, this is not combined with the following check.
9697
return resultVals;
9798
while (auto sliced = dyn_cast<SliceEncodingAttr>(baseEncoding))

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,13 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
250250
ArrayRef<unsigned> repShape,
251251
ArrayRef<unsigned> paddedRepShape,
252252
ArrayRef<unsigned> order, int swizzleByteSize);
253+
254+
// FIXME
255+
// Exposing to use it in LinearLayoutConversionsTest.cpp
256+
// Remove it once we fully activate the DotOperand conversion via LLs
257+
class DotOperandEncodingAttr;
258+
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
259+
DotOperandEncodingAttr dot);
253260
} // namespace mlir::triton::gpu
254261

255262
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,16 +1044,12 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
10441044
return res;
10451045
}
10461046
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
1047-
auto parentLayout = getParent();
1048-
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
1049-
if (auto distributedLayout =
1050-
mlir::dyn_cast<DistributedEncodingTrait>(parentLayout)) {
1051-
return distributedLayout.getWarpsPerCTA();
1052-
} else {
1053-
llvm::report_fatal_error(
1054-
"DotOperandEncodingAttr non-DistributedEncodingAttr parent not "
1055-
"supported yet");
1056-
}
1047+
auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent());
1048+
auto warps = distributedLayout.getWarpsPerCTA();
1049+
auto rank = warps.size();
1050+
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
1051+
warps[kDim] = 1;
1052+
return warps;
10571053
}
10581054
SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
10591055
return ::getWarpOrder(*this);

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
77
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
8+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
89
#include "triton/Tools/LinearLayout.h"
910
#include "triton/Tools/StrUtil.h"
1011
#include "llvm/ADT/DenseMap.h"
@@ -822,16 +823,81 @@ SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
822823
return ret;
823824
}
824825

826+
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
827+
DotOperandEncodingAttr dot) {
828+
// TODO,BE. Implement ampereMMA in terms of this one
829+
int rank = shape.size();
830+
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
831+
int kWidth = dot.getKWidth();
832+
bool isA = dot.getOpIdx() == 0;
833+
834+
assert(mma.isAmpere());
835+
assert((rank == 2 && mma.getInstrShape() == ArrayRef<unsigned>({16, 8})) ||
836+
(rank == 3 && mma.getInstrShape() == ArrayRef<unsigned>({1, 16, 8})));
837+
838+
MLIRContext *ctx = mma.getContext();
839+
SmallVector<StringAttr> dimNames = standardOutDimNames(ctx, rank);
840+
841+
// Implement A. For B transpose in the end
842+
std::vector<std::vector<int32_t>> registers;
843+
std::vector<std::vector<int32_t>> lanes;
844+
int32_t i = 1;
845+
// kWidth contiguous elements
846+
while (i < kWidth) {
847+
registers.push_back({i, 0});
848+
i *= 2;
849+
}
850+
// 4 threads per chunk
851+
for (int j = 0; j < 2; j++) {
852+
lanes.push_back({i, 0});
853+
i *= 2;
854+
}
855+
// 8 threads going down
856+
lanes.push_back({0, 1});
857+
lanes.push_back({0, 2});
858+
lanes.push_back({0, 4});
859+
// 2 tiles in column-major order
860+
// Just one if it's the B operand
861+
if (isA) {
862+
registers.push_back({0, 8});
863+
}
864+
registers.push_back({i, 0});
865+
866+
if (!isA) {
867+
for (auto &r : registers) {
868+
std::swap(r[0], r[1]);
869+
}
870+
for (auto &l : lanes) {
871+
std::swap(l[0], l[1]);
872+
}
873+
}
874+
875+
LinearLayout ctaLayout(
876+
{{S("register"), registers}, {S("lane"), lanes}},
877+
llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2))));
878+
879+
auto order = dot.getCTAOrder();
880+
assert(order[0] == 1 && order[1] == 0);
881+
ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames);
882+
883+
return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape);
884+
}
885+
825886
std::optional<LinearLayout>
826887
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
827-
828888
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
829889
return dotOperandMfmaToLinearLayout(*this, shape);
830890
}
831891
if (auto dpasLayout = llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
832892
return dotOperandDpasToLinearLayout(*this, shape);
833893
}
834894

895+
// TODO Activate in a follow-up PR
896+
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
897+
// if (mma.isAmpere()) {
898+
// return ampereDotToLinearLayout(shape, *this);
899+
// }
900+
//}
835901
return std::nullopt;
836902
}
837903

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __post_init__(self):
5858
default_libdir = Path(__file__).parent / 'lib'
5959
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
6060
# Ignore user-defined warp size for gfx9
61-
warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch else 64
61+
warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64
6262
object.__setattr__(self, 'warp_size', warp_size)
6363
libs = ["ocml", "ockl"]
6464
for lib in libs:

unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
66
#include "triton/Tools/StrUtil.h"
7+
#include "llvm/ADT/ArrayRef.h"
78
#include "llvm/Support/Signals.h"
89
#include <gmock/gmock.h>
910
#include <gtest/gtest.h>
@@ -40,6 +41,12 @@ class LinearLayoutConversionsTest : public ::testing::Test {
4041
CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape);
4142
}
4243

44+
DotOperandEncodingAttr dotMMAv2(int idx, int kWidth, ArrayRef<unsigned> warps,
45+
ArrayRef<unsigned> order) {
46+
auto mmaLayout = mma(2, 0, {16, 8}, warps, {1, 1}, {1, 1}, order);
47+
return DotOperandEncodingAttr::get(&ctx, idx, mmaLayout, /*kWidth=*/kWidth);
48+
}
49+
4350
AMDMfmaEncodingAttr mfma(ArrayRef<unsigned> warps, unsigned mDim,
4451
unsigned nDim, bool isTransposed) {
4552
SmallVector<unsigned> cpg(warps.size(), 1u);
@@ -494,6 +501,81 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) {
494501
{S("dim0"), S("dim1")}));
495502
}
496503

504+
TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) {
505+
EXPECT_EQ(ampereDotToLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})),
506+
LinearLayout(
507+
{
508+
{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}},
509+
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
510+
{S("warp"), {}},
511+
{S("block"), {}},
512+
},
513+
{S("dim0"), S("dim1")}));
514+
EXPECT_EQ(ampereDotToLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})),
515+
LinearLayout(
516+
{
517+
{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}},
518+
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
519+
{S("warp"), {}},
520+
{S("block"), {}},
521+
},
522+
{S("dim0"), S("dim1")}));
523+
}
524+
525+
TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) {
526+
EXPECT_EQ(
527+
ampereDotToLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})),
528+
LinearLayout(
529+
{
530+
{S("register"),
531+
{{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {64, 0}}},
532+
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
533+
{S("warp"), {{16, 0}, {32, 0}}},
534+
{S("block"), {}},
535+
},
536+
{S("dim0"), S("dim1")}));
537+
EXPECT_EQ(ampereDotToLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
538+
LinearLayout(
539+
{
540+
{S("register"),
541+
{{1, 0},
542+
{2, 0},
543+
{4, 0},
544+
{32, 0},
545+
{0, 8},
546+
{0, 16},
547+
{0, 32},
548+
{64, 0}}},
549+
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
550+
{
551+
S("warp"),
552+
{},
553+
},
554+
{S("block"), {}},
555+
},
556+
{S("dim0"), S("dim1")}));
557+
EXPECT_EQ(ampereDotToLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})),
558+
LinearLayout(
559+
{
560+
{S("register"),
561+
{{1, 0},
562+
{2, 0},
563+
{4, 0},
564+
{32, 0},
565+
{0, 8},
566+
{0, 16},
567+
{0, 32},
568+
{0, 64}}},
569+
{S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}},
570+
{
571+
S("warp"),
572+
{},
573+
},
574+
{S("block"), {}},
575+
},
576+
{S("dim0"), S("dim1")}));
577+
}
578+
497579
TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) {
498580
auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32,
499581
/*isTransposed=*/false);

0 commit comments

Comments
 (0)