Skip to content

Commit 58ae6f0

Browse files
[Backend] Use byte permutes in intra-warp layout conversion (#7809)
### Issue We patch an oversight in #7558 where reindexing sub-32-bit elements before or after unpacking them from vectors can cause LLVM’s InstCombine to materialize `shufflevector`s in real kernels that lower to byte permute instructions, which are not optimized away. This was believed to cause a small regression in #7574. In the context of that PR, one has 8-bit elements packed into registers and a layout conversion described by the permutation `(r1 r2 l1 l0)` of register (`r*`) and lane (`l*`) basis vectors. Due to register packing, `r1` corresponds to an intra-register index bit. The current algorithm interprets ``` (r1 r2 l1 l0) = (r2 r1) * (r2 l1)(l0 l1), ``` implements `(r2 l1)(l0 l1)` using a `select-shuffle-select` pattern, and then applies `(r2 r1)` by reindexing the elements after extraction. As the elements are immediately repacked, InstCombine produces `shufflevector` instructions from the extract-insert pattern, resulting in one `prmt` per packed register. ### Fix It is possible to fuse the effects of these intra-register index bit permutations to the first and/or third stages of the `select-shuffle-select` pattern of the conversion algorithm. In most cases, this happens when in the cycle decomposition of the layout conversion, the intra-register index bit is adjacent to a lane index bit within a cycle, as in the above example. ### Future work To the best of my knowledge, this PR handles all cases where the above fusion is possible. However, there are cases where it is not possible which have potential for further optimization due to InstCombine’s lack of coverage: Suppose we have four input `v4i8`s whose elements are rearranged via extraction and insertion into four output `v4i8`s in a manner such that each output vector contains one element from each of the four input vectors. In this case, LLVM generates 4 chains of 3 `prmt`s to build the output vectors, but it is possible to carry this out using two stages of 4 independent `prmt`s, thus reducing depth and instruction count. This pattern can also exist in layout conversions that take the `transferWithinWarp` path, but as it is truly an intra-thread pattern, this optimization should be implemented in `transferWithinThread` and invoked within `transferWithinWarp` in a future PR. --------- Co-authored-by: apgoucher <[email protected]>
1 parent de4376e commit 58ae6f0

File tree

12 files changed

+451
-157
lines changed

12 files changed

+451
-157
lines changed

include/triton/Analysis/Utility.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,19 @@ class GatherLoweringHelper {
181181
// `pReg` and `pLane` are square layouts each with only one input and output
182182
// dimension. `mixedTranspositions` holds pairs of integers (i, j)
183183
// corresponding to the transposition (r_i l_j) of the i-th register basis
184-
// vector with the j-th lane basis vector.
184+
// vector with the j-th lane basis vector along with 16-bit selectors for byte
185+
// permute instructions (where each of the four nybbles is in the range [0, 7]).
185186
struct DecomposedWarpConversion {
187+
struct TranspositionInfo {
188+
std::pair<int, int> transposition;
189+
uint16_t topPreSel = 0x3210;
190+
uint16_t botPreSel = 0x7654;
191+
uint16_t topPostSel = 0x3210;
192+
uint16_t botPostSel = 0x7654;
193+
};
194+
186195
triton::LinearLayout pReg, pLane;
187-
SmallVector<std::pair<int, int>> mixedTranspositions;
196+
SmallVector<TranspositionInfo> mixedTranspositions;
188197
};
189198

190199
// Produces a decomposition of a permutation describing a warp-local layout
@@ -196,7 +205,7 @@ struct DecomposedWarpConversion {
196205
// represented as a permutation.
197206
DecomposedWarpConversion
198207
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
199-
RankedTensorType dstTy);
208+
RankedTensorType dstTy, int bitwidth);
200209

201210
// Decomposes a reshape into simpler pieces.
202211
//

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class TargetInfoBase {
4848
virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
4949
Value i) const = 0;
5050

51+
virtual Value permute(RewriterBase &rewriter, Location loc, Value a, Value b,
52+
Value selector) const = 0;
53+
5154
virtual Value programId(RewriterBase &rewriter, Location loc,
5255
ModuleOp moduleOp, ProgramIDDim axis) const = 0;
5356

lib/Analysis/Utility.cpp

Lines changed: 237 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "triton/Tools/LayoutUtils.h"
1717
#include "triton/Tools/LinearLayout.h"
1818
#include "triton/Tools/Sys/GetEnv.hpp"
19+
#include "llvm/ADT/SmallSet.h"
1920

2021
namespace mlir {
2122

@@ -247,9 +248,14 @@ unsigned ScanLoweringHelper::getScratchSizeInBytes() {
247248
return elementSizeInBytes * getScratchSizeInElems();
248249
}
249250

251+
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
252+
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
253+
std::vector<std::vector<int32_t>> &regBases,
254+
int bitwidth);
255+
250256
DecomposedWarpConversion
251257
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
252-
RankedTensorType dstTy) {
258+
RankedTensorType dstTy, int bitwidth) {
253259
// Two layouts, ll_src and ll_dst, representing the same tensor can be
254260
// viewed as surjections of GF(2) vector spaces:
255261
//
@@ -278,11 +284,12 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
278284
// subsequences of consecutive lane bits from cycles involving both bit types.
279285
// Further explanation of this method is below.
280286
//
281-
// The decomposition is performed in two stages. First, we compute the
287+
// The decomposition is performed in three stages. First, we compute the
282288
// permutation matrix `P` by using `invertAndCompose` to generate a skeleton
283289
// and then fill in any zero columns. Second, we walk the cycles of `P` to
284290
// factor out mixed transpositions to build `mixedTranspositions`, `pReg`, and
285-
// `pLane`.
291+
// `pLane`. Finally, we determine any selectors needed for byte permute
292+
// instructions in place of `selp` instructions when packing registers.
286293

287294
// We remove any broadcasting in the register dimensions of the layouts before
288295
// forming the permutation `P` as the components of the decomposition directly
@@ -336,19 +343,14 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
336343
T = padWithZeros(T);
337344
}
338345

