Skip to content

Commit f9688ab

Browse files
Jokerenlezcano
andauthored
[BACKEND] Small fixes for dot operand properties (#4895)
Co-authored-by: Mario Lezcano Casado <[email protected]>
1 parent fa229d1 commit f9688ab

File tree

10 files changed

+165
-132
lines changed

10 files changed

+165
-132
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -781,22 +781,24 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
781781

782782
InterfaceMethod<"Return shape per CTA.",
783783
"SmallVector<unsigned>",
784-
"getShapePerCTATileForDotOperands",
784+
"getShapePerCTATileForOperand",
785785
(ins "ArrayRef<int64_t>":$tensorShape,
786-
"unsigned":$opIdx)>,
786+
"int":$kWidth,
787+
"int":$opIdx)>,
787788

788789
InterfaceMethod<"Return total element size per thread for dot operands.",
789790
"unsigned",
790-
"getTotalElemsPerThreadForOperands",
791+
"getTotalElemsPerThreadForOperand",
791792
(ins "ArrayRef<int64_t>":$tensorShape,
792793
"Type":$eltTy,
793-
"unsigned":$kWidth,
794-
"unsigned":$opIdx)>,
794+
"int":$kWidth,
795+
"int":$opIdx)>,
795796

796797
InterfaceMethod<"Return size per thread for dot operands.",
797798
"SmallVector<unsigned>",
798-
"getSizePerThreadForOperands",
799-
(ins "unsigned":$opIdx)>,
799+
"getSizePerThreadForOperand",
800+
(ins "int":$opIdx,
801+
"int":$kWidth)>,
800802
];
801803
}
802804

@@ -914,11 +916,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
914916
bool supportReduction() const {
915917
return true;
916918
}
917-
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
918-
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
919-
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
920-
SmallVector<int64_t> getMFMAInstrShapeForOperands(int kWidth, int opIdx) const;
921-
SmallVector<int64_t> getMFMARepForOperands(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
919+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
920+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
921+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
922+
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
923+
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
922924

923925
SmallVector<unsigned> getContigPerThread() {
924926
auto rank = getWarpsPerCTA().size();
@@ -1021,12 +1023,12 @@ Row | warp 0 warp 2
10211023
bool supportReduction() const {
10221024
return true;
10231025
}
1024-
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1025-
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
1026-
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1026+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1027+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
1028+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
10271029
SmallVector<int64_t> getElemsPerInstrForOperands() const;
1028-
SmallVector<int64_t> getRepForOperands(ArrayRef<int64_t> operandShape,
1029-
Type elemType, int kWidth, int opIdx) const;
1030+
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
1031+
Type elemType, int kWidth, int opIdx) const;
10301032
static SmallVector<unsigned> getMNKDimPerInstr();
10311033

10321034
SmallVector<unsigned> getContigPerThread() {
@@ -1222,18 +1224,18 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12221224
SmallVector<int> getMMAv1Rep(int opIdx) const;
12231225
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
12241226
int getMMAv1Vec(int opIdx) const;
1225-
SmallVector<int64_t> getMMAv2Rep(ArrayRef<int64_t> shape,
1226-
int bitwidth, int opIdx) const;
1227+
SmallVector<int64_t> getMMAv2RepForOperand(ArrayRef<int64_t> shape,
1228+
int bitwidth, int kWidth, int opIdx) const;
12271229

12281230
bool supportReduction() const {
12291231
if (isAmpere() || isHopper()) {
12301232
return true;
12311233
}
12321234
return false;
12331235
};
1234-
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1235-
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
1236-
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1236+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1237+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
1238+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
12371239

12381240
SmallVector<unsigned> getContigPerThread() {
12391241
assert(isVolta() || isAmpere() || isHopper());
@@ -1344,7 +1346,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
13441346
let genVerifyDecl = 1;
13451347
let extraClassDeclaration = extraDistributedDeclaration # [{
13461348
SmallVector<unsigned> getContigPerThread() {
1347-
return getSizePerThread();
1349+
auto rank = getWarpsPerCTA().size();
1350+
assert(rank == 2 || rank == 3);
1351+
SmallVector<unsigned> contigPerThread(rank, 1);
1352+
auto kWidth = getKWidth();
1353+
assert(kWidth != 0 && "Do not support kWidth=0");
1354+
if (getOpIdx() == 0)
1355+
contigPerThread[rank - 1] = kWidth;
1356+
else
1357+
contigPerThread[rank - 2] = kWidth;
1358+
return contigPerThread;
13481359
};
13491360
}];
13501361
}

0 commit comments

Comments
 (0)