Skip to content

Commit 6130c2b

Browse files
authored
[BACKEND][NVIDIA] Remove NvidiaMma::getTotalElems...ForOperand (#5105)
Fixes triton-lang/triton#5102 The logic in `getTotalElemsPerThreadForOperand` should now directly match that in `SharedToDotOperandMMAv2OrV3`
1 parent 7873637 commit 6130c2b

File tree

2 files changed

+13
-34
lines changed

2 files changed

+13
-34
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -772,14 +772,6 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
772772
"int":$kWidth,
773773
"int":$opIdx)>,
774774

775-
InterfaceMethod<"Return total element size per thread for dot operands.",
776-
"unsigned",
777-
"getTotalElemsPerThreadForOperand",
778-
(ins "ArrayRef<int64_t>":$tensorShape,
779-
"Type":$eltTy,
780-
"int":$kWidth,
781-
"int":$opIdx)>,
782-
783775
InterfaceMethod<"Return size per thread for dot operands.",
784776
"SmallVector<unsigned>",
785777
"getSizePerThreadForOperand",
@@ -1143,7 +1135,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
11431135
};
11441136
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
11451137
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
1146-
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
11471138

11481139
SmallVector<unsigned> getContigPerThread() {
11491140
assert(isAmpere() || isHopper());

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -938,11 +938,11 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
938938
elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2];
939939
return elemsPerThread;
940940
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
941-
if (mma.isAmpere()) {
941+
if (mma.isAmpere() || mma.isHopper()) {
942942
auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth();
943943
auto rep = mma.getRepForOperand(shape, bitwidth, idx);
944944
auto sizePerThread = getSizePerThread();
945-
auto elemsPerKRep = 32 / bitwidth * 2;
945+
auto elemsPerKRep = mma.isHopper() ? (kWidth * 2) : (32 / bitwidth * 2);
946946
if (rank == 3)
947947
elemsPerThread[0] = rep[0];
948948
elemsPerThread[rank - 2] =
@@ -964,12 +964,18 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
964964
unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
965965
Type eltTy) const {
966966
if (auto mmaParent = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
967-
if (auto nvidiaMmaParent = mlir::dyn_cast<NvidiaMmaEncodingAttr>(mmaParent);
968-
nvidiaMmaParent && nvidiaMmaParent.isAmpere()) {
967+
if (auto nvidiaMmaParent =
968+
mlir::dyn_cast<NvidiaMmaEncodingAttr>(mmaParent)) {
969969
return product<unsigned>(getElemsPerThread(shape, eltTy));
970970
}
971-
return mmaParent.getTotalElemsPerThreadForOperand(shape, eltTy, getKWidth(),
972-
getOpIdx());
971+
if (auto amdMfmaParent = mlir::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
972+
return amdMfmaParent.getTotalElemsPerThreadForOperand(
973+
shape, eltTy, getKWidth(), getOpIdx());
974+
}
975+
if (auto amdWmmaParent = mlir::dyn_cast<AMDWmmaEncodingAttr>(getParent())) {
976+
return amdWmmaParent.getTotalElemsPerThreadForOperand(
977+
shape, eltTy, getKWidth(), getOpIdx());
978+
}
973979
}
974980
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
975981
auto shapePerCTA = getShapePerCTA(*this, shape);
@@ -1981,26 +1987,9 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
19811987
}
19821988
}
19831989

1984-
unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand(
1985-
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
1986-
auto shapePerCTA = getShapePerCTA(*this, shape);
1987-
int warpsPerCTAM = getWarpsPerCTA()[0];
1988-
int warpsPerCTAN = getWarpsPerCTA()[1];
1989-
// H100
1990-
if (isHopper()) {
1991-
assert(opIdx == 0);
1992-
auto instrMNK = getInstrShape();
1993-
int repM = ceil<unsigned>(shapePerCTA[0], instrMNK[0] * warpsPerCTAM);
1994-
int repK = ceil<unsigned>(shapePerCTA[1], instrMNK[2]);
1995-
// For each WGMMA instr, a 2x2 matrix fragment is loaded. Each thread holds
1996-
// kWidth elements for each quadrant. WGMMA is repeated repM * repK times.
1997-
return 4 * kWidth * repM * repK;
1998-
}
1999-
llvm_unreachable("unknown mma layout");
2000-
}
20011990
SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
20021991
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
2003-
assert(isAmpere() && "mmaLayout version = 1 is not implemented yet");
1992+
assert(isAmpere() && "mmaLayout Hopper is not implemented yet");
20041993
auto shapePerCTATile = getShapePerCTATile(shape);
20051994
auto rank = shapePerCTATile.size();
20061995
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
@@ -2010,7 +1999,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
20101999
}
20112000
SmallVector<unsigned>
20122001
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
2013-
assert(isAmpere() && "mmaLayout version = 1 is not implemented yet");
20142002
auto rank = getWarpsPerCTA().size();
20152003
auto sizePerThread = SmallVector<unsigned>(rank, 1);
20162004
if (opIdx == 0) {

0 commit comments

Comments
 (0)