Skip to content

Commit 076a001

Browse files
Merge commit 'f9688abe3d3d9caea7846ce41d5fa1da765f5e16'
2 parents 24e53d2 + f9688ab commit 076a001

File tree

10 files changed

+165
-132
lines changed

10 files changed

+165
-132
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -781,22 +781,24 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
781781

782782
InterfaceMethod<"Return shape per CTA.",
783783
"SmallVector<unsigned>",
784-
"getShapePerCTATileForDotOperands",
784+
"getShapePerCTATileForOperand",
785785
(ins "ArrayRef<int64_t>":$tensorShape,
786-
"unsigned":$opIdx)>,
786+
"int":$kWidth,
787+
"int":$opIdx)>,
787788

788789
InterfaceMethod<"Return total element size per thread for dot operands.",
789790
"unsigned",
790-
"getTotalElemsPerThreadForOperands",
791+
"getTotalElemsPerThreadForOperand",
791792
(ins "ArrayRef<int64_t>":$tensorShape,
792793
"Type":$eltTy,
793-
"unsigned":$kWidth,
794-
"unsigned":$opIdx)>,
794+
"int":$kWidth,
795+
"int":$opIdx)>,
795796

796797
InterfaceMethod<"Return size per thread for dot operands.",
797798
"SmallVector<unsigned>",
798-
"getSizePerThreadForOperands",
799-
(ins "unsigned":$opIdx)>,
799+
"getSizePerThreadForOperand",
800+
(ins "int":$opIdx,
801+
"int":$kWidth)>,
800802

801803
InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector<unsigned>",
802804
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
@@ -919,11 +921,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
919921
bool supportReduction() const {
920922
return true;
921923
}
922-
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
923-
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
924-
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
925-
SmallVector<int64_t> getMFMAInstrShapeForOperands(int kWidth, int opIdx) const;
926-
SmallVector<int64_t> getMFMARepForOperands(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
924+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
925+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
926+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
927+
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
928+
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
927929

928930
SmallVector<unsigned> getContigPerThread() {
929931
auto rank = getWarpsPerCTA().size();
@@ -1030,12 +1032,12 @@ Row | warp 0 warp 2
10301032
bool supportReduction() const {
10311033
return true;
10321034
}
1033-
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1034-
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
1035-
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1035+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1036+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
1037+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
10361038
SmallVector<int64_t> getElemsPerInstrForOperands() const;
1037-
SmallVector<int64_t> getRepForOperands(ArrayRef<int64_t> operandShape,
1038-
Type elemType, int kWidth, int opIdx) const;
1039+
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
1040+
Type elemType, int kWidth, int opIdx) const;
10391041
static SmallVector<unsigned> getMNKDimPerInstr();
10401042

10411043
SmallVector<unsigned> getContigPerThread() {
@@ -1235,18 +1237,18 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12351237
SmallVector<int> getMMAv1Rep(int opIdx) const;
12361238
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
12371239
int getMMAv1Vec(int opIdx) const;
1238-
SmallVector<int64_t> getMMAv2Rep(ArrayRef<int64_t> shape,
1239-
int bitwidth, int opIdx) const;
1240+
SmallVector<int64_t> getMMAv2RepForOperand(ArrayRef<int64_t> shape,
1241+
int bitwidth, int kWidth, int opIdx) const;
12401242

12411243
bool supportReduction() const {
12421244
if (isAmpere() || isHopper()) {
12431245
return true;
12441246
}
12451247
return false;
12461248
};
1247-
SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1248-
SmallVector<unsigned> getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape, int opIdx) const;
1249-
unsigned getTotalElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1249+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1250+
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
1251+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
12501252

12511253
SmallVector<unsigned> getContigPerThread() {
12521254
assert(isVolta() || isAmpere() || isHopper());
@@ -1361,7 +1363,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
13611363
let genVerifyDecl = 1;
13621364
let extraClassDeclaration = extraDistributedDeclaration # [{
13631365
SmallVector<unsigned> getContigPerThread() {
1364-
return getSizePerThread();
1366+
auto rank = getWarpsPerCTA().size();
1367+
assert(rank == 2 || rank == 3);
1368+
SmallVector<unsigned> contigPerThread(rank, 1);
1369+
auto kWidth = getKWidth();
1370+
assert(kWidth != 0 && "Do not support kWidth=0");
1371+
if (getOpIdx() == 0)
1372+
contigPerThread[rank - 1] = kWidth;
1373+
else
1374+
contigPerThread[rank - 2] = kWidth;
1375+
return contigPerThread;
13651376
};
13661377
}];
13671378
}

0 commit comments

Comments
 (0)