@@ -1038,23 +1038,18 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
10381038 elemsPerThread[rank - 1 ] = (idx == 0 ) ? rep[2 ] * kWidth : rep[2 ];
10391039 return elemsPerThread;
10401040 } else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
1041- if (mma.isAmpere () || mma.isHopper ()) {
1042- auto bitwidth = getPointeeType (eltTy).getIntOrFloatBitWidth ();
1043- auto rep = mma.getRepForOperand (shape, bitwidth, kWidth , idx);
1044- auto sizePerThread = getSizePerThread ();
1045- auto elemsPerKRep = mma.isHopper () ? (kWidth * 2 ) : (32 / bitwidth * 2 );
1046- if (rank == 3 )
1047- elemsPerThread[0 ] = rep[0 ];
1048- elemsPerThread[rank - 2 ] =
1049- (idx == 0 )
1050- ? rep[1 ] * sizePerThread[rank - 2 ]
1051- : std::max<int >(rep[1 ] * elemsPerKRep, sizePerThread[rank - 2 ]);
1052- elemsPerThread[rank - 1 ] =
1053- (idx == 0 )
1054- ? std::max<int >(rep[2 ] * elemsPerKRep, sizePerThread[rank - 1 ])
1055- : rep[2 ] * sizePerThread[rank - 1 ];
1056- return elemsPerThread;
1041+ assert (getCTALayout (*this ) ==
1042+ CTALayoutAttr::getDefault (getContext (), rank) &&
1043+ " NYI" );
1044+ auto sizePerThread = getSizePerThread ();
1045+ auto threadsPerWarp = getThreadsPerWarp ();
1046+ auto warpsPerCTA = getWarpsPerCTA ();
1047+ SmallVector<unsigned > regs;
1048+ for (auto [n, nsize, nThread, nWarp] :
1049+ llvm::zip (shape, sizePerThread, threadsPerWarp, warpsPerCTA)) {
1050+ regs.push_back (std::max<int64_t >(nsize, n / (nThread * nWarp)));
10571051 }
1052+ return regs;
10581053 }
10591054
10601055 llvm_unreachable (" getElemsPerThread is not supported for dot operand" );
@@ -2341,35 +2336,41 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
23412336SmallVector<int64_t >
23422337NvidiaMmaEncodingAttr::getRepForOperand (ArrayRef<int64_t > shape, int bitwidth,
23432338 int kWidth , int opIdx) const {
2339+ assert (
2340+ kWidth >= 32 / bitwidth &&
2341+ " kWidth must be >= 32 / bitwidth for this function to be well-defined" );
23442342 auto rank = shape.size ();
2343+ // Broadcast long K
23452344 auto warpsPerCTA = getWarpsPerCTA ();
2345+ auto kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
2346+ warpsPerCTA[kDim ] = 1 ;
23462347
2347- // {batch, m, n, k}
2348- // Hopper path never uses the n value, since this method is only invoked
2349- // for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
2350- // TODO: rep per operand is not accurate for Hopper. It is currently done that
2351- // way to allow us to get the correct total number of elements. this will be
2352- // fixed when moving to linear layout.
2353- SmallVector<int > shapePerWarp = {
2354- 1 , 16 , 8 , isHopper () ? 4 * 2 * kWidth : 4 * 64 / bitwidth};
2355- int numRepBatch =
2356- rank == 3
2357- ? std::max<int64_t >(1 , shape[0 ] / (shapePerWarp[0 ] * warpsPerCTA[0 ]))
2358- : 1 ;
2359-
2348+ SmallVector<int > tileSize;
2349+ if (rank == 3 ) {
2350+ tileSize.push_back (1 );
2351+ }
23602352 if (opIdx == 0 ) {
2361- return {numRepBatch,
2362- std::max<int64_t >(1 , /* repM=*/ shape[rank - 2 ] /
2363- (shapePerWarp[1 ] * warpsPerCTA[rank - 2 ])),
2364- std::max<int64_t >(1 , /* repK=*/ shape[rank - 1 ] / shapePerWarp[3 ])};
2353+ // m x k
2354+ tileSize.push_back (16 );
2355+ tileSize.push_back (4 * 64 / bitwidth);
23652356 } else {
2366- assert (opIdx == 1 );
2367- return {
2368- numRepBatch,
2369- std::max<int64_t >(1 , /* repK=*/ shape[rank - 2 ] / shapePerWarp[3 ]),
2370- std::max<int64_t >(1 , /* repN=*/ shape[rank - 1 ] /
2371- (shapePerWarp[2 ] * warpsPerCTA[rank - 1 ]))};
2357+ // k x n
2358+ // Hopper path never uses the n value, since this method is only invoked
2359+ // for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
2360+ // so it's fine if the n is incorrect here
2361+ tileSize.push_back (4 * 64 / bitwidth);
2362+ tileSize.push_back (8 );
2363+ }
2364+
2365+ SmallVector<int64_t > numRep;
2366+ // Lezcano: This is odd. Why do we always return a vector of size 3?
2367+ if (rank != 3 ) {
2368+ numRep.push_back (1 );
2369+ }
2370+ for (auto [s, size, warp] : llvm::zip (shape, tileSize, warpsPerCTA)) {
2371+ numRep.push_back (std::max<int64_t >(1 , s / (size * warp)));
23722372 }
2373+ return numRep;
23732374}
23742375
23752376SmallVector<unsigned >
0 commit comments