339-
// Flatten outs for ease of building `P`, and reorder outs as flattening
340-
// depends on output dimension order.
341-
if (outDimNames != llvm::to_vector(T.getOutDimNames()))
342-
T = T.transposeOuts(outDimNames);
343-
S = S.flattenOuts();
344-
T = T.flattenOuts();
345-
346346
// We compute T^transpose \circ S, which serves as a skeleton for `P`, then
347347
// fill in zero columns, prioritizing producing fixed points. As we only need
348348
// the basis vectors of `P`, we never actually produce the LinearLayout.
349349
auto pBases = S.invertAndCompose(T).getBases();
350350

351351
// Find the common and uncommon zeros of S and T
352+
S = S.flattenOuts();
353+
T = T.flattenOuts();
352354
SmallVector<std::pair<int32_t, int32_t>> srcFreeZeros;
353355
SmallVector<std::pair<int32_t, int32_t>> dstFreeZeros;
354356
for (auto [dimIdx, dim] : llvm::enumerate(inDimNames)) {
@@ -461,11 +463,234 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
461463
}
462464
assert(visited.all() && "Cycle walk incomplete");
463465

466+
auto processedTranspos =
467+
getTranspositionSelectors(mixedTranspositions, regBases, bitwidth);
468+
464469
auto pReg = LinearLayout(std::move(pRegBases), {{kReg, 1 << nRegBases}},
465470
/*requireSurjective=*/true);
466471
auto pLane = LinearLayout(std::move(pLaneBases), {{kLane, 1 << nLaneBases}},
467472
/*requireSurjective=*/true);
468-
return {std::move(pReg), std::move(pLane), std::move(mixedTranspositions)};
473+
return {std::move(pReg), std::move(pLane), std::move(processedTranspos)};
474+
}
475+
476+
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
477+
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
478+
std::vector<std::vector<int32_t>> &regBases,
479+
int bitwidth) {
480+
// When possible, we fuse permutations of 'low' register bits together
481+
// with a mixed transposition, resulting in byte permute instructions instead
482+
// of `select` instructions. After processing, no low register bits appear in
483+
// the returned list of mixed transpositions.
484+
int m = mixedTranspositions.size();
485+
int nRegBases = regBases.size();
486+
int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4));
487+
int nPack = std::min(nPackPrelim, nRegBases - m);
488+
489+
SmallVector<DecomposedWarpConversion::TranspositionInfo> ret;
490+
ret.reserve(mixedTranspositions.size());
491+
if (nPack == 0) {
492+
for (auto &t : mixedTranspositions)
493+
ret.push_back(DecomposedWarpConversion::TranspositionInfo{t});
494+
return ret;
495+
}
496+
// Consider for example the cycle
497+
//
498+
// (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3)
499+
// = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2)
500+
//
501+
// with `nPack` = 2 so that r0 and r1 are considered low bits. We want to
502+
// factor out any low bits from `pReg` and to incorporate them into the data
503+
// of the mixed transposition. After processing, the contribution to `pReg`
504+
// is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with
505+
// the effects of (r3 r0) and (r3 r1) encoded in the returned selectors.
506+
// In general, low bits occurring immediately before l_j modify the selectors
507+
// of the `prmt` before the shuffle, while low bits occurring immediately
508+
// after l_k modify the selectors of the `prmt` after the shuffle. Unmodified
509+
// selectors correspond to `select` instructions.
510+
// Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is
511+
// not used in another mixed transposition and conjugating out a low bit:
512+
//
513+
// (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1)
514+
// = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1).
515+
//
516+
// Conjugation does not affect `pReg`. However, the set of fused mixed and
517+
// low-bit transpositions is noncommutative in cases where there are no
518+
// intervening high bits in between distinct sequences of lane bits as the
519+
// paired low bit is used in modifying the selectors of both factors:
520+
//
521+
// (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0).
522+
//
523+
// The `*` is standard composition of permutations. The groupings correspond
524+
// to different `TranspositionInfo` objects. For example, the permutation
525+
// `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with
526+
// pre- and post-shuffle selectors determined by the `r0` bit.
527+
// Processing of mixed transpositions is performed by determining the `head`
528+
// and `tail` of an excision of bits in cycles of `pReg` and building lists
529+
// of low bits acting as selector modifiers. In the noncommutative cases, we
530+
// opt to restrict the number of post-shuffle modifiers to one.
531+
532+
auto permuteSelector = [nPack](uint16_t sel, int bitIdx) {
533+
int lo = bitIdx + (2 - nPack);
534+
uint16_t maskHi = 0x4444;
535+
uint16_t maskLo = 0x1111 << lo;
536+
uint16_t fixed = sel & ~maskHi & ~maskLo;
537+
int shift = 2 - lo;
538+
return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift);
539+
};
540+
auto generateSelectors = [&](int head, int tail, auto &&lowBits) {
541+
uint16_t topSel = 0x3210;
542+
uint16_t botSel = 0x7654;
543+
for (auto lowBit : lowBits) {
544+
topSel = permuteSelector(topSel, lowBit);
545+
botSel = permuteSelector(botSel, lowBit);
546+
if (lowBit != head && lowBit != tail)
547+
regBases[lowBit][0] = 1 << lowBit;
548+
}
549+
return std::pair{topSel, botSel};
550+
};
551+
552+
llvm::SmallSet<int32_t, 6> pairedRegBits;
553+
for (auto [rBit, lBit] : mixedTranspositions)
554+
pairedRegBits.insert(rBit);
555+
556+
// A low bit in a mixed transposition must be replaced by a high bit. The
557+
// choice of high bit can affect instruction count. If the first high bit
558+
// found when walking along `pReg` is unpaired, then that bit is the best
559+
// choice. We reorder the transpositions to guarantee this during processing.
560+
auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); };
561+
auto nextHighFree = [&](auto p) {
562+
int curr = p.first;
563+
do {
564+
if (curr >= nPack)
565+
return curr == p.first || !pairedRegBits.contains(curr);
566+
curr = next(curr);
567+
} while (curr != p.first);
568+
return false;
569+
};
570+
std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(),
571+
nextHighFree);
572+
// If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low
573+
// bit to an open high bit, then the high bit should be used as the partner.
574+
auto prev = [&](int b) {
575+
int tail = b;
576+
int curr = next(b);
577+
while (curr != b) {
578+
tail = curr;
579+
curr = next(curr);
580+
}
581+
return tail;
582+
};
583+
auto findPartner = [&](int lowBit, auto &preShufLoBits) {
584+
if (nPack == 2) {
585+
int otherLow = 1 - lowBit;
586+
int b = next(otherLow);
587+
if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) &&
588+
!pairedRegBits.contains(otherLow)) {
589+
preShufLoBits.push_back(otherLow);
590+
regBases[prev(otherLow)][0] = 1 << b;
591+
pairedRegBits.insert(b);
592+
return b;
593+
}
594+
}
595+
int potentialPartner = nPack;
596+
while (pairedRegBits.contains(potentialPartner))
597+
++potentialPartner;
598+
pairedRegBits.insert(potentialPartner);
599+
return potentialPartner;
600+
};
601+
602+
for (auto p : mixedTranspositions) {
603+
int rBit = p.first;
604+
int lBit = p.second;
605+
SmallVector<int> cycle;
606+
int currBit = rBit;
607+
do {
608+
cycle.push_back(currBit);
609+
currBit = next(currBit);
610+
} while (currBit != rBit);
611+
612+
// Find any low register bits adjacent to the excised lane bits which aren't
613+
// used in other mixed transpositions.
614+
auto isBoundary = [&](int bit) {
615+
return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit);
616+
};
617+
auto forwardEnd = llvm::find_if(cycle, isBoundary);
618+
auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary);
619+
SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd);
620+
SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd);
621+
int head;
622+
int tail;
623+
int partnerBit = -1;
624+
625+
// Case work to determine what to conjugate out.
626+
if (forwardEnd != cycle.end()) {
627+
if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) {
628+
// End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2)
629+
// No conjugation needed.
630+
head = partnerBit = *forwardEnd;
631+
} else {
632+
// End at different paired bit. E.g. (l0 r0 r1 l1 r2)
633+
// Non-leading factor in a noncommutative case.
634+
// Conjugate by first low bit in forward walk.
635+
head = postShufLoBits.front();
636+
preShufLoBits.push_back(head);
637+
postShufLoBits.resize(1);
638+
pairedRegBits.erase(head);
639+
}
640+
tail = *backwardEnd;
641+
if (tail < nPack && pairedRegBits.contains(tail)) {
642+
// Non-terminal factor in a noncommutative case.
643+
preShufLoBits.insert(preShufLoBits.begin(), tail);
644+
}
645+
} else {
646+
if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) {
647+
// Symmetric noncommutative case. E.g. (l0 r0 l1 r1)
648+
preShufLoBits.erase(preShufLoBits.begin());
649+
postShufLoBits.pop_back();
650+
pairedRegBits.erase(postShufLoBits.front());
651+
head = rBit;
652+
tail = next(rBit);
653+
} else {
654+
// Isolated low bits with single mixed transposition. E.g. (l0 r0 r1)
655+
if (postShufLoBits.size() == 2)
656+
postShufLoBits.pop_back();
657+
head = tail = preShufLoBits.front();
658+
}
659+
}
660+
661+
if (partnerBit < 0)
662+
partnerBit = findPartner(head, preShufLoBits);
663+
auto [topPostSel, botPostSel] =
664+
generateSelectors(head, tail, llvm::reverse(postShufLoBits));
665+
auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits);
666+
regBases[tail][0] = 1 << head;
667+
668+
DecomposedWarpConversion::TranspositionInfo info;
669+
info.transposition = {partnerBit, lBit};
670+
info.topPreSel = topPreSel;
671+
info.botPreSel = botPreSel;
672+
info.topPostSel = topPostSel;
673+
info.botPostSel = botPostSel;
674+
675+
// In noncommutative cases, post-shuffle selectors of non-leading terms come
676+
// from a single low bit by design, so we can determine where to insert a
677+
// non-terminal factor by examining processed selectors.
678+
if (!preShufLoBits.empty()) {
679+
uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410;
680+
auto it =
681+
llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; });
682+
ret.insert(it, info);
683+
} else {
684+
ret.push_back(info);
685+
}
686+
}
687+
if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) {
688+
// If (r0 r1) was originally in `P`, fold it into a mixed transposition.
689+
auto &t = ret.back();
690+
t.topPostSel = 0x3120;
691+
t.botPostSel = 0x7564;
692+
}
693+
return ret;
469694
}
470695

471696
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>>
@@ -763,7 +988,7 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
763988
auto kLane = StringAttr::get(ctx, "lane");
764989
if (to_vector(layout.getOutDimNames()) ==
765990
SmallVector<StringAttr, 2>{kRegister, kLane}) {
766-
auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy);
991+
auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, 32);
767992
return (factors.mixedTranspositions.size() < 2);
768993
}
769994
return false;

0 commit comments

Comments
 (0)