Skip to content

Commit 83fbc0e

Browse files
authored
Revert byte permutes in intra-warp layout conversion (#7899)
@FrederickVu, sorry but I have to revert those 3 PRs triton-lang/triton#7809, triton-lang/triton#7825, triton-lang/triton#7861 There is a functional regression caused by triton-lang/triton#7809 but the other two PRs have many dependencies to it so I was not able to revert it cleanly separately and I couldn't manage to do a partial revert either. the follow convert layout miscompiles after with this PR: ``` #blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 8]}> %4 = ttg.convert_layout %3 : tensor<16x8xf8E5M2, #mma> -> tensor<16x8xf8E5M2, #blocked2> ``` it can be reproduced on Ampere, Hopper or Blackwell GPU (I would expect any nvidia gpu would show the problem) I can try to get a reproducer I can share later but would be nice to make a unit test for this convert layout in Gluon anyway. Happy to land those back when the bug is fixed, or if you manage to partial revert only the nvidia permute part that works too.
1 parent 2a86177 commit 83fbc0e

File tree

16 files changed

+249
-831
lines changed

16 files changed

+249
-831
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,7 @@ jobs:
120120
pytest --capture=tee-sys -rfs -n 8 python/test/gluon/
121121
122122
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
123-
pytest --capture=tee-sys -rfs -n 8 third_party/amd/python/test/ \
124-
--ignore=third_party/amd/python/test/test_scalarize_packed_fops.py \
125-
--ignore=third_party/amd/python/test/test_address_sanitizer.py
123+
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py
126124
TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py
127125
cd python/test/unit
128126
pytest --capture=tee-sys -rfs -n 12 \

include/triton/Analysis/Utility.h

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -181,22 +181,10 @@ class GatherLoweringHelper {
181181
// `pReg` and `pLane` are square layouts each with only one input and output
182182
// dimension. `mixedTranspositions` holds pairs of integers (i, j)
183183
// corresponding to the transposition (r_i l_j) of the i-th register basis
184-
// vector with the j-th lane basis vector along with 16-bit selectors for byte
185-
// permute instructions (where each of the four nybbles is in the range [0, 7]).
186-
// `nPack` gives the number of basis vectors that can be used for register
187-
// packing while ensuring packed elements arrive at the same destination lane.
184+
// vector with the j-th lane basis vector.
188185
struct DecomposedWarpConversion {
189-
struct TranspositionInfo {
190-
std::pair<int, int> transposition;
191-
uint16_t topPreSel = 0x3210;
192-
uint16_t botPreSel = 0x7654;
193-
uint16_t topPostSel = 0x3210;
194-
uint16_t botPostSel = 0x7654;
195-
};
196-
197186
triton::LinearLayout pReg, pLane;
198-
SmallVector<TranspositionInfo> mixedTranspositions;
199-
int nPack;
187+
SmallVector<std::pair<int, int>> mixedTranspositions;
200188
};
201189

202190
// Produces a decomposition of a permutation describing a warp-local layout
@@ -208,7 +196,7 @@ struct DecomposedWarpConversion {
208196
// represented as a permutation.
209197
DecomposedWarpConversion
210198
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
211-
RankedTensorType dstTy, int bitwidth);
199+
RankedTensorType dstTy);
212200

213201
// Decomposes a reshape into simpler pieces.
214202
//

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ class TargetInfoBase {
5555
virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
5656
Value i) const = 0;
5757

58-
virtual Value permute(RewriterBase &rewriter, Location loc, Value a, Value b,
59-
Value selector) const = 0;
60-
6158
virtual Value programId(RewriterBase &rewriter, Location loc,
6259
ModuleOp moduleOp, ProgramIDDim axis) const = 0;
6360

lib/Analysis/Utility.cpp

Lines changed: 12 additions & 238 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "triton/Tools/LayoutUtils.h"
1717
#include "triton/Tools/LinearLayout.h"
1818
#include "triton/Tools/Sys/GetEnv.hpp"
19-
#include "llvm/ADT/SmallSet.h"
2019

2120
namespace mlir {
2221

@@ -248,14 +247,9 @@ unsigned ScanLoweringHelper::getScratchSizeInBytes() {
248247
return elementSizeInBytes * getScratchSizeInElems();
249248
}
250249

251-
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
252-
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
253-
std::vector<std::vector<int32_t>> &regBases,
254-
int bitwidth);
255-
256250
DecomposedWarpConversion
257251
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
258-
RankedTensorType dstTy, int bitwidth) {
252+
RankedTensorType dstTy) {
259253
// Two layouts, ll_src and ll_dst, representing the same tensor can be
260254
// viewed as surjections of GF(2) vector spaces:
261255
//
@@ -284,12 +278,11 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
284278
// subsequences of consecutive lane bits from cycles involving both bit types.
285279
// Further explanation of this method is below.
286280
//
287-
// The decomposition is performed in three stages. First, we compute the
281+
// The decomposition is performed in two stages. First, we compute the
288282
// permutation matrix `P` by using `invertAndCompose` to generate a skeleton
289283
// and then fill in any zero columns. Second, we walk the cycles of `P` to
290284
// factor out mixed transpositions to build `mixedTranspositions`, `pReg`, and
291-
// `pLane`. Finally, we determine any selectors needed for byte permute
292-
// instructions in place of `selp` instructions when packing registers.
285+
// `pLane`.
293286

294287
// We remove any broadcasting in the register dimensions of the layouts before
295288
// forming the permutation `P` as the components of the decomposition directly
@@ -343,14 +336,19 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
343336
T = padWithZeros(T);
344337
}
345338

