Skip to content

Commit a5b948c

Browse files
authored
[BACKEND] Fix vectorisation for convert_layout with ldmatrix and stmatrix (#8655)
The previous code was a bit too eager adding `reps` in this case. So much so that after doing that we wouldn't have enough registers as to fully vectorise the ldmatrix/stmatrix Fixes the regression reported in triton-lang/triton#8328
1 parent 698bc5f commit a5b948c

File tree

3 files changed

+70
-14
lines changed

3 files changed

+70
-14
lines changed

include/triton/Tools/LinearLayout.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,15 @@ class LinearLayout {
459459
auto getOutDimSizes() const { return llvm::make_second_range(outDims); }
460460

461461
// Relevant for reshaping
462+
463+
SmallVector<std::pair<StringAttr, int32_t>> getInDims() const {
464+
SmallVector<std::pair<StringAttr, int32_t>> inDims;
465+
inDims.reserve(bases.size());
466+
for (auto [inDim, inDimBases] : bases) {
467+
inDims.push_back({inDim, getInDimSize(inDim)});
468+
}
469+
return inDims;
470+
}
462471
SmallVector<std::pair<StringAttr, int32_t>> getOutDims() const {
463472
return to_vector(outDims);
464473
}

lib/Tools/GenericSwizzling.cpp

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ SmallVector<int32_t> nullspaceBasis(ArrayRef<int32_t> vectors, int32_t dim) {
100100
// without sacrificing vectorisation and split it into its own
101101
// `reps` dimension
102102
LinearLayout buildReps(MLIRContext *ctx, const LinearLayout &src,
103-
const LinearLayout &dst, const LinearLayout &smem) {
103+
const LinearLayout &dst, const LinearLayout &smem,
104+
int32_t leaveReps) {
104105
auto kVec = StringAttr::get(ctx, "vector");
105106
auto kBank = StringAttr::get(ctx, "bank");
106107
auto kSegment = StringAttr::get(ctx, "segment");
@@ -116,8 +117,16 @@ LinearLayout buildReps(MLIRContext *ctx, const LinearLayout &src,
116117
SetVector<int32_t> segment;
117118
SetVector<int32_t> reps;
118119
for (auto s : smemSegment) {
120+
// Do not move the first leaveReps bases from reps to segment
121+
// as we need them to vectorise the instructions (think .x2 and .x4 in
122+
// ldmatrix)
119123
if (srcRegs.contains(s) && dstRegs.contains(s)) {
120-
reps.insert(s);
124+
if (leaveReps > 0) {
125+
leaveReps--;
126+
segment.insert(s);
127+
} else {
128+
reps.insert(s);
129+
}
121130
} else {
122131
segment.insert(s);
123132
}
@@ -376,11 +385,12 @@ std::optional<SmallVector<int32_t>> optimalSwizzlingTile(
376385
return vbasis;
377386
}
378387

379-
LinearLayout
380-
optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
381-
int32_t bitwidth, ArrayRef<int32_t> vbasis,
382-
ArrayRef<int32_t> tileSrc, ArrayRef<int32_t> tileDst,
383-
ArrayRef<std::pair<StringAttr, int32_t>> outDims) {
388+
LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
389+
int32_t bitwidth, ArrayRef<int32_t> vbasis,
390+
ArrayRef<int32_t> tileSrc,
391+
ArrayRef<int32_t> tileDst,
392+
ArrayRef<std::pair<StringAttr, int32_t>> outDims,
393+
int32_t leaveReps = 0) {
384394
// We work on the flattened tensors as the tensor dimensions are not relevant
385395
assert(src.getNumOutDims() == 1 && dst.getNumOutDims() == 1 &&
386396
"src and dst must have a single output dimension");
@@ -439,7 +449,7 @@ optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
439449
{bankAttr, unflatten(bbasis)},
440450
{segAttr, unflatten(sbasis)}},
441451
src.getOutDims(), /*requireSurjective=*/true);
442-
basis1D = buildReps(ctx, src, dst, basis1D);
452+
basis1D = buildReps(ctx, src, dst, basis1D, leaveReps);
443453

444454
return basis1D.reshapeOuts(outDims);
445455
}
@@ -649,7 +659,7 @@ optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
649659

650660
// Get the associated src/dst tiles for each instruction if they exist
651661
SmallVector<std::tuple<std::pair<int32_t, int32_t>, SmallVector<int32_t>,
652-
SmallVector<int32_t>, SmallVector<int32_t>>>
662+
SmallVector<int32_t>, SmallVector<int32_t>, int32_t>>
653663
tiles;
654664
for (auto [instrs, vbasis] : instr) {
655665
auto maybeTileSrc =
@@ -659,22 +669,31 @@ optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
659669
if (!maybeTileSrc.has_value() || !maybeTileDst.has_value()) {
660670
continue;
661671
}
672+
// Regs bases missing to get full vectorisation
673+
auto regsMissing = [](const LocalMemOpTile &instr) {
674+
return instr.laneContig.size() + instr.laneAddr.size() - 3;
675+
};
676+
// We leave 2 reps for combinations of ldmatrix/stmatrix instructions
677+
// to be able to fully vectorise them
678+
int32_t leaveReps = std::min(regsMissing(srcTiles[instrs.first]),
679+
regsMissing(dstTiles[instrs.second]));
680+
assert((leaveReps == 0 || leaveReps == 2) && "leaveReps must be 0 or 2");
662681
tiles.push_back({instrs, std::move(vbasis), std::move(*maybeTileSrc),
663-
std::move(*maybeTileDst)});
682+
std::move(*maybeTileDst), leaveReps});
664683
}
665684

666685
if (tiles.empty()) {
667686
// We lower to an ld / st, but can't use LDS128/STS128
668687
auto smem = optimalSwizzlingLdSt(src, dst, bitwidth);
669688
return {smem, {0, 0}};
670689
} else {
671-
// We choose the pair of instructions that minimises the total bank
672-
// conflicts
673690
SmallVector<std::tuple<int, LinearLayout, std::pair<int32_t, int32_t>>>
674691
smems;
675-
for (auto [instrs, vbasis, tileSrc, tileDst] : tiles) {
692+
// We choose the pair of instructions that minimises the total bank
693+
// conflicts
694+
for (auto [instrs, vbasis, tileSrc, tileDst, leaveReps] : tiles) {
676695
auto smem = optimalSwizzling(srcFlat, dstFlat, bitwidth, vbasis, tileSrc,
677-
tileDst, src.getOutDims());
696+
tileDst, src.getOutDims(), leaveReps);
678697
auto [read, write] = bankConflicts(tileSrc, tileDst, smem);
679698
smems.push_back({read + write, smem, {instrs.first, instrs.second}});
680699
}

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,34 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
248248

249249
// -----
250250

251+
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}>
252+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 32]], warp = [[32, 0], [64, 0], [16, 0]], block = []}>
253+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
254+
tt.func @convert_mma_to_blocked(%a: tensor<128x64xbf16, #linear>) {
255+
// CHECK: llvm.store {{.*}} : vector<4xi32>
256+
// CHECK: nvvm.barrier0
257+
// CHECK: llvm.load {{.*}} -> vector<4xi32>
258+
// CHECK: nvvm.barrier0
259+
// CHECK: llvm.store {{.*}} : vector<4xi32>
260+
// CHECK: nvvm.barrier0
261+
// CHECK: llvm.load {{.*}} -> vector<4xi32>
262+
// CHECK: nvvm.barrier0
263+
// CHECK: llvm.store {{.*}} : vector<4xi32>
264+
// CHECK: nvvm.barrier0
265+
// CHECK: llvm.load {{.*}} -> vector<4xi32>
266+
// CHECK: nvvm.barrier0
267+
// CHECK: llvm.store {{.*}} : vector<4xi32>
268+
// CHECK: nvvm.barrier0
269+
// CHECK: llvm.load {{.*}} -> vector<4xi32>
270+
// CHECK-NOT: llvm.store
271+
// CHECK-NOT: llvm.load
272+
%b = ttg.convert_layout %a: tensor<128x64xbf16, #linear> -> tensor<128x64xbf16, #blocked>
273+
tt.return
274+
}
275+
}
276+
277+
// -----
278+
251279
#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
252280
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
253281
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

0 commit comments

Comments
 (0)