Skip to content

Commit 66e8629

Browse files
oplavsicOgnjen Plavsic
andauthored
[AMD] Implement RepOrder for AMD MMA layouts (#5126)
Implement RepOrder methods for MFMA and WMMA layouts. Both layouts have row major rep layout. Also, isTranspose flag in MFMA layout does not affect RepOrder, meaning RepOrder is row major in both cases. Co-authored-by: Ognjen Plavsic <[email protected]>
1 parent 0bd30a2 commit 66e8629

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
777777
"getSizePerThreadForOperand",
778778
(ins "int":$opIdx,
779779
"int":$kWidth)>,
780+
781+
InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
782+
"SmallVector<unsigned>",
783+
"getRepOrderForOperand",
784+
(ins "int":$opIdx)>,
780785
];
781786
}
782787

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,7 +1658,14 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
16581658
}
16591659

16601660
SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
1661-
llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder");
1661+
auto rank = getWarpsPerCTA().size();
1662+
return getMatrixOrder(rank, /*rowMajor*/ true);
1663+
}
1664+
1665+
SmallVector<unsigned>
1666+
AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
1667+
auto rank = getWarpsPerCTA().size();
1668+
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
16621669
}
16631670

16641671
SmallVector<int64_t>
@@ -1745,8 +1752,16 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
17451752
return shapePerCTATile;
17461753
}
17471754
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
1748-
llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder");
1755+
auto rank = getWarpsPerCTA().size();
1756+
return getMatrixOrder(rank, /*rowMajor*/ true);
17491757
}
1758+
1759+
SmallVector<unsigned>
1760+
AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
1761+
auto rank = getWarpsPerCTA().size();
1762+
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
1763+
}
1764+
17501765
SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {
17511766
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
17521767
}
@@ -2016,7 +2031,7 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
20162031
// DotOperand Encoding
20172032
//===----------------------------------------------------------------------===//
20182033
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
2019-
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
2034+
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
20202035
return mma.getRepOrderForOperand(getOpIdx());
20212036
}
20222037
llvm::report_fatal_error(

0 commit comments

Comments
 (0)