|
17 | 17 | #include "triton/Tools/LayoutUtils.h" |
18 | 18 | #include "triton/Tools/LinearLayout.h" |
19 | 19 | #include "triton/Tools/Sys/GetEnv.hpp" |
20 | | -#include "llvm/ADT/SmallSet.h" |
21 | 20 |
|
22 | 21 | namespace mlir { |
23 | 22 |
|
@@ -254,11 +253,6 @@ unsigned ScanLoweringHelper::getScratchSizeInBytes() { |
254 | 253 | return elementSizeInBytes * getScratchSizeInElems(); |
255 | 254 | } |
256 | 255 |
|
257 | | -static SmallVector<DecomposedWarpConversion::TranspositionInfo> |
258 | | -getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions, |
259 | | - std::vector<std::vector<int32_t>> ®Bases, |
260 | | - int bitwidth); |
261 | | - |
262 | 256 | DecomposedWarpConversion |
263 | 257 | getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
264 | 258 | RankedTensorType dstTy) { |
@@ -480,226 +474,6 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
480 | 474 | return {std::move(pReg), std::move(pLane), std::move(mixedTranspositions)}; |
481 | 475 | } |
482 | 476 |
|
483 | | -static SmallVector<DecomposedWarpConversion::TranspositionInfo> |
484 | | -getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions, |
485 | | - std::vector<std::vector<int32_t>> ®Bases, |
486 | | - int bitwidth) { |
487 | | - // When possible, we fuse permutations of 'low' register bits together |
488 | | - // with a mixed transposition, resulting in byte permute instructions instead |
489 | | - // of `select` instructions. After processing, no low register bits appear in |
490 | | - // the returned list of mixed transpositions. |
491 | | - int m = mixedTranspositions.size(); |
492 | | - int nRegBases = regBases.size(); |
493 | | - int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4)); |
494 | | - int nPack = std::min(nPackPrelim, nRegBases - m); |
495 | | - |
496 | | - SmallVector<DecomposedWarpConversion::TranspositionInfo> ret; |
497 | | - ret.reserve(mixedTranspositions.size()); |
498 | | - if (nPack == 0) { |
499 | | - for (auto &t : mixedTranspositions) |
500 | | - ret.push_back(DecomposedWarpConversion::TranspositionInfo{t}); |
501 | | - return ret; |
502 | | - } |
503 | | - // Consider for example the cycle |
504 | | - // |
505 | | - // (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3) |
506 | | - // = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2) |
507 | | - // |
508 | | - // with `nPack` = 2 so that r0 and r1 are considered low bits. We want to |
509 | | - // factor out any low bits from `pReg` and to incorporate them into the data |
510 | | - // of the mixed transposition. After processing, the contribution to `pReg` |
511 | | - // is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with |
512 | | - // the effects of (r3 r0) and (r3 r1) encoded in the returned selectors. |
513 | | - // In general, low bits occurring immediately before l_j modify the selectors |
514 | | - // of the `prmt` before the shuffle, while low bits occurring immediately |
515 | | - // after l_k modify the selectors of the `prmt` after the shuffle. Unmodified |
516 | | - // selectors correspond to `select` instructions. |
517 | | - // Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is |
518 | | - // not used in another mixed transposition and conjugating out a low bit: |
519 | | - // |
520 | | - // (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1) |
521 | | - // = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1). |
522 | | - // |
523 | | - // Conjugation does not affect `pReg`. However, the set of fused mixed and |
524 | | - // low-bit transpositions is noncommutative in cases where there are no |
525 | | - // intervening high bits in between distinct sequences of lane bits as the |
526 | | - // paired low bit is used in modifying the selectors of both factors: |
527 | | - // |
528 | | - // (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0). |
529 | | - // |
530 | | - // The `*` is standard composition of permutations. The groupings correspond |
531 | | - // to different `TranspositionInfo` objects. For example, the permutation |
532 | | - // `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with |
533 | | - // pre- and post-shuffle selectors determined by the `r0` bit. |
534 | | - // Processing of mixed transpositions is performed by determining the `head` |
535 | | - // and `tail` of an excision of bits in cycles of `pReg` and building lists |
536 | | - // of low bits acting as selector modifiers. In the noncommutative cases, we |
537 | | - // opt to restrict the number of post-shuffle modifiers to one. |
538 | | - |
539 | | - auto permuteSelector = [nPack](uint16_t sel, int bitIdx) { |
540 | | - int lo = bitIdx + (2 - nPack); |
541 | | - uint16_t maskHi = 0x4444; |
542 | | - uint16_t maskLo = 0x1111 << lo; |
543 | | - uint16_t fixed = sel & ~maskHi & ~maskLo; |
544 | | - int shift = 2 - lo; |
545 | | - return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift); |
546 | | - }; |
547 | | - auto generateSelectors = [&](int head, int tail, auto &&lowBits) { |
548 | | - uint16_t topSel = 0x3210; |
549 | | - uint16_t botSel = 0x7654; |
550 | | - for (auto lowBit : lowBits) { |
551 | | - topSel = permuteSelector(topSel, lowBit); |
552 | | - botSel = permuteSelector(botSel, lowBit); |
553 | | - if (lowBit != head && lowBit != tail) |
554 | | - regBases[lowBit][0] = 1 << lowBit; |
555 | | - } |
556 | | - return std::pair{topSel, botSel}; |
557 | | - }; |
558 | | - |
559 | | - llvm::SmallSet<int32_t, 6> pairedRegBits; |
560 | | - for (auto [rBit, lBit] : mixedTranspositions) |
561 | | - pairedRegBits.insert(rBit); |
562 | | - |
563 | | - // A low bit in a mixed transposition must be replaced by a high bit. The |
564 | | - // choice of high bit can affect instruction count. If the first high bit |
565 | | - // found when walking along `pReg` is unpaired, then that bit is the best |
566 | | - // choice. We reorder the transpositions to guarantee this during processing. |
567 | | - auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); }; |
568 | | - auto nextHighFree = [&](auto p) { |
569 | | - int curr = p.first; |
570 | | - do { |
571 | | - if (curr >= nPack) |
572 | | - return curr == p.first || !pairedRegBits.contains(curr); |
573 | | - curr = next(curr); |
574 | | - } while (curr != p.first); |
575 | | - return false; |
576 | | - }; |
577 | | - std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(), |
578 | | - nextHighFree); |
579 | | - // If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low |
580 | | - // bit to an open high bit, then the high bit should be used as the partner. |
581 | | - auto prev = [&](int b) { |
582 | | - int tail = b; |
583 | | - int curr = next(b); |
584 | | - while (curr != b) { |
585 | | - tail = curr; |
586 | | - curr = next(curr); |
587 | | - } |
588 | | - return tail; |
589 | | - }; |
590 | | - auto findPartner = [&](int lowBit, auto &preShufLoBits) { |
591 | | - if (nPack == 2) { |
592 | | - int otherLow = 1 - lowBit; |
593 | | - int b = next(otherLow); |
594 | | - if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) && |
595 | | - !pairedRegBits.contains(otherLow)) { |
596 | | - preShufLoBits.push_back(otherLow); |
597 | | - regBases[prev(otherLow)][0] = 1 << b; |
598 | | - pairedRegBits.insert(b); |
599 | | - return b; |
600 | | - } |
601 | | - } |
602 | | - int potentialPartner = nPack; |
603 | | - while (pairedRegBits.contains(potentialPartner)) |
604 | | - ++potentialPartner; |
605 | | - pairedRegBits.insert(potentialPartner); |
606 | | - return potentialPartner; |
607 | | - }; |
608 | | - |
609 | | - for (auto p : mixedTranspositions) { |
610 | | - int rBit = p.first; |
611 | | - int lBit = p.second; |
612 | | - SmallVector<int> cycle; |
613 | | - int currBit = rBit; |
614 | | - do { |
615 | | - cycle.push_back(currBit); |
616 | | - currBit = next(currBit); |
617 | | - } while (currBit != rBit); |
618 | | - |
619 | | - // Find any low register bits adjacent to the excised lane bits which aren't |
620 | | - // used in other mixed transpositions. |
621 | | - auto isBoundary = [&](int bit) { |
622 | | - return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit); |
623 | | - }; |
624 | | - auto forwardEnd = llvm::find_if(cycle, isBoundary); |
625 | | - auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary); |
626 | | - SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd); |
627 | | - SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd); |
628 | | - int head; |
629 | | - int tail; |
630 | | - int partnerBit = -1; |
631 | | - |
632 | | - // Case work to determine what to conjugate out. |
633 | | - if (forwardEnd != cycle.end()) { |
634 | | - if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) { |
635 | | - // End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2) |
636 | | - // No conjugation needed. |
637 | | - head = partnerBit = *forwardEnd; |
638 | | - } else { |
639 | | - // End at different paired bit. E.g. (l0 r0 r1 l1 r2) |
640 | | - // Non-leading factor in a noncommutative case. |
641 | | - // Conjugate by first low bit in forward walk. |
642 | | - head = postShufLoBits.front(); |
643 | | - preShufLoBits.push_back(head); |
644 | | - postShufLoBits.resize(1); |
645 | | - pairedRegBits.erase(head); |
646 | | - } |
647 | | - tail = *backwardEnd; |
648 | | - if (tail < nPack && pairedRegBits.contains(tail)) { |
649 | | - // Non-terminal factor in a noncommutative case. |
650 | | - preShufLoBits.insert(preShufLoBits.begin(), tail); |
651 | | - } |
652 | | - } else { |
653 | | - if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) { |
654 | | - // Symmetric noncommutative case. E.g. (l0 r0 l1 r1) |
655 | | - preShufLoBits.erase(preShufLoBits.begin()); |
656 | | - postShufLoBits.pop_back(); |
657 | | - pairedRegBits.erase(postShufLoBits.front()); |
658 | | - head = rBit; |
659 | | - tail = next(rBit); |
660 | | - } else { |
661 | | - // Isolated low bits with single mixed transposition. E.g. (l0 r0 r1) |
662 | | - if (postShufLoBits.size() == 2) |
663 | | - postShufLoBits.pop_back(); |
664 | | - head = tail = preShufLoBits.front(); |
665 | | - } |
666 | | - } |
667 | | - |
668 | | - if (partnerBit < 0) |
669 | | - partnerBit = findPartner(head, preShufLoBits); |
670 | | - auto [topPostSel, botPostSel] = |
671 | | - generateSelectors(head, tail, llvm::reverse(postShufLoBits)); |
672 | | - auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits); |
673 | | - regBases[tail][0] = 1 << head; |
674 | | - |
675 | | - DecomposedWarpConversion::TranspositionInfo info; |
676 | | - info.transposition = {partnerBit, lBit}; |
677 | | - info.topPreSel = topPreSel; |
678 | | - info.botPreSel = botPreSel; |
679 | | - info.topPostSel = topPostSel; |
680 | | - info.botPostSel = botPostSel; |
681 | | - |
682 | | - // In noncommutative cases, post-shuffle selectors of non-leading terms come |
683 | | - // from a single low bit by design, so we can determine where to insert a |
684 | | - // non-terminal factor by examining processed selectors. |
685 | | - if (!preShufLoBits.empty()) { |
686 | | - uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410; |
687 | | - auto it = |
688 | | - llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; }); |
689 | | - ret.insert(it, info); |
690 | | - } else { |
691 | | - ret.push_back(info); |
692 | | - } |
693 | | - } |
694 | | - if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) { |
695 | | - // If (r0 r1) was originally in `P`, fold it into a mixed transposition. |
696 | | - auto &t = ret.back(); |
697 | | - t.topPostSel = 0x3120; |
698 | | - t.botPostSel = 0x7564; |
699 | | - } |
700 | | - return ret; |
701 | | -} |
702 | | - |
703 | 477 | SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>> |
704 | 478 | getReshapeDecomposition(ArrayRef<int64_t> srcShape, |
705 | 479 | ArrayRef<int64_t> dstShape) { |
|
0 commit comments