@@ -781,22 +781,24 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
781781
782782 InterfaceMethod<"Return shape per CTA.",
783783 "SmallVector<unsigned>",
784- "getShapePerCTATileForDotOperands ",
784+ "getShapePerCTATileForOperand ",
785785 (ins "ArrayRef<int64_t>":$tensorShape,
786- "unsigned":$opIdx)>,
786+ "int":$kWidth,
787+ "int":$opIdx)>,
787788
788789 InterfaceMethod<"Return total element size per thread for dot operands.",
789790 "unsigned",
790- "getTotalElemsPerThreadForOperands ",
791+ "getTotalElemsPerThreadForOperand ",
791792 (ins "ArrayRef<int64_t>":$tensorShape,
792793 "Type":$eltTy,
793- "unsigned ":$kWidth,
794- "unsigned ":$opIdx)>,
794+ "int ":$kWidth,
795+ "int ":$opIdx)>,
795796
796797 InterfaceMethod<"Return size per thread for dot operands.",
797798 "SmallVector<unsigned>",
798- "getSizePerThreadForOperands",
799- (ins "unsigned":$opIdx)>,
799+ "getSizePerThreadForOperand",
800+ (ins "int":$opIdx,
801+ "int":$kWidth)>,
800802
801803 InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector<unsigned>",
802804 "getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
@@ -919,11 +921,11 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
919921 bool supportReduction() const {
920922 return true;
921923 }
922- SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
923- SmallVector<unsigned> getShapePerCTATileForDotOperands (ArrayRef<int64_t> shape, int opIdx) const;
924- unsigned getTotalElemsPerThreadForOperands (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
925- SmallVector<int64_t> getMFMAInstrShapeForOperands (int kWidth, int opIdx) const;
926- SmallVector<int64_t> getMFMARepForOperands (ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
924+ SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
925+ SmallVector<unsigned> getShapePerCTATileForOperand (ArrayRef<int64_t> shape, int kWidth , int opIdx) const;
926+ unsigned getTotalElemsPerThreadForOperand (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
927+ SmallVector<int64_t> getInstrShapeForOperand (int kWidth, int opIdx) const;
928+ SmallVector<int64_t> getRepForOperand (ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
927929
928930 SmallVector<unsigned> getContigPerThread() {
929931 auto rank = getWarpsPerCTA().size();
@@ -1030,12 +1032,12 @@ Row | warp 0 warp 2
10301032 bool supportReduction() const {
10311033 return true;
10321034 }
1033- SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1034- SmallVector<unsigned> getShapePerCTATileForDotOperands (ArrayRef<int64_t> shape, int opIdx) const;
1035- unsigned getTotalElemsPerThreadForOperands (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1035+ SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1036+ SmallVector<unsigned> getShapePerCTATileForOperand (ArrayRef<int64_t> shape, int kWidth , int opIdx) const;
1037+ unsigned getTotalElemsPerThreadForOperand (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
10361038 SmallVector<int64_t> getElemsPerInstrForOperands() const;
1037- SmallVector<int64_t> getRepForOperands (ArrayRef<int64_t> operandShape,
1038- Type elemType, int kWidth, int opIdx) const;
1039+ SmallVector<int64_t> getRepForOperand (ArrayRef<int64_t> operandShape,
1040+ Type elemType, int kWidth, int opIdx) const;
10391041 static SmallVector<unsigned> getMNKDimPerInstr();
10401042
10411043 SmallVector<unsigned> getContigPerThread() {
@@ -1235,18 +1237,18 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12351237 SmallVector<int> getMMAv1Rep(int opIdx) const;
12361238 SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
12371239 int getMMAv1Vec(int opIdx) const;
1238- SmallVector<int64_t> getMMAv2Rep (ArrayRef<int64_t> shape,
1239- int bitwidth, int opIdx) const;
1240+ SmallVector<int64_t> getMMAv2RepForOperand (ArrayRef<int64_t> shape,
1241+ int bitwidth, int kWidth , int opIdx) const;
12401242
12411243 bool supportReduction() const {
12421244 if (isAmpere() || isHopper()) {
12431245 return true;
12441246 }
12451247 return false;
12461248 };
1247- SmallVector<unsigned> getSizePerThreadForOperands(unsigned opIdx) const;
1248- SmallVector<unsigned> getShapePerCTATileForDotOperands (ArrayRef<int64_t> shape, int opIdx) const;
1249- unsigned getTotalElemsPerThreadForOperands (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
1249+ SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1250+ SmallVector<unsigned> getShapePerCTATileForOperand (ArrayRef<int64_t> shape, int kWidth , int opIdx) const;
1251+ unsigned getTotalElemsPerThreadForOperand (ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
12501252
12511253 SmallVector<unsigned> getContigPerThread() {
12521254 assert(isVolta() || isAmpere() || isHopper());
@@ -1361,7 +1363,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
13611363 let genVerifyDecl = 1;
13621364 let extraClassDeclaration = extraDistributedDeclaration # [{
13631365 SmallVector<unsigned> getContigPerThread() {
1364- return getSizePerThread();
1366+ auto rank = getWarpsPerCTA().size();
1367+ assert(rank == 2 || rank == 3);
1368+ SmallVector<unsigned> contigPerThread(rank, 1);
1369+ auto kWidth = getKWidth();
1370+ assert(kWidth != 0 && "Do not support kWidth=0");
1371+ if (getOpIdx() == 0)
1372+ contigPerThread[rank - 1] = kWidth;
1373+ else
1374+ contigPerThread[rank - 2] = kWidth;
1375+ return contigPerThread;
13651376 };
13661377 }];
13671378}
0 commit comments