Skip to content

Commit 37866dc

Browse files
Merge OpenAI Triton commit 83fbc0e (#5037)
This PR change the Triton base from 2a86177 to 83fbc0e (Aug 18). Pass rate: 98.85%
2 parents 6895cf5 + 3fb537e commit 37866dc

File tree

12 files changed

+3
-291
lines changed

12 files changed

+3
-291
lines changed

include/triton/Analysis/Utility.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -181,17 +181,8 @@ class GatherLoweringHelper {
181181
// `pReg` and `pLane` are square layouts each with only one input and output
182182
// dimension. `mixedTranspositions` holds pairs of integers (i, j)
183183
// corresponding to the transposition (r_i l_j) of the i-th register basis
184-
// vector with the j-th lane basis vector along with 16-bit selectors for byte
185-
// permute instructions (where each of the four nybbles is in the range [0, 7]).
184+
// vector with the j-th lane basis vector.
186185
struct DecomposedWarpConversion {
187-
struct TranspositionInfo {
188-
std::pair<int, int> transposition;
189-
uint16_t topPreSel = 0x3210;
190-
uint16_t botPreSel = 0x7654;
191-
uint16_t topPostSel = 0x3210;
192-
uint16_t botPostSel = 0x7654;
193-
};
194-
195186
triton::LinearLayout pReg, pLane;
196187
SmallVector<std::pair<int, int>> mixedTranspositions;
197188
};

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ class TargetInfoBase {
5555
virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
5656
Value i) const = 0;
5757

58-
virtual Value permute(RewriterBase &rewriter, Location loc, Value a, Value b,
59-
Value selector) const = 0;
60-
6158
virtual Value programId(RewriterBase &rewriter, Location loc,
6259
ModuleOp moduleOp, ProgramIDDim axis) const = 0;
6360

lib/Analysis/Utility.cpp

Lines changed: 0 additions & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "triton/Tools/LayoutUtils.h"
1818
#include "triton/Tools/LinearLayout.h"
1919
#include "triton/Tools/Sys/GetEnv.hpp"
20-
#include "llvm/ADT/SmallSet.h"
2120

2221
namespace mlir {
2322

@@ -254,11 +253,6 @@ unsigned ScanLoweringHelper::getScratchSizeInBytes() {
254253
return elementSizeInBytes * getScratchSizeInElems();
255254
}
256255

257-
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
258-
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
259-
std::vector<std::vector<int32_t>> &regBases,
260-
int bitwidth);
261-
262256
DecomposedWarpConversion
263257
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
264258
RankedTensorType dstTy) {
@@ -480,226 +474,6 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
480474
return {std::move(pReg), std::move(pLane), std::move(mixedTranspositions)};
481475
}
482476

483-
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
484-
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
485-
std::vector<std::vector<int32_t>> &regBases,
486-
int bitwidth) {
487-
// When possible, we fuse permutations of 'low' register bits together
488-
// with a mixed transposition, resulting in byte permute instructions instead
489-
// of `select` instructions. After processing, no low register bits appear in
490-
// the returned list of mixed transpositions.
491-
int m = mixedTranspositions.size();
492-
int nRegBases = regBases.size();
493-
int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4));
494-
int nPack = std::min(nPackPrelim, nRegBases - m);
495-
496-
SmallVector<DecomposedWarpConversion::TranspositionInfo> ret;
497-
ret.reserve(mixedTranspositions.size());
498-
if (nPack == 0) {
499-
for (auto &t : mixedTranspositions)
500-
ret.push_back(DecomposedWarpConversion::TranspositionInfo{t});
501-
return ret;
502-
}
503-
// Consider for example the cycle
504-
//
505-
// (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3)
506-
// = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2)
507-
//
508-
// with `nPack` = 2 so that r0 and r1 are considered low bits. We want to
509-
// factor out any low bits from `pReg` and to incorporate them into the data
510-
// of the mixed transposition. After processing, the contribution to `pReg`
511-
// is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with
512-
// the effects of (r3 r0) and (r3 r1) encoded in the returned selectors.
513-
// In general, low bits occurring immediately before l_j modify the selectors
514-
// of the `prmt` before the shuffle, while low bits occurring immediately
515-
// after l_k modify the selectors of the `prmt` after the shuffle. Unmodified
516-
// selectors correspond to `select` instructions.
517-
// Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is
518-
// not used in another mixed transposition and conjugating out a low bit:
519-
//
520-
// (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1)
521-
// = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1).
522-
//
523-
// Conjugation does not affect `pReg`. However, the set of fused mixed and
524-
// low-bit transpositions is noncommutative in cases where there are no
525-
// intervening high bits in between distinct sequences of lane bits as the
526-
// paired low bit is used in modifying the selectors of both factors:
527-
//
528-
// (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0).
529-
//
530-
// The `*` is standard composition of permutations. The groupings correspond
531-
// to different `TranspositionInfo` objects. For example, the permutation
532-
// `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with
533-
// pre- and post-shuffle selectors determined by the `r0` bit.
534-
// Processing of mixed transpositions is performed by determining the `head`
535-
// and `tail` of an excision of bits in cycles of `pReg` and building lists
536-
// of low bits acting as selector modifiers. In the noncommutative cases, we
537-
// opt to restrict the number of post-shuffle modifiers to one.
538-
539-
auto permuteSelector = [nPack](uint16_t sel, int bitIdx) {
540-
int lo = bitIdx + (2 - nPack);
541-
uint16_t maskHi = 0x4444;
542-
uint16_t maskLo = 0x1111 << lo;
543-
uint16_t fixed = sel & ~maskHi & ~maskLo;
544-
int shift = 2 - lo;
545-
return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift);
546-
};
547-
auto generateSelectors = [&](int head, int tail, auto &&lowBits) {
548-
uint16_t topSel = 0x3210;
549-
uint16_t botSel = 0x7654;
550-
for (auto lowBit : lowBits) {
551-
topSel = permuteSelector(topSel, lowBit);
552-
botSel = permuteSelector(botSel, lowBit);
553-
if (lowBit != head && lowBit != tail)
554-
regBases[lowBit][0] = 1 << lowBit;
555-
}
556-
return std::pair{topSel, botSel};
557-
};
558-
559-
llvm::SmallSet<int32_t, 6> pairedRegBits;
560-
for (auto [rBit, lBit] : mixedTranspositions)
561-
pairedRegBits.insert(rBit);
562-
563-
// A low bit in a mixed transposition must be replaced by a high bit. The
564-
// choice of high bit can affect instruction count. If the first high bit
565-
// found when walking along `pReg` is unpaired, then that bit is the best
566-
// choice. We reorder the transpositions to guarantee this during processing.
567-
auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); };
568-
auto nextHighFree = [&](auto p) {
569-
int curr = p.first;
570-
do {
571-
if (curr >= nPack)
572-
return curr == p.first || !pairedRegBits.contains(curr);
573-
curr = next(curr);
574-
} while (curr != p.first);
575-
return false;
576-
};
577-
std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(),
578-
nextHighFree);
579-
// If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low
580-
// bit to an open high bit, then the high bit should be used as the partner.
581-
auto prev = [&](int b) {
582-
int tail = b;
583-
int curr = next(b);
584-
while (curr != b) {
585-
tail = curr;
586-
curr = next(curr);
587-
}
588-
return tail;
589-
};
590-
auto findPartner = [&](int lowBit, auto &preShufLoBits) {
591-
if (nPack == 2) {
592-
int otherLow = 1 - lowBit;
593-
int b = next(otherLow);
594-
if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) &&
595-
!pairedRegBits.contains(otherLow)) {
596-
preShufLoBits.push_back(otherLow);
597-
regBases[prev(otherLow)][0] = 1 << b;
598-
pairedRegBits.insert(b);
599-
return b;
600-
}
601-
}
602-
int potentialPartner = nPack;
603-
while (pairedRegBits.contains(potentialPartner))
604-
++potentialPartner;
605-
pairedRegBits.insert(potentialPartner);
606-
return potentialPartner;
607-
};
608-
609-
for (auto p : mixedTranspositions) {
610-
int rBit = p.first;
611-
int lBit = p.second;
612-
SmallVector<int> cycle;
613-
int currBit = rBit;
614-
do {
615-
cycle.push_back(currBit);
616-
currBit = next(currBit);
617-
} while (currBit != rBit);
618-
619-
// Find any low register bits adjacent to the excised lane bits which aren't
620-
// used in other mixed transpositions.
621-
auto isBoundary = [&](int bit) {
622-
return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit);
623-
};
624-
auto forwardEnd = llvm::find_if(cycle, isBoundary);
625-
auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary);
626-
SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd);
627-
SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd);
628-
int head;
629-
int tail;
630-
int partnerBit = -1;
631-
632-
// Case work to determine what to conjugate out.
633-
if (forwardEnd != cycle.end()) {
634-
if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) {
635-
// End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2)
636-
// No conjugation needed.
637-
head = partnerBit = *forwardEnd;
638-
} else {
639-
// End at different paired bit. E.g. (l0 r0 r1 l1 r2)
640-
// Non-leading factor in a noncommutative case.
641-
// Conjugate by first low bit in forward walk.
642-
head = postShufLoBits.front();
643-
preShufLoBits.push_back(head);
644-
postShufLoBits.resize(1);
645-
pairedRegBits.erase(head);
646-
}
647-
tail = *backwardEnd;
648-
if (tail < nPack && pairedRegBits.contains(tail)) {
649-
// Non-terminal factor in a noncommutative case.
650-
preShufLoBits.insert(preShufLoBits.begin(), tail);
651-
}
652-
} else {
653-
if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) {
654-
// Symmetric noncommutative case. E.g. (l0 r0 l1 r1)
655-
preShufLoBits.erase(preShufLoBits.begin());
656-
postShufLoBits.pop_back();
657-
pairedRegBits.erase(postShufLoBits.front());
658-
head = rBit;
659-
tail = next(rBit);
660-
} else {
661-
// Isolated low bits with single mixed transposition. E.g. (l0 r0 r1)
662-
if (postShufLoBits.size() == 2)
663-
postShufLoBits.pop_back();
664-
head = tail = preShufLoBits.front();
665-
}
666-
}
667-
668-
if (partnerBit < 0)
669-
partnerBit = findPartner(head, preShufLoBits);
670-
auto [topPostSel, botPostSel] =
671-
generateSelectors(head, tail, llvm::reverse(postShufLoBits));
672-
auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits);
673-
regBases[tail][0] = 1 << head;
674-
675-
DecomposedWarpConversion::TranspositionInfo info;
676-
info.transposition = {partnerBit, lBit};
677-
info.topPreSel = topPreSel;
678-
info.botPreSel = botPreSel;
679-
info.topPostSel = topPostSel;
680-
info.botPostSel = botPostSel;
681-
682-
// In noncommutative cases, post-shuffle selectors of non-leading terms come
683-
// from a single low bit by design, so we can determine where to insert a
684-
// non-terminal factor by examining processed selectors.
685-
if (!preShufLoBits.empty()) {
686-
uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410;
687-
auto it =
688-
llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; });
689-
ret.insert(it, info);
690-
} else {
691-
ret.push_back(info);
692-
}
693-
}
694-
if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) {
695-
// If (r0 r1) was originally in `P`, fold it into a mixed transposition.
696-
auto &t = ret.back();
697-
t.topPostSel = 0x3120;
698-
t.botPostSel = 0x7564;
699-
}
700-
return ret;
701-
}
702-
703477
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>>
704478
getReshapeDecomposition(ArrayRef<int64_t> srcShape,
705479
ArrayRef<int64_t> dstShape) {

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,6 @@ Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
148148
return LLVM::AMD::shuffleIdx(loc, rewriter, val, i, getISAFamily());
149149
}
150150

