Skip to content

Commit 9f23f73

Browse files
Revert partial "[Backend] Use byte permutes in intra-warp layout conversion (#7809)"
1 parent 0e679c1 commit 9f23f73

File tree

4 files changed

+155
-407
lines changed

4 files changed

+155
-407
lines changed

include/triton/Analysis/Utility.h

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -181,19 +181,10 @@ class GatherLoweringHelper {
181181
// `pReg` and `pLane` are square layouts each with only one input and output
182182
// dimension. `mixedTranspositions` holds pairs of integers (i, j)
183183
// corresponding to the transposition (r_i l_j) of the i-th register basis
184-
// vector with the j-th lane basis vector along with 16-bit selectors for byte
185-
// permute instructions (where each of the four nybbles is in the range [0, 7]).
184+
// vector with the j-th lane basis vector.
186185
struct DecomposedWarpConversion {
187-
struct TranspositionInfo {
188-
std::pair<int, int> transposition;
189-
uint16_t topPreSel = 0x3210;
190-
uint16_t botPreSel = 0x7654;
191-
uint16_t topPostSel = 0x3210;
192-
uint16_t botPostSel = 0x7654;
193-
};
194-
195186
triton::LinearLayout pReg, pLane;
196-
SmallVector<TranspositionInfo> mixedTranspositions;
187+
SmallVector<std::pair<int, int>> mixedTranspositions;
197188
};
198189

199190
// Produces a decomposition of a permutation describing a warp-local layout
@@ -205,7 +196,7 @@ struct DecomposedWarpConversion {
205196
// represented as a permutation.
206197
DecomposedWarpConversion
207198
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
208-
RankedTensorType dstTy, int bitwidth);
199+
RankedTensorType dstTy);
209200

210201
// Decomposes a reshape into simpler pieces.
211202
//

lib/Analysis/Utility.cpp

Lines changed: 12 additions & 237 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "triton/Tools/LayoutUtils.h"
1818
#include "triton/Tools/LinearLayout.h"
1919
#include "triton/Tools/Sys/GetEnv.hpp"
20-
#include "llvm/ADT/SmallSet.h"
2120

2221
namespace mlir {
2322

@@ -254,14 +253,9 @@ unsigned ScanLoweringHelper::getScratchSizeInBytes() {
254253
return elementSizeInBytes * getScratchSizeInElems();
255254
}
256255

257-
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
258-
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
259-
std::vector<std::vector<int32_t>> &regBases,
260-
int bitwidth);
261-
262256
DecomposedWarpConversion
263257
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
264-
RankedTensorType dstTy, int bitwidth) {
258+
RankedTensorType dstTy) {
265259
// Two layouts, ll_src and ll_dst, representing the same tensor can be
266260
// viewed as surjections of GF(2) vector spaces:
267261
//
@@ -290,12 +284,11 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
290284
// subsequences of consecutive lane bits from cycles involving both bit types.
291285
// Further explanation of this method is below.
292286
//
293-
// The decomposition is performed in three stages. First, we compute the
287+
// The decomposition is performed in two stages. First, we compute the
294288
// permutation matrix `P` by using `invertAndCompose` to generate a skeleton
295289
// and then fill in any zero columns. Second, we walk the cycles of `P` to
296290
// factor out mixed transpositions to build `mixedTranspositions`, `pReg`, and
297-
// `pLane`. Finally, we determine any selectors needed for byte permute
298-
// instructions in place of `selp` instructions when packing registers.
291+
// `pLane`.
299292

300293
// We remove any broadcasting in the register dimensions of the layouts before
301294
// forming the permutation `P` as the components of the decomposition directly
@@ -349,14 +342,19 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
349342
T = padWithZeros(T);
350343
}
351344

345+
// Flatten outs for ease of building `P`, and reorder outs as flattening
346+
// depends on output dimension order.
347+
if (outDimNames != llvm::to_vector(T.getOutDimNames()))
348+
T = T.transposeOuts(outDimNames);
349+
S = S.flattenOuts();
350+
T = T.flattenOuts();
351+
352352
// We compute T^transpose \circ S, which serves as a skeleton for `P`, then
353353
// fill in zero columns, prioritizing producing fixed points. As we only need
354354
// the basis vectors of `P`, we never actually produce the LinearLayout.
355355
auto pBases = S.invertAndCompose(T).getBases();
356356

357357
// Find the common and uncommon zeros of S and T
358-
S = S.flattenOuts();
359-
T = T.flattenOuts();
360358
SmallVector<std::pair<int32_t, int32_t>> srcFreeZeros;
361359
SmallVector<std::pair<int32_t, int32_t>> dstFreeZeros;
362360
for (auto [dimIdx, dim] : llvm::enumerate(inDimNames)) {
@@ -469,234 +467,11 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
469467
}
470468
assert(visited.all() && "Cycle walk incomplete");
471469