339+
// Flatten outs for ease of building `P`, and reorder outs as flattening
340+
// depends on output dimension order.
341+
if (outDimNames != llvm::to_vector(T.getOutDimNames()))
342+
T = T.transposeOuts(outDimNames);
343+
S = S.flattenOuts();
344+
T = T.flattenOuts();
345+
346346
// We compute T^transpose \circ S, which serves as a skeleton for `P`, then
347347
// fill in zero columns, prioritizing producing fixed points. As we only need
348348
// the basis vectors of `P`, we never actually produce the LinearLayout.
349349
auto pBases = S.invertAndCompose(T).getBases();
350350

351351
// Find the common and uncommon zeros of S and T
352-
S = S.flattenOuts();
353-
T = T.flattenOuts();
354352
SmallVector<std::pair<int32_t, int32_t>> srcFreeZeros;
355353
SmallVector<std::pair<int32_t, int32_t>> dstFreeZeros;
356354
for (auto [dimIdx, dim] : llvm::enumerate(inDimNames)) {
@@ -463,235 +461,11 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
463461
}
464462
assert(visited.all() && "Cycle walk incomplete");
465463

466-
// Determine degree of packing and selectors.
467-
int m = mixedTranspositions.size();
468-
int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4));
469-
int nPack = std::min(nPackPrelim, nRegBases - m);
470-
auto processedTranspos =
471-
getTranspositionSelectors(mixedTranspositions, regBases, nPack);
472-
473464
auto pReg = LinearLayout(std::move(pRegBases), {{kReg, 1 << nRegBases}},
474465
/*requireSurjective=*/true);
475466
auto pLane = LinearLayout(std::move(pLaneBases), {{kLane, 1 << nLaneBases}},
476467
/*requireSurjective=*/true);
477-
return {std::move(pReg), std::move(pLane), std::move(processedTranspos),
478-
nPack};
479-
}
480-
481-
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
482-
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
483-
std::vector<std::vector<int32_t>> &regBases,
484-
int nPack) {
485-
// When possible, we fuse permutations of 'low' register bits together
486-
// with a mixed transposition, resulting in byte permute instructions instead
487-
// of `select` instructions. After processing, no low register bits appear in
488-
// the returned list of mixed transpositions.
489-
490-
SmallVector<DecomposedWarpConversion::TranspositionInfo> ret;
491-
ret.reserve(mixedTranspositions.size());
492-
if (nPack == 0) {
493-
for (auto &t : mixedTranspositions)
494-
ret.push_back(DecomposedWarpConversion::TranspositionInfo{t});
495-
return ret;
496-
}
497-
// Consider for example the cycle
498-
//
499-
// (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3)
500-
// = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2)
501-
//
502-
// with `nPack` = 2 so that r0 and r1 are considered low bits. We want to
503-
// factor out any low bits from `pReg` and to incorporate them into the data
504-
// of the mixed transposition. After processing, the contribution to `pReg`
505-
// is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with
506-
// the effects of (r3 r0) and (r3 r1) encoded in the returned selectors.
507-
// In general, low bits occurring immediately before l_j modify the selectors
508-
// of the `prmt` before the shuffle, while low bits occurring immediately
509-
// after l_k modify the selectors of the `prmt` after the shuffle. Unmodified
510-
// selectors correspond to `select` instructions.
511-
// Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is
512-
// not used in another mixed transposition and conjugating out a low bit:
513-
//
514-
// (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1)
515-
// = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1).
516-
//
517-
// Conjugation does not affect `pReg`. However, the set of fused mixed and
518-
// low-bit transpositions is noncommutative in cases where there are no
519-
// intervening high bits in between distinct sequences of lane bits as the
520-
// paired low bit is used in modifying the selectors of both factors:
521-
//
522-
// (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0).
523-
//
524-
// The `*` is standard composition of permutations. The groupings correspond
525-
// to different `TranspositionInfo` objects. For example, the permutation
526-
// `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with
527-
// pre- and post-shuffle selectors determined by the `r0` bit.
528-
// Processing of mixed transpositions is performed by determining the `head`
529-
// and `tail` of an excision of bits in cycles of `pReg` and building lists
530-
// of low bits acting as selector modifiers. In the noncommutative cases, we
531-
// opt to restrict the number of post-shuffle modifiers to one.
532-
533-
auto permuteSelector = [nPack](uint16_t sel, int bitIdx) {
534-
int lo = bitIdx + (2 - nPack);
535-
uint16_t maskHi = 0x4444;
536-
uint16_t maskLo = 0x1111 << lo;
537-
uint16_t fixed = sel & ~maskHi & ~maskLo;
538-
int shift = 2 - lo;
539-
return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift);
540-
};
541-
auto generateSelectors = [&](int head, int tail, auto &&lowBits) {
542-
uint16_t topSel = 0x3210;
543-
uint16_t botSel = 0x7654;
544-
for (auto lowBit : lowBits) {
545-
topSel = permuteSelector(topSel, lowBit);
546-
botSel = permuteSelector(botSel, lowBit);
547-
if (lowBit != head && lowBit != tail)
548-
regBases[lowBit][0] = 1 << lowBit;
549-
}
550-
return std::pair{topSel, botSel};
551-
};
552-
553-
llvm::SmallSet<int32_t, 6> pairedRegBits;
554-
for (auto [rBit, lBit] : mixedTranspositions)
555-
pairedRegBits.insert(rBit);
556-
557-
// A low bit in a mixed transposition must be replaced by a high bit. The
558-
// choice of high bit can affect instruction count. If the first high bit
559-
// found when walking along `pReg` is unpaired, then that bit is the best
560-
// choice. We reorder the transpositions to guarantee this during processing.
561-
auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); };
562-
auto nextHighFree = [&](auto p) {
563-
int curr = p.first;
564-
do {
565-
if (curr >= nPack)
566-
return curr == p.first || !pairedRegBits.contains(curr);
567-
curr = next(curr);
568-
} while (curr != p.first);
569-
return false;
570-
};
571-
std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(),
572-
nextHighFree);
573-
// If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low
574-
// bit to an open high bit, then the high bit should be used as the partner.
575-
auto prev = [&](int b) {
576-
int tail = b;
577-
int curr = next(b);
578-
while (curr != b) {
579-
tail = curr;
580-
curr = next(curr);
581-
}
582-
return tail;
583-
};
584-
auto findPartner = [&](int lowBit, auto &preShufLoBits) {
585-
if (nPack == 2) {
586-
int otherLow = 1 - lowBit;
587-
int b = next(otherLow);
588-
if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) &&
589-
!pairedRegBits.contains(otherLow)) {
590-
preShufLoBits.push_back(otherLow);
591-
regBases[prev(otherLow)][0] = 1 << b;
592-
pairedRegBits.insert(b);
593-
return b;
594-
}
595-
}
596-
int potentialPartner = nPack;
597-
while (pairedRegBits.contains(potentialPartner))
598-
++potentialPartner;
599-
pairedRegBits.insert(potentialPartner);
600-
return potentialPartner;
601-
};
602-
603-
for (auto p : mixedTranspositions) {
604-
int rBit = p.first;
605-
int lBit = p.second;
606-
SmallVector<int> cycle;
607-
int currBit = rBit;
608-
do {
609-
cycle.push_back(currBit);
610-
currBit = next(currBit);
611-
} while (currBit != rBit);
612-
613-
// Find any low register bits adjacent to the excised lane bits which aren't
614-
// used in other mixed transpositions.
615-
auto isBoundary = [&](int bit) {
616-
return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit);
617-
};
618-
auto forwardEnd = llvm::find_if(cycle, isBoundary);
619-
auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary);
620-
SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd);
621-
SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd);
622-
int head;
623-
int tail;
624-
int partnerBit = -1;
625-
626-
// Case work to determine what to conjugate out.
627-
if (forwardEnd != cycle.end()) {
628-
if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) {
629-
// End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2)
630-
// No conjugation needed.
631-
head = partnerBit = *forwardEnd;
632-
} else {
633-
// End at different paired bit. E.g. (l0 r0 r1 l1 r2)
634-
// Non-leading factor in a noncommutative case.
635-
// Conjugate by first low bit in forward walk.
636-
head = postShufLoBits.front();
637-
preShufLoBits.push_back(head);
638-
postShufLoBits.resize(1);
639-
pairedRegBits.erase(head);
640-
}
641-
tail = *backwardEnd;
642-
if (tail < nPack && pairedRegBits.contains(tail)) {
643-
// Non-terminal factor in a noncommutative case.
644-
preShufLoBits.insert(preShufLoBits.begin(), tail);
645-
}
646-
} else {
647-
if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) {
648-
// Symmetric noncommutative case. E.g. (l0 r0 l1 r1)
649-
preShufLoBits.erase(preShufLoBits.begin());
650-
postShufLoBits.pop_back();
651-
pairedRegBits.erase(postShufLoBits.front());
652-
head = rBit;
653-
tail = next(rBit);
654-
} else {
655-
// Isolated low bits with single mixed transposition. E.g. (l0 r0 r1)
656-
if (postShufLoBits.size() == 2)
657-
postShufLoBits.pop_back();
658-
head = tail = preShufLoBits.front();
659-
}
660-
}
661-
662-
if (partnerBit < 0)
663-
partnerBit = findPartner(head, preShufLoBits);
664-
auto [topPostSel, botPostSel] =
665-
generateSelectors(head, tail, llvm::reverse(postShufLoBits));
666-
auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits);
667-
regBases[tail][0] = 1 << head;
668-
669-
DecomposedWarpConversion::TranspositionInfo info;
670-
info.transposition = {partnerBit, lBit};
671-
info.topPreSel = topPreSel;
672-
info.botPreSel = botPreSel;
673-
info.topPostSel = topPostSel;
674-
info.botPostSel = botPostSel;
675-
676-
// In noncommutative cases, post-shuffle selectors of non-leading terms come
677-
// from a single low bit by design, so we can determine where to insert a
678-
// non-terminal factor by examining processed selectors.
679-
if (!preShufLoBits.empty()) {
680-
uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410;
681-
auto it =
682-
llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; });
683-
ret.insert(it, info);
684-
} else {
685-
ret.push_back(info);
686-
}
687-
}
688-
if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) {
689-
// If (r0 r1) was originally in `P`, fold it into a mixed transposition.
690-
auto &t = ret.back();
691-
t.topPostSel = 0x3120;
692-
t.botPostSel = 0x7564;
693-
}
694-
return ret;
468+
return {std::move(pReg), std::move(pLane), std::move(mixedTranspositions)};
695469
}
696470

697471
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>>
@@ -989,7 +763,7 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
989763
auto kLane = StringAttr::get(ctx, "lane");
990764
if (to_vector(layout.getOutDimNames()) ==
991765
SmallVector<StringAttr, 2>{kRegister, kLane}) {
992-
auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, 32);
766+
auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy);
993767
return (factors.mixedTranspositions.size() < 2);
994768
}
995769
return false;

0 commit comments

Comments
 (0)