Skip to content

Commit 4330372

Browse files
authored
[LAYOUTS] [BE] Simplify Ampere/Hopper paths introduced in #5189 (#5200)
We simplify the implementation of `getElemsPerThread` and strengthen the preconditions of `getRepForOperand`. More generally, we should try to minimise the calls to `isAmpere` and `isHopper` throughout the codebase. I'll do a pass fixing many of these once we land LLs for `ldmatrix` and Hopper.
1 parent af0649d commit 4330372

File tree

1 file changed

+40
-39
lines changed

1 file changed

+40
-39
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,23 +1038,18 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
10381038
elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2];
10391039
return elemsPerThread;
10401040
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
1041-
if (mma.isAmpere() || mma.isHopper()) {
1042-
auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth();
1043-
auto rep = mma.getRepForOperand(shape, bitwidth, kWidth, idx);
1044-
auto sizePerThread = getSizePerThread();
1045-
auto elemsPerKRep = mma.isHopper() ? (kWidth * 2) : (32 / bitwidth * 2);
1046-
if (rank == 3)
1047-
elemsPerThread[0] = rep[0];
1048-
elemsPerThread[rank - 2] =
1049-
(idx == 0)
1050-
? rep[1] * sizePerThread[rank - 2]
1051-
: std::max<int>(rep[1] * elemsPerKRep, sizePerThread[rank - 2]);
1052-
elemsPerThread[rank - 1] =
1053-
(idx == 0)
1054-
? std::max<int>(rep[2] * elemsPerKRep, sizePerThread[rank - 1])
1055-
: rep[2] * sizePerThread[rank - 1];
1056-
return elemsPerThread;
1041+
assert(getCTALayout(*this) ==
1042+
CTALayoutAttr::getDefault(getContext(), rank) &&
1043+
"NYI");
1044+
auto sizePerThread = getSizePerThread();
1045+
auto threadsPerWarp = getThreadsPerWarp();
1046+
auto warpsPerCTA = getWarpsPerCTA();
1047+
SmallVector<unsigned> regs;
1048+
for (auto [n, nsize, nThread, nWarp] :
1049+
llvm::zip(shape, sizePerThread, threadsPerWarp, warpsPerCTA)) {
1050+
regs.push_back(std::max<int64_t>(nsize, n / (nThread * nWarp)));
10571051
}
1052+
return regs;
10581053
}
10591054

10601055
llvm_unreachable("getElemsPerThread is not supported for dot operand");
@@ -2341,35 +2336,41 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
23412336
SmallVector<int64_t>
23422337
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
23432338
int kWidth, int opIdx) const {
2339+
assert(
2340+
kWidth >= 32 / bitwidth &&
2341+
"kWidth must be >= 32 / bitwidth for this function to be well-defined");
23442342
auto rank = shape.size();
2343+
// Broadcast long K
23452344
auto warpsPerCTA = getWarpsPerCTA();
2345+
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
2346+
warpsPerCTA[kDim] = 1;
23462347

2347-
// {batch, m, n, k}
2348-
// Hopper path never uses the n value, since this method is only invoked
2349-
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
2350-
// TODO: rep per operand is not accurate for Hopper. It is currently done that
2351-
// way to allow us to get the correct total number of elements. this will be
2352-
// fixed when moving to linear layout.
2353-
SmallVector<int> shapePerWarp = {
2354-
1, 16, 8, isHopper() ? 4 * 2 * kWidth : 4 * 64 / bitwidth};
2355-
int numRepBatch =
2356-
rank == 3
2357-
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
2358-
: 1;
2359-
2348+
SmallVector<int> tileSize;
2349+
if (rank == 3) {
2350+
tileSize.push_back(1);
2351+
}
23602352
if (opIdx == 0) {
2361-
return {numRepBatch,
2362-
std::max<int64_t>(1, /*repM=*/shape[rank - 2] /
2363-
(shapePerWarp[1] * warpsPerCTA[rank - 2])),
2364-
std::max<int64_t>(1, /*repK=*/shape[rank - 1] / shapePerWarp[3])};
2353+
// m x k
2354+
tileSize.push_back(16);
2355+
tileSize.push_back(4 * 64 / bitwidth);
23652356
} else {
2366-
assert(opIdx == 1);
2367-
return {
2368-
numRepBatch,
2369-
std::max<int64_t>(1, /*repK=*/shape[rank - 2] / shapePerWarp[3]),
2370-
std::max<int64_t>(1, /*repN=*/shape[rank - 1] /
2371-
(shapePerWarp[2] * warpsPerCTA[rank - 1]))};
2357+
// k x n
2358+
// Hopper path never uses the n value, since this method is only invoked
2359+
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
2360+
// so it's fine if the n is incorrect here
2361+
tileSize.push_back(4 * 64 / bitwidth);
2362+
tileSize.push_back(8);
2363+
}
2364+
2365+
SmallVector<int64_t> numRep;
2366+
// Lezcano: This is odd. Why do we always return a vector of size 3?
2367+
if (rank != 3) {
2368+
numRep.push_back(1);
2369+
}
2370+
for (auto [s, size, warp] : llvm::zip(shape, tileSize, warpsPerCTA)) {
2371+
numRep.push_back(std::max<int64_t>(1, s / (size * warp)));
23722372
}
2373+
return numRep;
23732374
}
23742375

23752376
SmallVector<unsigned>

0 commit comments

Comments
 (0)