@@ -938,11 +938,11 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
938938 elemsPerThread[rank - 1 ] = (idx == 0 ) ? rep[2 ] * kWidth : rep[2 ];
939939 return elemsPerThread;
940940 } else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
941- if (mma.isAmpere ()) {
941+ if (mma.isAmpere () || mma. isHopper () ) {
942942 auto bitwidth = getPointeeType (eltTy).getIntOrFloatBitWidth ();
943943 auto rep = mma.getRepForOperand (shape, bitwidth, idx);
944944 auto sizePerThread = getSizePerThread ();
945- auto elemsPerKRep = 32 / bitwidth * 2 ;
945+ auto elemsPerKRep = mma. isHopper () ? ( kWidth * 2 ) : ( 32 / bitwidth * 2 ) ;
946946 if (rank == 3 )
947947 elemsPerThread[0 ] = rep[0 ];
948948 elemsPerThread[rank - 2 ] =
@@ -964,12 +964,18 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
964964unsigned DotOperandEncodingAttr::getTotalElemsPerThread (ArrayRef<int64_t > shape,
965965 Type eltTy) const {
966966 if (auto mmaParent = mlir::dyn_cast<MmaEncodingTrait>(getParent ())) {
967- if (auto nvidiaMmaParent = mlir::dyn_cast<NvidiaMmaEncodingAttr>(mmaParent);
968- nvidiaMmaParent && nvidiaMmaParent. isAmpere ( )) {
967+ if (auto nvidiaMmaParent =
968+ mlir::dyn_cast<NvidiaMmaEncodingAttr>(mmaParent )) {
969969 return product<unsigned >(getElemsPerThread (shape, eltTy));
970970 }
971- return mmaParent.getTotalElemsPerThreadForOperand (shape, eltTy, getKWidth (),
972- getOpIdx ());
971+ if (auto amdMfmaParent = mlir::dyn_cast<AMDMfmaEncodingAttr>(getParent ())) {
972+ return amdMfmaParent.getTotalElemsPerThreadForOperand (
973+ shape, eltTy, getKWidth (), getOpIdx ());
974+ }
975+ if (auto amdWmmaParent = mlir::dyn_cast<AMDWmmaEncodingAttr>(getParent ())) {
976+ return amdWmmaParent.getTotalElemsPerThreadForOperand (
977+ shape, eltTy, getKWidth (), getOpIdx ());
978+ }
973979 }
974980 if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent ())) {
975981 auto shapePerCTA = getShapePerCTA (*this , shape);
@@ -1981,26 +1987,9 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
19811987 }
19821988}
19831989
1984- unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand (
1985- ArrayRef<int64_t > shape, Type eltTy, int kWidth , int opIdx) const {
1986- auto shapePerCTA = getShapePerCTA (*this , shape);
1987- int warpsPerCTAM = getWarpsPerCTA ()[0 ];
1988- int warpsPerCTAN = getWarpsPerCTA ()[1 ];
1989- // H100
1990- if (isHopper ()) {
1991- assert (opIdx == 0 );
1992- auto instrMNK = getInstrShape ();
1993- int repM = ceil<unsigned >(shapePerCTA[0 ], instrMNK[0 ] * warpsPerCTAM);
1994- int repK = ceil<unsigned >(shapePerCTA[1 ], instrMNK[2 ]);
1995- // For each WGMMA instr, a 2x2 matrix fragment is loaded. Each thread holds
1996- // kWidth elements for each quadrant. WGMMA is repeated repM * repK times.
1997- return 4 * kWidth * repM * repK;
1998- }
1999- llvm_unreachable (" unknown mma layout" );
2000- }
20011990SmallVector<unsigned > NvidiaMmaEncodingAttr::getShapePerCTATileForOperand (
20021991 ArrayRef<int64_t > shape, int kWidth , int opIdx) const {
2003- assert (isAmpere () && " mmaLayout version = 1 is not implemented yet" );
1992+ assert (isAmpere () && " mmaLayout Hopper is not implemented yet" );
20041993 auto shapePerCTATile = getShapePerCTATile (shape);
20051994 auto rank = shapePerCTATile.size ();
20061995 auto kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
@@ -2010,7 +1999,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
20101999}
20112000SmallVector<unsigned >
20122001NvidiaMmaEncodingAttr::getSizePerThreadForOperand (int kWidth , int opIdx) const {
2013- assert (isAmpere () && " mmaLayout version = 1 is not implemented yet" );
20142002 auto rank = getWarpsPerCTA ().size ();
20152003 auto sizePerThread = SmallVector<unsigned >(rank, 1 );
20162004 if (opIdx == 0 ) {
0 commit comments