Skip to content

Commit 25a7cba

Browse files
Revert "[BACKEND] Small fixes for dot operand properties (#4895)"
This reverts commit f9688ab.
1 parent 076a001 commit 25a7cba

File tree

10 files changed

+132
-165
lines changed

10 files changed

+132
-165
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -781,24 +781,22 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
781781

782782
InterfaceMethod<"Return shape per CTA.",
783783
"SmallVector<unsigned>",
784-
"getShapePerCTATileForOperand",
784+
"getShapePerCTATileForDotOperands",
785785
(ins "ArrayRef<int64_t>":$tensorShape,
786-
"int":$kWidth,
787-
"int":$opIdx)>,
786+
"unsigned":$opIdx)>,
788787

789788
InterfaceMethod<"Return total element size per thread for dot operands.",
790789
"unsigned",
791-
"getTotalElemsPerThreadForOperand",
790+
"getTotalElemsPerThreadForOperands",
792791
(ins "ArrayRef<int64_t>":$tensorShape,
793792
"Type":$eltTy,
794-
"int":$kWidth,
795-
"int":$opIdx)>,
793+
"unsigned":$kWidth,
794+
"unsigned":$opIdx)>,
796795

797796
InterfaceMethod<"Return size per thread for dot operands.",
798797
"SmallVector<unsigned>",
799-
"getSizePerThreadForOperand",
800-
(ins "int":$opIdx,
801-
"int":$kWidth)>,
798+
"getSizePerThreadForOperands",
799+
(ins "unsigned":$opIdx)>,
802800

803801
InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector<unsigned>",
804802
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
@@ -921,11 +919,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
921919
bool supportReduction() const {
922920
return true;
923921
}
924-
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
925-
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
926-
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
927-
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
928-
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
922+
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
923+
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
924+
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
925+
SmallVector<int64_t> getMFMAInstrShapeForOperands(int kWidth, int opIdx) const;
926+
SmallVector<int64_t> getMFMARepForOperands(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
929927

930928
SmallVector<unsigned> getContigPerThread() {
931929
auto rank = getWarpsPerCTA().size();
@@ -1032,12 +1030,12 @@ Row | warp 0 warp 2
10321030
bool supportReduction() const {
10331031
return true;
10341032
}
1035-
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1036-
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
1037-
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1033+
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1034+
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
1035+
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
10381036
SmallVector<int64_t> getElemsPerInstrForOperands() const;
1039-
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
1040-
Type elemType, int kWidth, int opIdx) const;
1037+
SmallVector<int64_t> getRepForOperands(ArrayRef<int64_t> operandShape,
1038+
Type elemType, int kWidth, int opIdx) const;
10411039
static SmallVector<unsigned> getMNKDimPerInstr();
10421040

10431041
SmallVector<unsigned> getContigPerThread() {
@@ -1237,18 +1235,18 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12371235
SmallVector<int> getMMAv1Rep(int opIdx) const;
12381236
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
12391237
int getMMAv1Vec(int opIdx) const;
1240-
SmallVector<int64_t> getMMAv2RepForOperand(ArrayRef<int64_t> shape,
1241-
int bitwidth, int kWidth, int opIdx) const;
1238+
SmallVector<int64_t> getMMAv2Rep(ArrayRef<int64_t> shape,
1239+
int bitwidth, int opIdx) const;
12421240

12431241
bool supportReduction() const {
12441242
if (isAmpere() || isHopper()) {
12451243
return true;
12461244
}
12471245
return false;
12481246
};
1249-
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1250-
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
1251-
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1247+
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1248+
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
1249+
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
12521250

12531251
SmallVector<unsigned> getContigPerThread() {
12541252
assert(isVolta() || isAmpere() || isHopper());
@@ -1363,16 +1361,7 @@ elements along the K dim, or they use all elements of the tensor along the K dim
13631361
let genVerifyDecl = 1;
13641362
let extraClassDeclaration = extraDistributedDeclaration # [{
13651363
SmallVector<unsigned> getContigPerThread() {
1366-
auto rank = getWarpsPerCTA().size();
1367-
assert(rank == 2 || rank == 3);
1368-
SmallVector<unsigned> contigPerThread(rank, 1);
1369-
auto kWidth = getKWidth();
1370-
assert(kWidth != 0 && "Do not support kWidth=0");
1371-
if (getOpIdx() == 0)
1372-
contigPerThread[rank - 1] = kWidth;
1373-
else
1374-
contigPerThread[rank - 2] = kWidth;
1375-
return contigPerThread;
1364+
return getSizePerThread();
13761365
};
13771366
}];
13781367
}

0 commit comments

Comments
 (0)