@@ -781,24 +781,22 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
781781
782782 InterfaceMethod<"Return shape per CTA.",
783783 "SmallVector<unsigned>",
784- "getShapePerCTATileForOperand ",
784+ "getShapePerCTATileForDotOperands ",
785785 (ins "ArrayRef<int64_t>":$tensorShape,
786- "int":$kWidth,
787- "int":$opIdx)>,
786+ "unsigned":$opIdx)>,
788787
789788 InterfaceMethod<"Return total element size per thread for dot operands.",
790789 "unsigned",
791- "getTotalElemsPerThreadForOperand ",
790+ "getTotalElemsPerThreadForOperands ",
792791 (ins "ArrayRef<int64_t>":$tensorShape,
793792 "Type":$eltTy,
794- "int ":$kWidth,
795- "int ":$opIdx)>,
793+ "unsigned ":$kWidth,
794+ "unsigned ":$opIdx)>,
796795
797796 InterfaceMethod<"Return size per thread for dot operands.",
798797 "SmallVector<unsigned>",
799- "getSizePerThreadForOperand",
800- (ins "int":$opIdx,
801- "int":$kWidth)>,
798+ "getSizePerThreadForOperands",
799+ (ins "unsigned":$opIdx)>,
802800
803801 InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector<unsigned>",
804802 "getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
@@ -921,11 +919,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
921919 bool supportReduction() const {
922920 return true;
923921 }
924- SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
925- SmallVector<unsigned> getShapePerCTATileForOperand (ArrayRef<int64_t> shape, int kWidth , int opIdx) const;
926- unsigned getTotalElemsPerThreadForOperand (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
927- SmallVector<int64_t> getInstrShapeForOperand (int kWidth, int opIdx) const;
928- SmallVector<int64_t> getRepForOperand (ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
922+ SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
923+ SmallVector<unsigned> getShapePerCTATileForDotOperands (ArrayRef<int64_t> shape, int opIdx) const;
924+ unsigned getTotalElemsPerThreadForOperands (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
925+ SmallVector<int64_t> getMFMAInstrShapeForOperands (int kWidth, int opIdx) const;
926+ SmallVector<int64_t> getMFMARepForOperands (ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
929927
930928 SmallVector<unsigned> getContigPerThread() {
931929 auto rank = getWarpsPerCTA().size();
@@ -1032,12 +1030,12 @@ Row | warp 0 warp 2
10321030 bool supportReduction() const {
10331031 return true;
10341032 }
1035- SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1036- SmallVector<unsigned> getShapePerCTATileForOperand (ArrayRef<int64_t> shape, int kWidth , int opIdx) const;
1037- unsigned getTotalElemsPerThreadForOperand (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1033+ SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1034+ SmallVector<unsigned> getShapePerCTATileForDotOperands (ArrayRef<int64_t> shape, int opIdx) const;
1035+ unsigned getTotalElemsPerThreadForOperands (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
10381036 SmallVector<int64_t> getElemsPerInstrForOperands() const;
1039- SmallVector<int64_t> getRepForOperand (ArrayRef<int64_t> operandShape,
1040- Type elemType, int kWidth, int opIdx) const;
1037+ SmallVector<int64_t> getRepForOperands (ArrayRef<int64_t> operandShape,
1038+ Type elemType, int kWidth, int opIdx) const;
10411039 static SmallVector<unsigned> getMNKDimPerInstr();
10421040
10431041 SmallVector<unsigned> getContigPerThread() {
@@ -1237,18 +1235,18 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12371235 SmallVector<int> getMMAv1Rep(int opIdx) const;
12381236 SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
12391237 int getMMAv1Vec(int opIdx) const;
1240- SmallVector<int64_t> getMMAv2RepForOperand (ArrayRef<int64_t> shape,
1241- int bitwidth, int kWidth , int opIdx) const;
1238+ SmallVector<int64_t> getMMAv2Rep (ArrayRef<int64_t> shape,
1239+ int bitwidth, int opIdx) const;
12421240
12431241 bool supportReduction() const {
12441242 if (isAmpere() || isHopper()) {
12451243 return true;
12461244 }
12471245 return false;
12481246 };
1249- SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1250- SmallVector<unsigned> getShapePerCTATileForOperand (ArrayRef<int64_t> shape, int kWidth , int opIdx) const;
1251- unsigned getTotalElemsPerThreadForOperand (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1247+ SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1248+ SmallVector<unsigned> getShapePerCTATileForDotOperands (ArrayRef<int64_t> shape, int opIdx) const;
1249+ unsigned getTotalElemsPerThreadForOperands (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
12521250
12531251 SmallVector<unsigned> getContigPerThread() {
12541252 assert(isVolta() || isAmpere() || isHopper());
@@ -1363,16 +1361,7 @@ elements along the K dim, or they use all elements of the tensor along the K dim
13631361 let genVerifyDecl = 1;
13641362 let extraClassDeclaration = extraDistributedDeclaration # [{
13651363 SmallVector<unsigned> getContigPerThread() {
1366- auto rank = getWarpsPerCTA().size();
1367- assert(rank == 2 || rank == 3);
1368- SmallVector<unsigned> contigPerThread(rank, 1);
1369- auto kWidth = getKWidth();
1370- assert(kWidth != 0 && "Do not support kWidth=0");
1371- if (getOpIdx() == 0)
1372- contigPerThread[rank - 1] = kWidth;
1373- else
1374- contigPerThread[rank - 2] = kWidth;
1375- return contigPerThread;
1364+ return getSizePerThread();
13761365 };
13771366 }];
13781367}
0 commit comments