151-
Value TargetInfo::permute(RewriterBase &rewriter, Location loc, Value a,
152-
Value b, Value selector) const {
153-
// Warning: The `a` and `b` operands are ordered to align with Nvidia's `prmt`
154-
// Both use little-endian ordering, but AMD puts the MSBs of the data in the
155-
// 0-th operand.
156-
return LLVM::AMD::permute(loc, rewriter, b, a, selector);
157-
}
158-
159151
Value TargetInfo::programId(RewriterBase &rewriter, Location loc,
160152
ModuleOp moduleOp, ProgramIDDim axis) const {
161153
return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis);

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
4646
Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
4747
Value i) const override;
4848

49-
Value permute(RewriterBase &rewriter, Location loc, Value a, Value b,
50-
Value selector) const override;
51-
5249
Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp,
5350
ProgramIDDim axis) const override;
5451

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -284,23 +284,6 @@ Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i,
284284
b.i32_val(0x1f));
285285
}
286286

287-
Value permute(Location loc, RewriterBase &rewriter, Value x, Value y,
288-
Value selector) {
289-
auto b = TritonLLVMOpBuilder(loc, rewriter);
290-
Value prmt_mask = selector;
291-
// convert from nybble mask to byte mask:
292-
prmt_mask =
293-
b.or_(b.and_(prmt_mask, b.i32_val(0x000000ff)),
294-
b.shl(b.and_(prmt_mask, b.i32_val(0x0000ff00)), b.i32_val(8)));
295-
prmt_mask =
296-
b.or_(b.and_(prmt_mask, b.i32_val(0x000f000f)),
297-
b.shl(b.and_(prmt_mask, b.i32_val(0x00f000f0)), b.i32_val(4)));
298-
Value args[] = {x, y, prmt_mask};
299-
auto op = createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.perm", i32_ty,
300-
args);
301-
return op.getResult(0);
302-
}
303-
304287
Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
305288
ProgramIDDim axis) {
306289
Value blockId =

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@ Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i,
3737
mlir::triton::AMD::ISAFamily isaFamily =
3838
mlir::triton::AMD::ISAFamily::Unknown);
3939

