Skip to content

Commit a7b1123

Browse files
Merge commit '58ae6f0e9a35c5527c198d6e8aaf8a57f05c13c8'
2 parents 6effefb + 58ae6f0 commit a7b1123

File tree

12 files changed

+451
-157
lines changed

12 files changed

+451
-157
lines changed

include/triton/Analysis/Utility.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,19 @@ class GatherLoweringHelper {
181181
// `pReg` and `pLane` are square layouts each with only one input and output
182182
// dimension. `mixedTranspositions` holds pairs of integers (i, j)
183183
// corresponding to the transposition (r_i l_j) of the i-th register basis
184-
// vector with the j-th lane basis vector.
184+
// vector with the j-th lane basis vector along with 16-bit selectors for byte
185+
// permute instructions (where each of the four nybbles is in the range [0, 7]).
185186
struct DecomposedWarpConversion {
187+
struct TranspositionInfo {
188+
std::pair<int, int> transposition;
189+
uint16_t topPreSel = 0x3210;
190+
uint16_t botPreSel = 0x7654;
191+
uint16_t topPostSel = 0x3210;
192+
uint16_t botPostSel = 0x7654;
193+
};
194+
186195
triton::LinearLayout pReg, pLane;
187-
SmallVector<std::pair<int, int>> mixedTranspositions;
196+
SmallVector<TranspositionInfo> mixedTranspositions;
188197
};
189198

190199
// Produces a decomposition of a permutation describing a warp-local layout
@@ -196,7 +205,7 @@ struct DecomposedWarpConversion {
196205
// represented as a permutation.
197206
DecomposedWarpConversion
198207
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
199-
RankedTensorType dstTy);
208+
RankedTensorType dstTy, int bitwidth);
200209

201210
// Decomposes a reshape into simpler pieces.
202211
//

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class TargetInfoBase {
4848
virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
4949
Value i) const = 0;
5050

51+
virtual Value permute(RewriterBase &rewriter, Location loc, Value a, Value b,
52+
Value selector) const = 0;
53+
5154
virtual Value programId(RewriterBase &rewriter, Location loc,
5255
ModuleOp moduleOp, ProgramIDDim axis) const = 0;
5356

lib/Analysis/Utility.cpp

Lines changed: 237 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "triton/Tools/LayoutUtils.h"
1818
#include "triton/Tools/LinearLayout.h"
1919
#include "triton/Tools/Sys/GetEnv.hpp"
20+
#include "llvm/ADT/SmallSet.h"
2021

2122
namespace mlir {
2223

@@ -253,9 +254,14 @@ unsigned ScanLoweringHelper::getScratchSizeInBytes() {
253254
return elementSizeInBytes * getScratchSizeInElems();
254255
}
255256

257+
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
258+
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
259+
std::vector<std::vector<int32_t>> &regBases,
260+
int bitwidth);
261+
256262
DecomposedWarpConversion
257263
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
258-
RankedTensorType dstTy) {
264+
RankedTensorType dstTy, int bitwidth) {
259265
// Two layouts, ll_src and ll_dst, representing the same tensor can be
260266
// viewed as surjections of GF(2) vector spaces:
261267
//
@@ -284,11 +290,12 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
284290
// subsequences of consecutive lane bits from cycles involving both bit types.
285291
// Further explanation of this method is below.
286292
//
287-
// The decomposition is performed in two stages. First, we compute the
293+
// The decomposition is performed in three stages. First, we compute the
288294
// permutation matrix `P` by using `invertAndCompose` to generate a skeleton
289295
// and then fill in any zero columns. Second, we walk the cycles of `P` to
290296
// factor out mixed transpositions to build `mixedTranspositions`, `pReg`, and
291-
// `pLane`.
297+
// `pLane`. Finally, we determine any selectors needed for byte permute
298+
// instructions in place of `selp` instructions when packing registers.
292299

293300
// We remove any broadcasting in the register dimensions of the layouts before
294301
// forming the permutation `P` as the components of the decomposition directly
@@ -342,19 +349,14 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
342349
T = padWithZeros(T);
343350
}
344351

345-
// Flatten outs for ease of building `P`, and reorder outs as flattening
346-
// depends on output dimension order.
347-
if (outDimNames != llvm::to_vector(T.getOutDimNames()))
348-
T = T.transposeOuts(outDimNames);
349-
S = S.flattenOuts();
350-
T = T.flattenOuts();
351-
352352
// We compute T^transpose \circ S, which serves as a skeleton for `P`, then
353353
// fill in zero columns, prioritizing producing fixed points. As we only need
354354
// the basis vectors of `P`, we never actually produce the LinearLayout.
355355
auto pBases = S.invertAndCompose(T).getBases();
356356

357357
// Find the common and uncommon zeros of S and T
358+
S = S.flattenOuts();
359+
T = T.flattenOuts();
358360
SmallVector<std::pair<int32_t, int32_t>> srcFreeZeros;
359361
SmallVector<std::pair<int32_t, int32_t>> dstFreeZeros;
360362
for (auto [dimIdx, dim] : llvm::enumerate(inDimNames)) {
@@ -467,11 +469,234 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
467469
}
468470
assert(visited.all() && "Cycle walk incomplete");
469471

