@@ -781,22 +781,24 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
781781
782782 InterfaceMethod<"Return shape per CTA.",
783783 "SmallVector<unsigned>",
784- "getShapePerCTATileForDotOperands ",
784+ "getShapePerCTATileForOperand ",
785785 (ins "ArrayRef<int64_t>":$tensorShape,
786- "unsigned":$opIdx)>,
786+ "int":$kWidth,
787+ "int":$opIdx)>,
787788
788789 InterfaceMethod<"Return total element size per thread for dot operands.",
789790 "unsigned",
790- "getTotalElemsPerThreadForOperands ",
791+ "getTotalElemsPerThreadForOperand ",
791792 (ins "ArrayRef<int64_t>":$tensorShape,
792793 "Type":$eltTy,
793- "unsigned ":$kWidth,
794- "unsigned ":$opIdx)>,
794+ "int ":$kWidth,
795+ "int ":$opIdx)>,
795796
796797 InterfaceMethod<"Return size per thread for dot operands.",
797798 "SmallVector<unsigned>",
798- "getSizePerThreadForOperands",
799- (ins "unsigned":$opIdx)>,
799+ "getSizePerThreadForOperand",
800+ (ins "int":$opIdx,
801+ "int":$kWidth)>,
800802 ];
801803}
802804
@@ -914,11 +916,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
914916 bool supportReduction() const {
915917 return true;
916918 }
917- SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
918- SmallVector<unsigned> getShapePerCTATileForDotOperands (ArrayRef<int64_t> shape, int opIdx) const;
919- unsigned getTotalElemsPerThreadForOperands (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
920- SmallVector<int64_t> getMFMAInstrShapeForOperands (int kWidth, int opIdx) const;
921- SmallVector<int64_t> getMFMARepForOperands (ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
919+ SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
920+ SmallVector<unsigned> getShapePerCTATileForOperand (ArrayRef<int64_t> shape, int kWidth , int opIdx) const;
921+ unsigned getTotalElemsPerThreadForOperand (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
922+ SmallVector<int64_t> getInstrShapeForOperand (int kWidth, int opIdx) const;
923+ SmallVector<int64_t> getRepForOperand (ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
922924
923925 SmallVector<unsigned> getContigPerThread() {
924926 auto rank = getWarpsPerCTA().size();
@@ -1021,12 +1023,12 @@ Row | warp 0 warp 2
10211023 bool supportReduction() const {
10221024 return true;
10231025 }
1024- SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1025- SmallVector<unsigned> getShapePerCTATileForDotOperands (ArrayRef<int64_t> shape, int opIdx) const;
1026- unsigned getTotalElemsPerThreadForOperands (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1026+ SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1027+ SmallVector<unsigned> getShapePerCTATileForOperand (ArrayRef<int64_t> shape, int kWidth , int opIdx) const;
1028+ unsigned getTotalElemsPerThreadForOperand (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
10271029 SmallVector<int64_t> getElemsPerInstrForOperands() const;
1028- SmallVector<int64_t> getRepForOperands (ArrayRef<int64_t> operandShape,
1029- Type elemType, int kWidth, int opIdx) const;
1030+ SmallVector<int64_t> getRepForOperand (ArrayRef<int64_t> operandShape,
1031+ Type elemType, int kWidth, int opIdx) const;
10301032 static SmallVector<unsigned> getMNKDimPerInstr();
10311033
10321034 SmallVector<unsigned> getContigPerThread() {
@@ -1222,18 +1224,18 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12221224 SmallVector<int> getMMAv1Rep(int opIdx) const;
12231225 SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
12241226 int getMMAv1Vec(int opIdx) const;
1225- SmallVector<int64_t> getMMAv2Rep (ArrayRef<int64_t> shape,
1226- int bitwidth, int opIdx) const;
1227+ SmallVector<int64_t> getMMAv2RepForOperand (ArrayRef<int64_t> shape,
1228+ int bitwidth, int kWidth , int opIdx) const;
12271229
12281230 bool supportReduction() const {
12291231 if (isAmpere() || isHopper()) {
12301232 return true;
12311233 }
12321234 return false;
12331235 };
1234- SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1235- SmallVector<unsigned> getShapePerCTATileForDotOperands (ArrayRef<int64_t> shape, int opIdx) const;
1236- unsigned getTotalElemsPerThreadForOperands (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1236+ SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1237+ SmallVector<unsigned> getShapePerCTATileForOperand (ArrayRef<int64_t> shape, int kWidth , int opIdx) const;
1238+ unsigned getTotalElemsPerThreadForOperand (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
12371239
12381240 SmallVector<unsigned> getContigPerThread() {
12391241 assert(isVolta() || isAmpere() || isHopper());
@@ -1344,7 +1346,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
13441346 let genVerifyDecl = 1;
13451347 let extraClassDeclaration = extraDistributedDeclaration # [{
13461348 SmallVector<unsigned> getContigPerThread() {
1347- return getSizePerThread();
1349+ auto rank = getWarpsPerCTA().size();
1350+ assert(rank == 2 || rank == 3);
1351+ SmallVector<unsigned> contigPerThread(rank, 1);
1352+ auto kWidth = getKWidth();
1353+ assert(kWidth != 0 && "Do not support kWidth=0");
1354+ if (getOpIdx() == 0)
1355+ contigPerThread[rank - 1] = kWidth;
1356+ else
1357+ contigPerThread[rank - 2] = kWidth;
1358+ return contigPerThread;
13481359 };
13491360 }];
13501361}
0 commit comments