472-
auto processedTranspos =
473-
getTranspositionSelectors(mixedTranspositions, regBases, bitwidth);
474-
475470
auto pReg = LinearLayout(std::move(pRegBases), {{kReg, 1 << nRegBases}},
476471
/*requireSurjective=*/true);
477472
auto pLane = LinearLayout(std::move(pLaneBases), {{kLane, 1 << nLaneBases}},
478473
/*requireSurjective=*/true);
479-
return {std::move(pReg), std::move(pLane), std::move(processedTranspos)};
480-
}
481-
482-
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
483-
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
484-
std::vector<std::vector<int32_t>> &regBases,
485-
int bitwidth) {
486-
// When possible, we fuse permutations of 'low' register bits together
487-
// with a mixed transposition, resulting in byte permute instructions instead
488-
// of `select` instructions. After processing, no low register bits appear in
489-
// the returned list of mixed transpositions.
490-
int m = mixedTranspositions.size();
491-
int nRegBases = regBases.size();
492-
int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4));
493-
int nPack = std::min(nPackPrelim, nRegBases - m);
494-
495-
SmallVector<DecomposedWarpConversion::TranspositionInfo> ret;
496-
ret.reserve(mixedTranspositions.size());
497-
if (nPack == 0) {
498-
for (auto &t : mixedTranspositions)
499-
ret.push_back(DecomposedWarpConversion::TranspositionInfo{t});
500-
return ret;
501-
}
502-
// Consider for example the cycle
503-
//
504-
// (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3)
505-
// = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2)
506-
//
507-
// with `nPack` = 2 so that r0 and r1 are considered low bits. We want to
508-
// factor out any low bits from `pReg` and to incorporate them into the data
509-
// of the mixed transposition. After processing, the contribution to `pReg`
510-
// is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with
511-
// the effects of (r3 r0) and (r3 r1) encoded in the returned selectors.
512-
// In general, low bits occurring immediately before l_j modify the selectors
513-
// of the `prmt` before the shuffle, while low bits occurring immediately
514-
// after l_k modify the selectors of the `prmt` after the shuffle. Unmodified
515-
// selectors correspond to `select` instructions.
516-
// Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is
517-
// not used in another mixed transposition and conjugating out a low bit:
518-
//
519-
// (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1)
520-
// = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1).
521-
//
522-
// Conjugation does not affect `pReg`. However, the set of fused mixed and
523-
// low-bit transpositions is noncommutative in cases where there are no
524-
// intervening high bits in between distinct sequences of lane bits as the
525-
// paired low bit is used in modifying the selectors of both factors:
526-
//
527-
// (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0).
528-
//
529-
// The `*` is standard composition of permutations. The groupings correspond
530-
// to different `TranspositionInfo` objects. For example, the permutation
531-
// `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with
532-
// pre- and post-shuffle selectors determined by the `r0` bit.
533-
// Processing of mixed transpositions is performed by determining the `head`
534-
// and `tail` of an excision of bits in cycles of `pReg` and building lists
535-
// of low bits acting as selector modifiers. In the noncommutative cases, we
536-
// opt to restrict the number of post-shuffle modifiers to one.
537-
538-
auto permuteSelector = [nPack](uint16_t sel, int bitIdx) {
539-
int lo = bitIdx + (2 - nPack);
540-
uint16_t maskHi = 0x4444;
541-
uint16_t maskLo = 0x1111 << lo;
542-
uint16_t fixed = sel & ~maskHi & ~maskLo;
543-
int shift = 2 - lo;
544-
return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift);
545-
};
546-
auto generateSelectors = [&](int head, int tail, auto &&lowBits) {
547-
uint16_t topSel = 0x3210;
548-
uint16_t botSel = 0x7654;
549-
for (auto lowBit : lowBits) {
550-
topSel = permuteSelector(topSel, lowBit);
551-
botSel = permuteSelector(botSel, lowBit);
552-
if (lowBit != head && lowBit != tail)
553-
regBases[lowBit][0] = 1 << lowBit;
554-
}
555-
return std::pair{topSel, botSel};
556-
};
557-
558-
llvm::SmallSet<int32_t, 6> pairedRegBits;
559-
for (auto [rBit, lBit] : mixedTranspositions)
560-
pairedRegBits.insert(rBit);
561-
562-
// A low bit in a mixed transposition must be replaced by a high bit. The
563-
// choice of high bit can affect instruction count. If the first high bit
564-
// found when walking along `pReg` is unpaired, then that bit is the best
565-
// choice. We reorder the transpositions to guarantee this during processing.
566-
auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); };
567-
auto nextHighFree = [&](auto p) {
568-
int curr = p.first;
569-
do {
570-
if (curr >= nPack)
571-
return curr == p.first || !pairedRegBits.contains(curr);
572-
curr = next(curr);
573-
} while (curr != p.first);
574-
return false;
575-
};
576-
std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(),
577-
nextHighFree);
578-
// If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low
579-
// bit to an open high bit, then the high bit should be used as the partner.
580-
auto prev = [&](int b) {
581-
int tail = b;
582-
int curr = next(b);
583-
while (curr != b) {
584-
tail = curr;
585-
curr = next(curr);
586-
}
587-
return tail;
588-
};
589-
auto findPartner = [&](int lowBit, auto &preShufLoBits) {
590-
if (nPack == 2) {
591-
int otherLow = 1 - lowBit;
592-
int b = next(otherLow);
593-
if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) &&
594-
!pairedRegBits.contains(otherLow)) {
595-
preShufLoBits.push_back(otherLow);
596-
regBases[prev(otherLow)][0] = 1 << b;
597-
pairedRegBits.insert(b);
598-
return b;
599-
}
600-
}
601-
int potentialPartner = nPack;
602-
while (pairedRegBits.contains(potentialPartner))
603-
++potentialPartner;
604-
pairedRegBits.insert(potentialPartner);
605-
return potentialPartner;
606-
};
607-
608-
for (auto p : mixedTranspositions) {
609-
int rBit = p.first;
610-
int lBit = p.second;
611-
SmallVector<int> cycle;
612-
int currBit = rBit;
613-
do {
614-
cycle.push_back(currBit);
615-
currBit = next(currBit);
616-
} while (currBit != rBit);
617-
618-
// Find any low register bits adjacent to the excised lane bits which aren't
619-
// used in other mixed transpositions.
620-
auto isBoundary = [&](int bit) {
621-
return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit);
622-
};
623-
auto forwardEnd = llvm::find_if(cycle, isBoundary);
624-
auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary);
625-
SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd);
626-
SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd);
627-
int head;
628-
int tail;
629-
int partnerBit = -1;
630-
631-
// Case work to determine what to conjugate out.
632-
if (forwardEnd != cycle.end()) {
633-
if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) {
634-
// End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2)
635-
// No conjugation needed.
636-
head = partnerBit = *forwardEnd;
637-
} else {
638-
// End at different paired bit. E.g. (l0 r0 r1 l1 r2)
639-
// Non-leading factor in a noncommutative case.
640-
// Conjugate by first low bit in forward walk.
641-
head = postShufLoBits.front();
642-
preShufLoBits.push_back(head);
643-
postShufLoBits.resize(1);
644-
pairedRegBits.erase(head);
645-
}
646-
tail = *backwardEnd;
647-
if (tail < nPack && pairedRegBits.contains(tail)) {
648-
// Non-terminal factor in a noncommutative case.
649-
preShufLoBits.insert(preShufLoBits.begin(), tail);
650-
}
651-
} else {
652-
if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) {
653-
// Symmetric noncommutative case. E.g. (l0 r0 l1 r1)
654-
preShufLoBits.erase(preShufLoBits.begin());
655-
postShufLoBits.pop_back();
656-
pairedRegBits.erase(postShufLoBits.front());
657-
head = rBit;
658-
tail = next(rBit);
659-
} else {
660-
// Isolated low bits with single mixed transposition. E.g. (l0 r0 r1)
661-
if (postShufLoBits.size() == 2)
662-
postShufLoBits.pop_back();
663-
head = tail = preShufLoBits.front();
664-
}
665-
}
666-
667-
if (partnerBit < 0)
668-
partnerBit = findPartner(head, preShufLoBits);
669-
auto [topPostSel, botPostSel] =
670-
generateSelectors(head, tail, llvm::reverse(postShufLoBits));
671-
auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits);
672-
regBases[tail][0] = 1 << head;
673-
674-
DecomposedWarpConversion::TranspositionInfo info;
675-
info.transposition = {partnerBit, lBit};
676-
info.topPreSel = topPreSel;
677-
info.botPreSel = botPreSel;
678-
info.topPostSel = topPostSel;
679-
info.botPostSel = botPostSel;
680-
681-
// In noncommutative cases, post-shuffle selectors of non-leading terms come
682-
// from a single low bit by design, so we can determine where to insert a
683-
// non-terminal factor by examining processed selectors.
684-
if (!preShufLoBits.empty()) {
685-
uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410;
686-
auto it =
687-
llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; });
688-
ret.insert(it, info);
689-
} else {
690-
ret.push_back(info);
691-
}
692-
}
693-
if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) {
694-
// If (r0 r1) was originally in `P`, fold it into a mixed transposition.
695-
auto &t = ret.back();
696-
t.topPostSel = 0x3120;
697-
t.botPostSel = 0x7564;
698-
}
699-
return ret;
474+
return {std::move(pReg), std::move(pLane), std::move(mixedTranspositions)};
700475
}
701476

702477
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>>
@@ -994,7 +769,7 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
994769
auto kLane = StringAttr::get(ctx, "lane");
995770
if (to_vector(layout.getOutDimNames()) ==
996771
SmallVector<StringAttr, 2>{kRegister, kLane}) {
997-
auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, 32);
772+
auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy);
998773
return (factors.mixedTranspositions.size() < 2);
999774
}
1000775
return false;

0 commit comments

Comments
 (0)