40-
Value permute(Location loc, RewriterBase &rewriter, Value a, Value b,
41-
Value selector);
42-
4340
Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
4441
ProgramIDDim axis);
4542

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,6 @@ Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
9191
return LLVM::intel::shuffleIdx(loc, rewriter, val, i);
9292
}
9393

94-
Value TargetInfo::permute(RewriterBase &rewriter, Location loc, Value a,
95-
Value b, Value selector) const {
96-
// Warning: The `a` and `b` operands are ordered to align with Nvidia's `prmt`
97-
// Both use little-endian ordering, but AMD puts the MSBs of the data in the
98-
// 0-th operand.
99-
return LLVM::intel::permute(loc, rewriter, b, a, selector);
100-
}
101-
10294
Value TargetInfo::programId(RewriterBase &rewriter, Location loc,
10395
ModuleOp moduleOp, ProgramIDDim axis) const {
10496
Value blockId =

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
4444
Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
4545
Value i) const override;
4646

47-
Value permute(RewriterBase &rewriter, Location loc, Value a, Value b,
48-
Value selector) const override;
49-
5047
Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp,
5148
ProgramIDDim axis) const override;
5249

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -445,11 +445,6 @@ Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
445445
return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i);
446446
}
447447

448-
Value TargetInfo::permute(RewriterBase &rewriter, Location loc, Value a,
449-
Value b, Value selector) const {
450-
return LLVM::NVIDIA::permute(loc, rewriter, a, b, selector);
451-
}
452-
453448
Value TargetInfo::programId(RewriterBase &rewriter, Location loc,
454449
ModuleOp moduleOp, ProgramIDDim axis) const {
455450
return LLVM::NVIDIA::llGetPid(loc, rewriter, moduleOp, axis);

0 commit comments

Comments
 (0)