472+
auto processedTranspos =
473+
getTranspositionSelectors(mixedTranspositions, regBases, bitwidth);
474+
470475
auto pReg = LinearLayout(std::move(pRegBases), {{kReg, 1 << nRegBases}},
471476
/*requireSurjective=*/true);
472477
auto pLane = LinearLayout(std::move(pLaneBases), {{kLane, 1 << nLaneBases}},
473478
/*requireSurjective=*/true);
474-
return {std::move(pReg), std::move(pLane), std::move(mixedTranspositions)};
479+
return {std::move(pReg), std::move(pLane), std::move(processedTranspos)};
480+
}
481+
482+
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
483+
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
484+
std::vector<std::vector<int32_t>> &regBases,
485+
int bitwidth) {
486+
// When possible, we fuse permutations of 'low' register bits together
487+
// with a mixed transposition, resulting in byte permute instructions instead
488+
// of `select` instructions. After processing, no low register bits appear in
489+
// the returned list of mixed transpositions.
490+
int m = mixedTranspositions.size();
491+
int nRegBases = regBases.size();
492+
int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4));
493+
int nPack = std::min(nPackPrelim, nRegBases - m);
494+
495+
SmallVector<DecomposedWarpConversion::TranspositionInfo> ret;
496+
ret.reserve(mixedTranspositions.size());
497+
if (nPack == 0) {
498+
for (auto &t : mixedTranspositions)
499+
ret.push_back(DecomposedWarpConversion::TranspositionInfo{t});
500+
return ret;
501+
}
502+
// Consider for example the cycle
503+
//
504+
// (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3)
505+
// = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2)
506+
//
507+
// with `nPack` = 2 so that r0 and r1 are considered low bits. We want to
508+
// factor out any low bits from `pReg` and to incorporate them into the data
509+
// of the mixed transposition. After processing, the contribution to `pReg`
510+
// is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with
511+
// the effects of (r3 r0) and (r3 r1) encoded in the returned selectors.
512+
// In general, low bits occurring immediately before l_j modify the selectors
513+
// of the `prmt` before the shuffle, while low bits occurring immediately
514+
// after l_k modify the selectors of the `prmt` after the shuffle. Unmodified
515+
// selectors correspond to `select` instructions.
516+
// Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is
517+
// not used in another mixed transposition and conjugating out a low bit:
518+
//
519+
// (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1)
520+
// = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1).
521+
//
522+
// Conjugation does not affect `pReg`. However, the set of fused mixed and
523+
// low-bit transpositions is noncommutative in cases where there are no
524+
// intervening high bits in between distinct sequences of lane bits as the
525+
// paired low bit is used in modifying the selectors of both factors:
526+
//
527+
// (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0).
528+
//
529+
// The `*` is standard composition of permutations. The groupings correspond
530+
// to different `TranspositionInfo` objects. For example, the permutation
531+
// `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with
532+
// pre- and post-shuffle selectors determined by the `r0` bit.
533+
// Processing of mixed transpositions is performed by determining the `head`
534+
// and `tail` of an excision of bits in cycles of `pReg` and building lists
535+
// of low bits acting as selector modifiers. In the noncommutative cases, we
536+
// opt to restrict the number of post-shuffle modifiers to one.
537+
538+
auto permuteSelector = [nPack](uint16_t sel, int bitIdx) {
539+
int lo = bitIdx + (2 - nPack);
540+
uint16_t maskHi = 0x4444;
541+
uint16_t maskLo = 0x1111 << lo;
542+
uint16_t fixed = sel & ~maskHi & ~maskLo;
543+
int shift = 2 - lo;
544+
return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift);
545+
};
546+
auto generateSelectors = [&](int head, int tail, auto &&lowBits) {
547+
uint16_t topSel = 0x3210;
548+
uint16_t botSel = 0x7654;
549+
for (auto lowBit : lowBits) {
550+
topSel = permuteSelector(topSel, lowBit);
551+
botSel = permuteSelector(botSel, lowBit);
552+
if (lowBit != head && lowBit != tail)
553+
regBases[lowBit][0] = 1 << lowBit;
554+
}
555+
return std::pair{topSel, botSel};
556+
};
557+
558+
llvm::SmallSet<int32_t, 6> pairedRegBits;
559+
for (auto [rBit, lBit] : mixedTranspositions)
560+
pairedRegBits.insert(rBit);
561+
562+
// A low bit in a mixed transposition must be replaced by a high bit. The
563+
// choice of high bit can affect instruction count. If the first high bit
564+
// found when walking along `pReg` is unpaired, then that bit is the best
565+
// choice. We reorder the transpositions to guarantee this during processing.
566+
auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); };
567+
auto nextHighFree = [&](auto p) {
568+
int curr = p.first;
569+
do {
570+
if (curr >= nPack)
571+
return curr == p.first || !pairedRegBits.contains(curr);
572+
curr = next(curr);
573+
} while (curr != p.first);
574+
return false;
575+
};
576+
std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(),
577+
nextHighFree);
578+
// If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low
579+
// bit to an open high bit, then the high bit should be used as the partner.
580+
auto prev = [&](int b) {
581+
int tail = b;
582+
int curr = next(b);
583+
while (curr != b) {
584+
tail = curr;
585+
curr = next(curr);
586+
}
587+
return tail;
588+
};
589+
auto findPartner = [&](int lowBit, auto &preShufLoBits) {
590+
if (nPack == 2) {
591+
int otherLow = 1 - lowBit;
592+
int b = next(otherLow);
593+
if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) &&
594+
!pairedRegBits.contains(otherLow)) {
595+
preShufLoBits.push_back(otherLow);
596+
regBases[prev(otherLow)][0] = 1 << b;
597+
pairedRegBits.insert(b);
598+
return b;
599+
}
600+
}
601+
int potentialPartner = nPack;
602+
while (pairedRegBits.contains(potentialPartner))
603+
++potentialPartner;
604+
pairedRegBits.insert(potentialPartner);
605+
return potentialPartner;
606+
};
607+
608+
for (auto p : mixedTranspositions) {
609+
int rBit = p.first;
610+
int lBit = p.second;
611+
SmallVector<int> cycle;
612+
int currBit = rBit;
613+
do {
614+
cycle.push_back(currBit);
615+
currBit = next(currBit);
616+
} while (currBit != rBit);
617+
618+
// Find any low register bits adjacent to the excised lane bits which aren't
619+
// used in other mixed transpositions.
620+
auto isBoundary = [&](int bit) {
621+
return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit);
622+
};
623+
auto forwardEnd = llvm::find_if(cycle, isBoundary);
624+
auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary);
625+
SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd);
626+
SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd);
627+
int head;
628+
int tail;
629+
int partnerBit = -1;
630+
631+
// Case work to determine what to conjugate out.
632+
if (forwardEnd != cycle.end()) {
633+
if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) {
634+
// End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2)
635+
// No conjugation needed.
636+
head = partnerBit = *forwardEnd;
637+
} else {
638+
// End at different paired bit. E.g. (l0 r0 r1 l1 r2)
639+
// Non-leading factor in a noncommutative case.
640+
// Conjugate by first low bit in forward walk.
641+
head = postShufLoBits.front();
642+
preShufLoBits.push_back(head);
643+
postShufLoBits.resize(1);
644+
pairedRegBits.erase(head);
645+
}
646+
tail = *backwardEnd;
647+
if (tail < nPack && pairedRegBits.contains(tail)) {
648+
// Non-terminal factor in a noncommutative case.
649+
preShufLoBits.insert(preShufLoBits.begin(), tail);
650+
}
651+
} else {
652+
if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) {
653+
// Symmetric noncommutative case. E.g. (l0 r0 l1 r1)
654+
preShufLoBits.erase(preShufLoBits.begin());
655+
postShufLoBits.pop_back();
656+
pairedRegBits.erase(postShufLoBits.front());
657+
head = rBit;
658+
tail = next(rBit);
659+
} else {
660+
// Isolated low bits with single mixed transposition. E.g. (l0 r0 r1)
661+
if (postShufLoBits.size() == 2)
662+
postShufLoBits.pop_back();
663+
head = tail = preShufLoBits.front();
664+
}
665+
}
666+
667+
if (partnerBit < 0)
668+
partnerBit = findPartner(head, preShufLoBits);
669+
auto [topPostSel, botPostSel] =
670+
generateSelectors(head, tail, llvm::reverse(postShufLoBits));
671+
auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits);
672+
regBases[tail][0] = 1 << head;
673+
674+
DecomposedWarpConversion::TranspositionInfo info;
675+
info.transposition = {partnerBit, lBit};
676+
info.topPreSel = topPreSel;
677+
info.botPreSel = botPreSel;
678+
info.topPostSel = topPostSel;
679+
info.botPostSel = botPostSel;
680+
681+
// In noncommutative cases, post-shuffle selectors of non-leading terms come
682+
// from a single low bit by design, so we can determine where to insert a
683+
// non-terminal factor by examining processed selectors.
684+
if (!preShufLoBits.empty()) {
685+
uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410;
686+
auto it =
687+
llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; });
688+
ret.insert(it, info);
689+
} else {
690+
ret.push_back(info);
691+
}
692+
}
693+
if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) {
694+
// If (r0 r1) was originally in `P`, fold it into a mixed transposition.
695+
auto &t = ret.back();
696+
t.topPostSel = 0x3120;
697+
t.botPostSel = 0x7564;
698+
}
699+
return ret;
475700
}
476701

477702
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>>
@@ -769,7 +994,7 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
769994
auto kLane = StringAttr::get(ctx, "lane");
770995
if (to_vector(layout.getOutDimNames()) ==
771996
SmallVector<StringAttr, 2>{kRegister, kLane}) {
772-
auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy);
997+
auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, 32);
773998
return (factors.mixedTranspositions.size() < 2);
774999
}
7751000
return false;

0 commit comments

Comments
 (0)