|
17 | 17 | #include "triton/Tools/LayoutUtils.h" |
18 | 18 | #include "triton/Tools/LinearLayout.h" |
19 | 19 | #include "triton/Tools/Sys/GetEnv.hpp" |
| 20 | +#include "llvm/ADT/SmallSet.h" |
20 | 21 |
|
21 | 22 | namespace mlir { |
22 | 23 |
|
@@ -253,9 +254,14 @@ unsigned ScanLoweringHelper::getScratchSizeInBytes() { |
253 | 254 | return elementSizeInBytes * getScratchSizeInElems(); |
254 | 255 | } |
255 | 256 |
|
| 257 | +static SmallVector<DecomposedWarpConversion::TranspositionInfo> |
| 258 | +getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions, |
| 259 | + std::vector<std::vector<int32_t>> ®Bases, |
| 260 | + int bitwidth); |
| 261 | + |
256 | 262 | DecomposedWarpConversion |
257 | 263 | getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
258 | | - RankedTensorType dstTy) { |
| 264 | + RankedTensorType dstTy, int bitwidth) { |
259 | 265 | // Two layouts, ll_src and ll_dst, representing the same tensor can be |
260 | 266 | // viewed as surjections of GF(2) vector spaces: |
261 | 267 | // |
@@ -284,11 +290,12 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
284 | 290 | // subsequences of consecutive lane bits from cycles involving both bit types. |
285 | 291 | // Further explanation of this method is below. |
286 | 292 | // |
287 | | - // The decomposition is performed in two stages. First, we compute the |
| 293 | + // The decomposition is performed in three stages. First, we compute the |
288 | 294 | // permutation matrix `P` by using `invertAndCompose` to generate a skeleton |
289 | 295 | // and then fill in any zero columns. Second, we walk the cycles of `P` to |
290 | 296 | // factor out mixed transpositions to build `mixedTranspositions`, `pReg`, and |
291 | | - // `pLane`. |
| 297 | + // `pLane`. Finally, we determine any selectors needed for byte permute |
| 298 | + // instructions in place of `selp` instructions when packing registers. |
292 | 299 |
|
293 | 300 | // We remove any broadcasting in the register dimensions of the layouts before |
294 | 301 | // forming the permutation `P` as the components of the decomposition directly |
@@ -342,19 +349,14 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
342 | 349 | T = padWithZeros(T); |
343 | 350 | } |
344 | 351 |
|
345 | | - // Flatten outs for ease of building `P`, and reorder outs as flattening |
346 | | - // depends on output dimension order. |
347 | | - if (outDimNames != llvm::to_vector(T.getOutDimNames())) |
348 | | - T = T.transposeOuts(outDimNames); |
349 | | - S = S.flattenOuts(); |
350 | | - T = T.flattenOuts(); |
351 | | - |
352 | 352 | // We compute T^transpose \circ S, which serves as a skeleton for `P`, then |
353 | 353 | // fill in zero columns, prioritizing producing fixed points. As we only need |
354 | 354 | // the basis vectors of `P`, we never actually produce the LinearLayout. |
355 | 355 | auto pBases = S.invertAndCompose(T).getBases(); |
356 | 356 |
|
357 | 357 | // Find the common and uncommon zeros of S and T |
| 358 | + S = S.flattenOuts(); |
| 359 | + T = T.flattenOuts(); |
358 | 360 | SmallVector<std::pair<int32_t, int32_t>> srcFreeZeros; |
359 | 361 | SmallVector<std::pair<int32_t, int32_t>> dstFreeZeros; |
360 | 362 | for (auto [dimIdx, dim] : llvm::enumerate(inDimNames)) { |
@@ -467,11 +469,234 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
467 | 469 | } |
468 | 470 | assert(visited.all() && "Cycle walk incomplete"); |
469 | 471 |
|
| 472 | + auto processedTranspos = |
| 473 | + getTranspositionSelectors(mixedTranspositions, regBases, bitwidth); |
| 474 | + |
470 | 475 | auto pReg = LinearLayout(std::move(pRegBases), {{kReg, 1 << nRegBases}}, |
471 | 476 | /*requireSurjective=*/true); |
472 | 477 | auto pLane = LinearLayout(std::move(pLaneBases), {{kLane, 1 << nLaneBases}}, |
473 | 478 | /*requireSurjective=*/true); |
474 | | - return {std::move(pReg), std::move(pLane), std::move(mixedTranspositions)}; |
| 479 | + return {std::move(pReg), std::move(pLane), std::move(processedTranspos)}; |
| 480 | +} |
| 481 | + |
| 482 | +static SmallVector<DecomposedWarpConversion::TranspositionInfo> |
| 483 | +getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions, |
| 484 | + std::vector<std::vector<int32_t>> ®Bases, |
| 485 | + int bitwidth) { |
| 486 | + // When possible, we fuse permutations of 'low' register bits together |
| 487 | + // with a mixed transposition, resulting in byte permute instructions instead |
| 488 | + // of `select` instructions. After processing, no low register bits appear in |
| 489 | + // the returned list of mixed transpositions. |
| 490 | + int m = mixedTranspositions.size(); |
| 491 | + int nRegBases = regBases.size(); |
| 492 | + int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4)); |
| 493 | + int nPack = std::min(nPackPrelim, nRegBases - m); |
| 494 | + |
| 495 | + SmallVector<DecomposedWarpConversion::TranspositionInfo> ret; |
| 496 | + ret.reserve(mixedTranspositions.size()); |
| 497 | + if (nPack == 0) { |
| 498 | + for (auto &t : mixedTranspositions) |
| 499 | + ret.push_back(DecomposedWarpConversion::TranspositionInfo{t}); |
| 500 | + return ret; |
| 501 | + } |
| 502 | + // Consider for example the cycle |
| 503 | + // |
| 504 | + // (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3) |
| 505 | + // = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2) |
| 506 | + // |
| 507 | + // with `nPack` = 2 so that r0 and r1 are considered low bits. We want to |
| 508 | + // factor out any low bits from `pReg` and to incorporate them into the data |
| 509 | + // of the mixed transposition. After processing, the contribution to `pReg` |
| 510 | + // is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with |
| 511 | + // the effects of (r3 r0) and (r3 r1) encoded in the returned selectors. |
| 512 | + // In general, low bits occurring immediately before l_j modify the selectors |
| 513 | + // of the `prmt` before the shuffle, while low bits occurring immediately |
| 514 | + // after l_k modify the selectors of the `prmt` after the shuffle. Unmodified |
| 515 | + // selectors correspond to `select` instructions. |
| 516 | + // Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is |
| 517 | + // not used in another mixed transposition and conjugating out a low bit: |
| 518 | + // |
| 519 | + // (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1) |
| 520 | + // = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1). |
| 521 | + // |
| 522 | + // Conjugation does not affect `pReg`. However, the set of fused mixed and |
| 523 | + // low-bit transpositions is noncommutative in cases where there are no |
| 524 | + // intervening high bits in between distinct sequences of lane bits as the |
| 525 | + // paired low bit is used in modifying the selectors of both factors: |
| 526 | + // |
| 527 | + // (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0). |
| 528 | + // |
| 529 | + // The `*` is standard composition of permutations. The groupings correspond |
| 530 | + // to different `TranspositionInfo` objects. For example, the permutation |
| 531 | + // `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with |
| 532 | + // pre- and post-shuffle selectors determined by the `r0` bit. |
| 533 | + // Processing of mixed transpositions is performed by determining the `head` |
| 534 | + // and `tail` of an excision of bits in cycles of `pReg` and building lists |
| 535 | + // of low bits acting as selector modifiers. In the noncommutative cases, we |
| 536 | + // opt to restrict the number of post-shuffle modifiers to one. |
| 537 | + |
| 538 | + auto permuteSelector = [nPack](uint16_t sel, int bitIdx) { |
| 539 | + int lo = bitIdx + (2 - nPack); |
| 540 | + uint16_t maskHi = 0x4444; |
| 541 | + uint16_t maskLo = 0x1111 << lo; |
| 542 | + uint16_t fixed = sel & ~maskHi & ~maskLo; |
| 543 | + int shift = 2 - lo; |
| 544 | + return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift); |
| 545 | + }; |
| 546 | + auto generateSelectors = [&](int head, int tail, auto &&lowBits) { |
| 547 | + uint16_t topSel = 0x3210; |
| 548 | + uint16_t botSel = 0x7654; |
| 549 | + for (auto lowBit : lowBits) { |
| 550 | + topSel = permuteSelector(topSel, lowBit); |
| 551 | + botSel = permuteSelector(botSel, lowBit); |
| 552 | + if (lowBit != head && lowBit != tail) |
| 553 | + regBases[lowBit][0] = 1 << lowBit; |
| 554 | + } |
| 555 | + return std::pair{topSel, botSel}; |
| 556 | + }; |
| 557 | + |
| 558 | + llvm::SmallSet<int32_t, 6> pairedRegBits; |
| 559 | + for (auto [rBit, lBit] : mixedTranspositions) |
| 560 | + pairedRegBits.insert(rBit); |
| 561 | + |
| 562 | + // A low bit in a mixed transposition must be replaced by a high bit. The |
| 563 | + // choice of high bit can affect instruction count. If the first high bit |
| 564 | + // found when walking along `pReg` is unpaired, then that bit is the best |
| 565 | + // choice. We reorder the transpositions to guarantee this during processing. |
| 566 | + auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); }; |
| 567 | + auto nextHighFree = [&](auto p) { |
| 568 | + int curr = p.first; |
| 569 | + do { |
| 570 | + if (curr >= nPack) |
| 571 | + return curr == p.first || !pairedRegBits.contains(curr); |
| 572 | + curr = next(curr); |
| 573 | + } while (curr != p.first); |
| 574 | + return false; |
| 575 | + }; |
| 576 | + std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(), |
| 577 | + nextHighFree); |
| 578 | + // If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low |
| 579 | + // bit to an open high bit, then the high bit should be used as the partner. |
| 580 | + auto prev = [&](int b) { |
| 581 | + int tail = b; |
| 582 | + int curr = next(b); |
| 583 | + while (curr != b) { |
| 584 | + tail = curr; |
| 585 | + curr = next(curr); |
| 586 | + } |
| 587 | + return tail; |
| 588 | + }; |
| 589 | + auto findPartner = [&](int lowBit, auto &preShufLoBits) { |
| 590 | + if (nPack == 2) { |
| 591 | + int otherLow = 1 - lowBit; |
| 592 | + int b = next(otherLow); |
| 593 | + if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) && |
| 594 | + !pairedRegBits.contains(otherLow)) { |
| 595 | + preShufLoBits.push_back(otherLow); |
| 596 | + regBases[prev(otherLow)][0] = 1 << b; |
| 597 | + pairedRegBits.insert(b); |
| 598 | + return b; |
| 599 | + } |
| 600 | + } |
| 601 | + int potentialPartner = nPack; |
| 602 | + while (pairedRegBits.contains(potentialPartner)) |
| 603 | + ++potentialPartner; |
| 604 | + pairedRegBits.insert(potentialPartner); |
| 605 | + return potentialPartner; |
| 606 | + }; |
| 607 | + |
| 608 | + for (auto p : mixedTranspositions) { |
| 609 | + int rBit = p.first; |
| 610 | + int lBit = p.second; |
| 611 | + SmallVector<int> cycle; |
| 612 | + int currBit = rBit; |
| 613 | + do { |
| 614 | + cycle.push_back(currBit); |
| 615 | + currBit = next(currBit); |
| 616 | + } while (currBit != rBit); |
| 617 | + |
| 618 | + // Find any low register bits adjacent to the excised lane bits which aren't |
| 619 | + // used in other mixed transpositions. |
| 620 | + auto isBoundary = [&](int bit) { |
| 621 | + return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit); |
| 622 | + }; |
| 623 | + auto forwardEnd = llvm::find_if(cycle, isBoundary); |
| 624 | + auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary); |
| 625 | + SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd); |
| 626 | + SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd); |
| 627 | + int head; |
| 628 | + int tail; |
| 629 | + int partnerBit = -1; |
| 630 | + |
| 631 | + // Case work to determine what to conjugate out. |
| 632 | + if (forwardEnd != cycle.end()) { |
| 633 | + if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) { |
| 634 | + // End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2) |
| 635 | + // No conjugation needed. |
| 636 | + head = partnerBit = *forwardEnd; |
| 637 | + } else { |
| 638 | + // End at different paired bit. E.g. (l0 r0 r1 l1 r2) |
| 639 | + // Non-leading factor in a noncommutative case. |
| 640 | + // Conjugate by first low bit in forward walk. |
| 641 | + head = postShufLoBits.front(); |
| 642 | + preShufLoBits.push_back(head); |
| 643 | + postShufLoBits.resize(1); |
| 644 | + pairedRegBits.erase(head); |
| 645 | + } |
| 646 | + tail = *backwardEnd; |
| 647 | + if (tail < nPack && pairedRegBits.contains(tail)) { |
| 648 | + // Non-terminal factor in a noncommutative case. |
| 649 | + preShufLoBits.insert(preShufLoBits.begin(), tail); |
| 650 | + } |
| 651 | + } else { |
| 652 | + if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) { |
| 653 | + // Symmetric noncommutative case. E.g. (l0 r0 l1 r1) |
| 654 | + preShufLoBits.erase(preShufLoBits.begin()); |
| 655 | + postShufLoBits.pop_back(); |
| 656 | + pairedRegBits.erase(postShufLoBits.front()); |
| 657 | + head = rBit; |
| 658 | + tail = next(rBit); |
| 659 | + } else { |
| 660 | + // Isolated low bits with single mixed transposition. E.g. (l0 r0 r1) |
| 661 | + if (postShufLoBits.size() == 2) |
| 662 | + postShufLoBits.pop_back(); |
| 663 | + head = tail = preShufLoBits.front(); |
| 664 | + } |
| 665 | + } |
| 666 | + |
| 667 | + if (partnerBit < 0) |
| 668 | + partnerBit = findPartner(head, preShufLoBits); |
| 669 | + auto [topPostSel, botPostSel] = |
| 670 | + generateSelectors(head, tail, llvm::reverse(postShufLoBits)); |
| 671 | + auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits); |
| 672 | + regBases[tail][0] = 1 << head; |
| 673 | + |
| 674 | + DecomposedWarpConversion::TranspositionInfo info; |
| 675 | + info.transposition = {partnerBit, lBit}; |
| 676 | + info.topPreSel = topPreSel; |
| 677 | + info.botPreSel = botPreSel; |
| 678 | + info.topPostSel = topPostSel; |
| 679 | + info.botPostSel = botPostSel; |
| 680 | + |
| 681 | + // In noncommutative cases, post-shuffle selectors of non-leading terms come |
| 682 | + // from a single low bit by design, so we can determine where to insert a |
| 683 | + // non-terminal factor by examining processed selectors. |
| 684 | + if (!preShufLoBits.empty()) { |
| 685 | + uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410; |
| 686 | + auto it = |
| 687 | + llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; }); |
| 688 | + ret.insert(it, info); |
| 689 | + } else { |
| 690 | + ret.push_back(info); |
| 691 | + } |
| 692 | + } |
| 693 | + if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) { |
| 694 | + // If (r0 r1) was originally in `P`, fold it into a mixed transposition. |
| 695 | + auto &t = ret.back(); |
| 696 | + t.topPostSel = 0x3120; |
| 697 | + t.botPostSel = 0x7564; |
| 698 | + } |
| 699 | + return ret; |
475 | 700 | } |
476 | 701 |
|
477 | 702 | SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>> |
@@ -769,7 +994,7 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { |
769 | 994 | auto kLane = StringAttr::get(ctx, "lane"); |
770 | 995 | if (to_vector(layout.getOutDimNames()) == |
771 | 996 | SmallVector<StringAttr, 2>{kRegister, kLane}) { |
772 | | - auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy); |
| 997 | + auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, 32); |
773 | 998 | return (factors.mixedTranspositions.size() < 2); |
774 | 999 | } |
775 | 1000 | return false; |
|
0 commit comments