|
17 | 17 | #include "triton/Tools/LayoutUtils.h" |
18 | 18 | #include "triton/Tools/LinearLayout.h" |
19 | 19 | #include "triton/Tools/Sys/GetEnv.hpp" |
20 | | -#include "llvm/ADT/SmallSet.h" |
21 | 20 |
|
22 | 21 | namespace mlir { |
23 | 22 |
|
@@ -254,14 +253,9 @@ unsigned ScanLoweringHelper::getScratchSizeInBytes() { |
254 | 253 | return elementSizeInBytes * getScratchSizeInElems(); |
255 | 254 | } |
256 | 255 |
|
257 | | -static SmallVector<DecomposedWarpConversion::TranspositionInfo> |
258 | | -getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions, |
259 | | - std::vector<std::vector<int32_t>> ®Bases, |
260 | | - int bitwidth); |
261 | | - |
262 | 256 | DecomposedWarpConversion |
263 | 257 | getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
264 | | - RankedTensorType dstTy, int bitwidth) { |
| 258 | + RankedTensorType dstTy) { |
265 | 259 | // Two layouts, ll_src and ll_dst, representing the same tensor can be |
266 | 260 | // viewed as surjections of GF(2) vector spaces: |
267 | 261 | // |
@@ -290,12 +284,11 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
290 | 284 | // subsequences of consecutive lane bits from cycles involving both bit types. |
291 | 285 | // Further explanation of this method is below. |
292 | 286 | // |
293 | | - // The decomposition is performed in three stages. First, we compute the |
| 287 | + // The decomposition is performed in two stages. First, we compute the |
294 | 288 | // permutation matrix `P` by using `invertAndCompose` to generate a skeleton |
295 | 289 | // and then fill in any zero columns. Second, we walk the cycles of `P` to |
296 | 290 | // factor out mixed transpositions to build `mixedTranspositions`, `pReg`, and |
297 | | - // `pLane`. Finally, we determine any selectors needed for byte permute |
298 | | - // instructions in place of `selp` instructions when packing registers. |
| 291 | + // `pLane`. |
299 | 292 |
|
300 | 293 | // We remove any broadcasting in the register dimensions of the layouts before |
301 | 294 | // forming the permutation `P` as the components of the decomposition directly |
@@ -349,14 +342,19 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
349 | 342 | T = padWithZeros(T); |
350 | 343 | } |
351 | 344 |
|
| 345 | + // Flatten outs for ease of building `P`, and reorder outs as flattening |
| 346 | + // depends on output dimension order. |
| 347 | + if (outDimNames != llvm::to_vector(T.getOutDimNames())) |
| 348 | + T = T.transposeOuts(outDimNames); |
| 349 | + S = S.flattenOuts(); |
| 350 | + T = T.flattenOuts(); |
| 351 | + |
352 | 352 | // We compute T^transpose \circ S, which serves as a skeleton for `P`, then |
353 | 353 | // fill in zero columns, prioritizing producing fixed points. As we only need |
354 | 354 | // the basis vectors of `P`, we never actually produce the LinearLayout. |
355 | 355 | auto pBases = S.invertAndCompose(T).getBases(); |
356 | 356 |
|
357 | 357 | // Find the common and uncommon zeros of S and T |
358 | | - S = S.flattenOuts(); |
359 | | - T = T.flattenOuts(); |
360 | 358 | SmallVector<std::pair<int32_t, int32_t>> srcFreeZeros; |
361 | 359 | SmallVector<std::pair<int32_t, int32_t>> dstFreeZeros; |
362 | 360 | for (auto [dimIdx, dim] : llvm::enumerate(inDimNames)) { |
@@ -469,234 +467,11 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
469 | 467 | } |
470 | 468 | assert(visited.all() && "Cycle walk incomplete"); |
471 | 469 |
|
472 | | - auto processedTranspos = |
473 | | - getTranspositionSelectors(mixedTranspositions, regBases, bitwidth); |
474 | | - |
475 | 470 | auto pReg = LinearLayout(std::move(pRegBases), {{kReg, 1 << nRegBases}}, |
476 | 471 | /*requireSurjective=*/true); |
477 | 472 | auto pLane = LinearLayout(std::move(pLaneBases), {{kLane, 1 << nLaneBases}}, |
478 | 473 | /*requireSurjective=*/true); |
479 | | - return {std::move(pReg), std::move(pLane), std::move(processedTranspos)}; |
480 | | -} |
481 | | - |
482 | | -static SmallVector<DecomposedWarpConversion::TranspositionInfo> |
483 | | -getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions, |
484 | | - std::vector<std::vector<int32_t>> ®Bases, |
485 | | - int bitwidth) { |
486 | | - // When possible, we fuse permutations of 'low' register bits together |
487 | | - // with a mixed transposition, resulting in byte permute instructions instead |
488 | | - // of `select` instructions. After processing, no low register bits appear in |
489 | | - // the returned list of mixed transpositions. |
490 | | - int m = mixedTranspositions.size(); |
491 | | - int nRegBases = regBases.size(); |
492 | | - int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4)); |
493 | | - int nPack = std::min(nPackPrelim, nRegBases - m); |
494 | | - |
495 | | - SmallVector<DecomposedWarpConversion::TranspositionInfo> ret; |
496 | | - ret.reserve(mixedTranspositions.size()); |
497 | | - if (nPack == 0) { |
498 | | - for (auto &t : mixedTranspositions) |
499 | | - ret.push_back(DecomposedWarpConversion::TranspositionInfo{t}); |
500 | | - return ret; |
501 | | - } |
502 | | - // Consider for example the cycle |
503 | | - // |
504 | | - // (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3) |
505 | | - // = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2) |
506 | | - // |
507 | | - // with `nPack` = 2 so that r0 and r1 are considered low bits. We want to |
508 | | - // factor out any low bits from `pReg` and to incorporate them into the data |
509 | | - // of the mixed transposition. After processing, the contribution to `pReg` |
510 | | - // is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with |
511 | | - // the effects of (r3 r0) and (r3 r1) encoded in the returned selectors. |
512 | | - // In general, low bits occurring immediately before l_j modify the selectors |
513 | | - // of the `prmt` before the shuffle, while low bits occurring immediately |
514 | | - // after l_k modify the selectors of the `prmt` after the shuffle. Unmodified |
515 | | - // selectors correspond to `select` instructions. |
516 | | - // Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is |
517 | | - // not used in another mixed transposition and conjugating out a low bit: |
518 | | - // |
519 | | - // (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1) |
520 | | - // = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1). |
521 | | - // |
522 | | - // Conjugation does not affect `pReg`. However, the set of fused mixed and |
523 | | - // low-bit transpositions is noncommutative in cases where there are no |
524 | | - // intervening high bits in between distinct sequences of lane bits as the |
525 | | - // paired low bit is used in modifying the selectors of both factors: |
526 | | - // |
527 | | - // (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0). |
528 | | - // |
529 | | - // The `*` is standard composition of permutations. The groupings correspond |
530 | | - // to different `TranspositionInfo` objects. For example, the permutation |
531 | | - // `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with |
532 | | - // pre- and post-shuffle selectors determined by the `r0` bit. |
533 | | - // Processing of mixed transpositions is performed by determining the `head` |
534 | | - // and `tail` of an excision of bits in cycles of `pReg` and building lists |
535 | | - // of low bits acting as selector modifiers. In the noncommutative cases, we |
536 | | - // opt to restrict the number of post-shuffle modifiers to one. |
537 | | - |
538 | | - auto permuteSelector = [nPack](uint16_t sel, int bitIdx) { |
539 | | - int lo = bitIdx + (2 - nPack); |
540 | | - uint16_t maskHi = 0x4444; |
541 | | - uint16_t maskLo = 0x1111 << lo; |
542 | | - uint16_t fixed = sel & ~maskHi & ~maskLo; |
543 | | - int shift = 2 - lo; |
544 | | - return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift); |
545 | | - }; |
546 | | - auto generateSelectors = [&](int head, int tail, auto &&lowBits) { |
547 | | - uint16_t topSel = 0x3210; |
548 | | - uint16_t botSel = 0x7654; |
549 | | - for (auto lowBit : lowBits) { |
550 | | - topSel = permuteSelector(topSel, lowBit); |
551 | | - botSel = permuteSelector(botSel, lowBit); |
552 | | - if (lowBit != head && lowBit != tail) |
553 | | - regBases[lowBit][0] = 1 << lowBit; |
554 | | - } |
555 | | - return std::pair{topSel, botSel}; |
556 | | - }; |
557 | | - |
558 | | - llvm::SmallSet<int32_t, 6> pairedRegBits; |
559 | | - for (auto [rBit, lBit] : mixedTranspositions) |
560 | | - pairedRegBits.insert(rBit); |
561 | | - |
562 | | - // A low bit in a mixed transposition must be replaced by a high bit. The |
563 | | - // choice of high bit can affect instruction count. If the first high bit |
564 | | - // found when walking along `pReg` is unpaired, then that bit is the best |
565 | | - // choice. We reorder the transpositions to guarantee this during processing. |
566 | | - auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); }; |
567 | | - auto nextHighFree = [&](auto p) { |
568 | | - int curr = p.first; |
569 | | - do { |
570 | | - if (curr >= nPack) |
571 | | - return curr == p.first || !pairedRegBits.contains(curr); |
572 | | - curr = next(curr); |
573 | | - } while (curr != p.first); |
574 | | - return false; |
575 | | - }; |
576 | | - std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(), |
577 | | - nextHighFree); |
578 | | - // If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low |
579 | | - // bit to an open high bit, then the high bit should be used as the partner. |
580 | | - auto prev = [&](int b) { |
581 | | - int tail = b; |
582 | | - int curr = next(b); |
583 | | - while (curr != b) { |
584 | | - tail = curr; |
585 | | - curr = next(curr); |
586 | | - } |
587 | | - return tail; |
588 | | - }; |
589 | | - auto findPartner = [&](int lowBit, auto &preShufLoBits) { |
590 | | - if (nPack == 2) { |
591 | | - int otherLow = 1 - lowBit; |
592 | | - int b = next(otherLow); |
593 | | - if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) && |
594 | | - !pairedRegBits.contains(otherLow)) { |
595 | | - preShufLoBits.push_back(otherLow); |
596 | | - regBases[prev(otherLow)][0] = 1 << b; |
597 | | - pairedRegBits.insert(b); |
598 | | - return b; |
599 | | - } |
600 | | - } |
601 | | - int potentialPartner = nPack; |
602 | | - while (pairedRegBits.contains(potentialPartner)) |
603 | | - ++potentialPartner; |
604 | | - pairedRegBits.insert(potentialPartner); |
605 | | - return potentialPartner; |
606 | | - }; |
607 | | - |
608 | | - for (auto p : mixedTranspositions) { |
609 | | - int rBit = p.first; |
610 | | - int lBit = p.second; |
611 | | - SmallVector<int> cycle; |
612 | | - int currBit = rBit; |
613 | | - do { |
614 | | - cycle.push_back(currBit); |
615 | | - currBit = next(currBit); |
616 | | - } while (currBit != rBit); |
617 | | - |
618 | | - // Find any low register bits adjacent to the excised lane bits which aren't |
619 | | - // used in other mixed transpositions. |
620 | | - auto isBoundary = [&](int bit) { |
621 | | - return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit); |
622 | | - }; |
623 | | - auto forwardEnd = llvm::find_if(cycle, isBoundary); |
624 | | - auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary); |
625 | | - SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd); |
626 | | - SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd); |
627 | | - int head; |
628 | | - int tail; |
629 | | - int partnerBit = -1; |
630 | | - |
631 | | - // Case work to determine what to conjugate out. |
632 | | - if (forwardEnd != cycle.end()) { |
633 | | - if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) { |
634 | | - // End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2) |
635 | | - // No conjugation needed. |
636 | | - head = partnerBit = *forwardEnd; |
637 | | - } else { |
638 | | - // End at different paired bit. E.g. (l0 r0 r1 l1 r2) |
639 | | - // Non-leading factor in a noncommutative case. |
640 | | - // Conjugate by first low bit in forward walk. |
641 | | - head = postShufLoBits.front(); |
642 | | - preShufLoBits.push_back(head); |
643 | | - postShufLoBits.resize(1); |
644 | | - pairedRegBits.erase(head); |
645 | | - } |
646 | | - tail = *backwardEnd; |
647 | | - if (tail < nPack && pairedRegBits.contains(tail)) { |
648 | | - // Non-terminal factor in a noncommutative case. |
649 | | - preShufLoBits.insert(preShufLoBits.begin(), tail); |
650 | | - } |
651 | | - } else { |
652 | | - if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) { |
653 | | - // Symmetric noncommutative case. E.g. (l0 r0 l1 r1) |
654 | | - preShufLoBits.erase(preShufLoBits.begin()); |
655 | | - postShufLoBits.pop_back(); |
656 | | - pairedRegBits.erase(postShufLoBits.front()); |
657 | | - head = rBit; |
658 | | - tail = next(rBit); |
659 | | - } else { |
660 | | - // Isolated low bits with single mixed transposition. E.g. (l0 r0 r1) |
661 | | - if (postShufLoBits.size() == 2) |
662 | | - postShufLoBits.pop_back(); |
663 | | - head = tail = preShufLoBits.front(); |
664 | | - } |
665 | | - } |
666 | | - |
667 | | - if (partnerBit < 0) |
668 | | - partnerBit = findPartner(head, preShufLoBits); |
669 | | - auto [topPostSel, botPostSel] = |
670 | | - generateSelectors(head, tail, llvm::reverse(postShufLoBits)); |
671 | | - auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits); |
672 | | - regBases[tail][0] = 1 << head; |
673 | | - |
674 | | - DecomposedWarpConversion::TranspositionInfo info; |
675 | | - info.transposition = {partnerBit, lBit}; |
676 | | - info.topPreSel = topPreSel; |
677 | | - info.botPreSel = botPreSel; |
678 | | - info.topPostSel = topPostSel; |
679 | | - info.botPostSel = botPostSel; |
680 | | - |
681 | | - // In noncommutative cases, post-shuffle selectors of non-leading terms come |
682 | | - // from a single low bit by design, so we can determine where to insert a |
683 | | - // non-terminal factor by examining processed selectors. |
684 | | - if (!preShufLoBits.empty()) { |
685 | | - uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410; |
686 | | - auto it = |
687 | | - llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; }); |
688 | | - ret.insert(it, info); |
689 | | - } else { |
690 | | - ret.push_back(info); |
691 | | - } |
692 | | - } |
693 | | - if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) { |
694 | | - // If (r0 r1) was originally in `P`, fold it into a mixed transposition. |
695 | | - auto &t = ret.back(); |
696 | | - t.topPostSel = 0x3120; |
697 | | - t.botPostSel = 0x7564; |
698 | | - } |
699 | | - return ret; |
| 474 | + return {std::move(pReg), std::move(pLane), std::move(mixedTranspositions)}; |
700 | 475 | } |
701 | 476 |
|
702 | 477 | SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>> |
@@ -994,7 +769,7 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { |
994 | 769 | auto kLane = StringAttr::get(ctx, "lane"); |
995 | 770 | if (to_vector(layout.getOutDimNames()) == |
996 | 771 | SmallVector<StringAttr, 2>{kRegister, kLane}) { |
997 | | - auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, 32); |
| 772 | + auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy); |
998 | 773 | return (factors.mixedTranspositions.size() < 2); |
999 | 774 | } |
1000 | 775 | return false; |
|
0 commit comments