Skip to content

Commit 1031dc7

Browse files
lezcanoapgoucher
andauthored
[LAYOUTS] Improve the swizzling algorithm when we don't have enough vectorisation (#7524)
We write a few heuristics to improve vectorisation and decrease the bank conflicts in the case when the default vectorisation does not cover a whole bank. --------- Co-authored-by: apgoucher <[email protected]>
1 parent 5c9393a commit 1031dc7

File tree

2 files changed

+91
-2
lines changed

2 files changed

+91
-2
lines changed

lib/Tools/GenericSwizzling.cpp

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ SmallVector<int32_t> flatten(const LinearLayout &ll, StringAttr dim) {
4747
return vec;
4848
};
4949

50+
SmallVector<int32_t> removeZeros(ArrayRef<int32_t> vec) {
51+
SmallVector<int32_t> result;
52+
for (int32_t r : vec) {
53+
if (r != 0) {
54+
result.push_back(r);
55+
}
56+
}
57+
return result;
58+
}
59+
5060
// [1, 2, 4, 8] -> [[1], [2], [4], [8]]
5161
std::vector<std::vector<int32_t>> unflatten(ArrayRef<int32_t> basis) {
5262
std::vector<std::vector<int32_t>> unflattened;
@@ -279,6 +289,7 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
279289
auto *ctx = src.getInDimNames().begin()->getContext();
280290
auto kReg = StringAttr::get(ctx, "register");
281291
auto kLane = StringAttr::get(ctx, "lane");
292+
auto kWarp = StringAttr::get(ctx, "warp");
282293

283294
// We work on the flattened tensors as the tensor dimensions are not relevant
284295
const LinearLayout srcFlat = src.flattenOuts();
@@ -307,6 +318,65 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
307318
if (vbasis.size() > maxVecBases) {
308319
vbasis.resize(maxVecBases);
309320
}
321+
// We fill-up vbasis until it has 32 bits as best we can
322+
auto vecFillsBank = (1 << vbasis.size()) * bitwidth >= 32;
323+
if (!vecFillsBank) {
324+
auto warpSrc = removeZeros(flatten(srcFlat, kWarp));
325+
auto warpDst = removeZeros(flatten(dstFlat, kWarp));
326+
auto removeVec = [&vbasis](ArrayRef<int32_t> vec) {
327+
SmallVector<int32_t> result;
328+
for (int32_t r : vec) {
329+
if (!llvm::is_contained(vbasis, r)) {
330+
result.push_back(r);
331+
}
332+
}
333+
return result;
334+
};
335+
auto regSrcWarp = intersectionBasis(removeVec(regSrc), warpDst, dim);
336+
auto regDstWarp = intersectionBasis(removeVec(regDst), warpSrc, dim);
337+
// Maximise vectorisation in the load or the store without creating
338+
// conflicts
339+
SmallVector<int32_t> largest;
340+
if (regSrcWarp.size() == regDstWarp.size() && regSrcWarp.size() > 0) {
341+
// We choose the one with the lowest basis in the hope that it will
342+
// avoid PRMTs. The comparison of the mins will be strict as the sets
343+
// removeVec(regSrc) and removeVec(regDst) don't intersect
344+
if (*llvm::min_element(regSrcWarp) < *llvm::min_element(regDstWarp)) {
345+
largest = regSrcWarp;
346+
} else {
347+
largest = regDstWarp;
348+
}
349+
} else {
350+
largest = regSrcWarp.size() > regDstWarp.size() ? regSrcWarp : regDstWarp;
351+
}
352+
vbasis.append(largest.begin(), largest.end());
353+
if (vbasis.size() > maxVecBases) {
354+
vbasis.resize(maxVecBases);
355+
}
356+
// We allow vbasis.size > Log2_32(32 / bitwidth) at this point, as it is in
357+
// general good, but one should note
358+
if (vbasis.size() < llvm::Log2_32(32 / bitwidth)) {
359+
// Pad the vectorisation to 32 bits with warp bases
360+
auto warpSrcWarp = intersectionBasis(warpSrc, warpDst, dim);
361+
vbasis.append(warpSrcWarp.begin(), warpSrcWarp.end());
362+
}
363+
364+
int i = 0;
365+
while (vbasis.size() < llvm::Log2_32(32 / bitwidth) &&
366+
(i < warpSrc.size() || i < warpDst.size())) {
367+
// If we have not filled up a whole bank, we add more warp bases
368+
// until we have 32 bits. They will at least avoid bank conflicts in one
369+
// direction
370+
if (i < warpSrc.size() && !llvm::is_contained(vbasis, warpSrc[i])) {
371+
vbasis.push_back(warpSrc[i]);
372+
}
373+
if (vbasis.size() < llvm::Log2_32(32 / bitwidth) && i < warpDst.size() &&
374+
!llvm::is_contained(vbasis, warpDst[i])) {
375+
vbasis.push_back(warpDst[i]);
376+
}
377+
++i;
378+
}
379+
}
310380

311381
// Bits in a bank segment: 32 banks x 32 bits
312382
constexpr int32_t bankBits = 32 * 32;
@@ -321,8 +391,11 @@ LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
321391
auto bankDst = llvm::to_vector(llvm::concat<int32_t>(vbasis, laneDst));
322392

323393
// Whether we'll use b32.v1 / b32.v2 / b32.v4
324-
auto b32Vec =
325-
llvm::Log2_32(std::max<int32_t>((1 << vbasis.size()) * bitwidth / 32, 1));
394+
// FIXME: With !vecFillsBank we may use b32.v2 or b32.v4 for the load or
395+
// store, but we pesimistically assume we don't.
396+
auto b32Vec = !vecFillsBank ? 0
397+
: llvm::Log2_32(std::max<int32_t>(
398+
(1 << vbasis.size()) * bitwidth / 32, 1));
326399
// Drop the last vec bases of the banks
327400
bankSrc.resize(bankSrc.size() - b32Vec);
328401
bankDst.resize(bankDst.size() - b32Vec);

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
10111011

10121012
// -----
10131013

1014+
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
1015+
#blocked1 = #ttg.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
1016+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1017+
// CHECK: llvm.mlir.global external @global_smem
1018+
// CHECK-LABEL: convert_layout_transpose
1019+
tt.func @convert_layout_transpose(%arg0: tensor<128x128xf8E5M2, #blocked>) {
1020+
// CHECK-COUNT-128: llvm.store {{.*}} vector<1xi8>
1021+
// CHECK: nvvm.barrier0
1022+
// CHECK-COUNT-32: llvm.load {{.*}} vector<4xi8>
1023+
%0 = ttg.convert_layout %arg0 : tensor<128x128xf8E5M2, #blocked> -> tensor<128x128xf8E5M2, #blocked1>
1024+
tt.return
1025+
}
1026+
}
1027+
1028+
// -----
1029+
10141030
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
10151031
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
10161032
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

0 commit comments

Comments
 (0)