|
16 | 16 | #include "triton/Tools/LayoutUtils.h" |
17 | 17 | #include "triton/Tools/LinearLayout.h" |
18 | 18 | #include "triton/Tools/Sys/GetEnv.hpp" |
| 19 | +#include "llvm/ADT/SmallSet.h" |
19 | 20 |
|
20 | 21 | namespace mlir { |
21 | 22 |
|
@@ -247,9 +248,14 @@ unsigned ScanLoweringHelper::getScratchSizeInBytes() { |
247 | 248 | return elementSizeInBytes * getScratchSizeInElems(); |
248 | 249 | } |
249 | 250 |
|
| 251 | +static SmallVector<DecomposedWarpConversion::TranspositionInfo> |
| 252 | +getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions, |
| 253 | + std::vector<std::vector<int32_t>> ®Bases, |
| 254 | + int bitwidth); |
| 255 | + |
250 | 256 | DecomposedWarpConversion |
251 | 257 | getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
252 | | - RankedTensorType dstTy) { |
| 258 | + RankedTensorType dstTy, int bitwidth) { |
253 | 259 | // Two layouts, ll_src and ll_dst, representing the same tensor can be |
254 | 260 | // viewed as surjections of GF(2) vector spaces: |
255 | 261 | // |
@@ -278,11 +284,12 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
278 | 284 | // subsequences of consecutive lane bits from cycles involving both bit types. |
279 | 285 | // Further explanation of this method is below. |
280 | 286 | // |
281 | | - // The decomposition is performed in two stages. First, we compute the |
| 287 | + // The decomposition is performed in three stages. First, we compute the |
282 | 288 | // permutation matrix `P` by using `invertAndCompose` to generate a skeleton |
283 | 289 | // and then fill in any zero columns. Second, we walk the cycles of `P` to |
284 | 290 | // factor out mixed transpositions to build `mixedTranspositions`, `pReg`, and |
285 | | - // `pLane`. |
| 291 | + // `pLane`. Finally, we determine any selectors needed for byte permute |
| 292 | + // instructions in place of `selp` instructions when packing registers. |
286 | 293 |
|
287 | 294 | // We remove any broadcasting in the register dimensions of the layouts before |
288 | 295 | // forming the permutation `P` as the components of the decomposition directly |
@@ -336,19 +343,14 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
336 | 343 | T = padWithZeros(T); |
337 | 344 | } |
338 | 345 |
|
339 | | - // Flatten outs for ease of building `P`, and reorder outs as flattening |
340 | | - // depends on output dimension order. |
341 | | - if (outDimNames != llvm::to_vector(T.getOutDimNames())) |
342 | | - T = T.transposeOuts(outDimNames); |
343 | | - S = S.flattenOuts(); |
344 | | - T = T.flattenOuts(); |
345 | | - |
346 | 346 | // We compute T^transpose \circ S, which serves as a skeleton for `P`, then |
347 | 347 | // fill in zero columns, prioritizing producing fixed points. As we only need |
348 | 348 | // the basis vectors of `P`, we never actually produce the LinearLayout. |
349 | 349 | auto pBases = S.invertAndCompose(T).getBases(); |
350 | 350 |
|
351 | 351 | // Find the common and uncommon zeros of S and T |
| 352 | + S = S.flattenOuts(); |
| 353 | + T = T.flattenOuts(); |
352 | 354 | SmallVector<std::pair<int32_t, int32_t>> srcFreeZeros; |
353 | 355 | SmallVector<std::pair<int32_t, int32_t>> dstFreeZeros; |
354 | 356 | for (auto [dimIdx, dim] : llvm::enumerate(inDimNames)) { |
@@ -461,11 +463,234 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy, |
461 | 463 | } |
462 | 464 | assert(visited.all() && "Cycle walk incomplete"); |
463 | 465 |
|
| 466 | + auto processedTranspos = |
| 467 | + getTranspositionSelectors(mixedTranspositions, regBases, bitwidth); |
| 468 | + |
464 | 469 | auto pReg = LinearLayout(std::move(pRegBases), {{kReg, 1 << nRegBases}}, |
465 | 470 | /*requireSurjective=*/true); |
466 | 471 | auto pLane = LinearLayout(std::move(pLaneBases), {{kLane, 1 << nLaneBases}}, |
467 | 472 | /*requireSurjective=*/true); |
468 | | - return {std::move(pReg), std::move(pLane), std::move(mixedTranspositions)}; |
| 473 | + return {std::move(pReg), std::move(pLane), std::move(processedTranspos)}; |
| 474 | +} |
| 475 | + |
| 476 | +static SmallVector<DecomposedWarpConversion::TranspositionInfo> |
| 477 | +getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions, |
| 478 | + std::vector<std::vector<int32_t>> ®Bases, |
| 479 | + int bitwidth) { |
| 480 | + // When possible, we fuse permutations of 'low' register bits together |
| 481 | + // with a mixed transposition, resulting in byte permute instructions instead |
| 482 | + // of `select` instructions. After processing, no low register bits appear in |
| 483 | + // the returned list of mixed transpositions. |
| 484 | + int m = mixedTranspositions.size(); |
| 485 | + int nRegBases = regBases.size(); |
| 486 | + int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4)); |
| 487 | + int nPack = std::min(nPackPrelim, nRegBases - m); |
| 488 | + |
| 489 | + SmallVector<DecomposedWarpConversion::TranspositionInfo> ret; |
| 490 | + ret.reserve(mixedTranspositions.size()); |
| 491 | + if (nPack == 0) { |
| 492 | + for (auto &t : mixedTranspositions) |
| 493 | + ret.push_back(DecomposedWarpConversion::TranspositionInfo{t}); |
| 494 | + return ret; |
| 495 | + } |
| 496 | + // Consider for example the cycle |
| 497 | + // |
| 498 | + // (r2 r1 l0 r0 r3) = (r0 l0) * (r2 r1 r0 r3) |
| 499 | + // = (r3 r0) * (r3 l0) * (r3 r1) * (r3 r2) |
| 500 | + // |
| 501 | + // with `nPack` = 2 so that r0 and r1 are considered low bits. We want to |
| 502 | + // factor out any low bits from `pReg` and to incorporate them into the data |
| 503 | + // of the mixed transposition. After processing, the contribution to `pReg` |
| 504 | + // is reduced to (r3 r2) and the mixed transposition recorded is (r3 l0), with |
| 505 | + // the effects of (r3 r0) and (r3 r1) encoded in the returned selectors. |
| 506 | + // In general, low bits occurring immediately before l_j modify the selectors |
| 507 | + // of the `prmt` before the shuffle, while low bits occurring immediately |
| 508 | + // after l_k modify the selectors of the `prmt` after the shuffle. Unmodified |
| 509 | + // selectors correspond to `select` instructions. |
| 510 | + // Cases like (l0 r0 r1) must be handled by selecting a 'partner' bit that is |
| 511 | + // not used in another mixed transposition and conjugating out a low bit: |
| 512 | + // |
| 513 | + // (l0 r0 r1) = (r2 r1) * (l0 r0 r2) * (r2 r1) |
| 514 | + // = (r2 r1) * (r2 r0) * (r2 l0) * (r2 r1). |
| 515 | + // |
| 516 | + // Conjugation does not affect `pReg`. However, the set of fused mixed and |
| 517 | + // low-bit transpositions is noncommutative in cases where there are no |
| 518 | + // intervening high bits in between distinct sequences of lane bits as the |
| 519 | + // paired low bit is used in modifying the selectors of both factors: |
| 520 | + // |
| 521 | + // (l0 r0 r1 l1 r2) = (r3 r0)(r3 l0)(r3 r0) * (r2 l1)(r2 r1)(r2 r0). |
| 522 | + // |
| 523 | + // The `*` is standard composition of permutations. The groupings correspond |
| 524 | + // to different `TranspositionInfo` objects. For example, the permutation |
| 525 | + // `(r3 r0)(r3 l0)(r3 r0) = (r0 l0)` has mixed transposition `(r3 l0)` with |
| 526 | + // pre- and post-shuffle selectors determined by the `r0` bit. |
| 527 | + // Processing of mixed transpositions is performed by determining the `head` |
| 528 | + // and `tail` of an excision of bits in cycles of `pReg` and building lists |
| 529 | + // of low bits acting as selector modifiers. In the noncommutative cases, we |
| 530 | + // opt to restrict the number of post-shuffle modifiers to one. |
| 531 | + |
| 532 | + auto permuteSelector = [nPack](uint16_t sel, int bitIdx) { |
| 533 | + int lo = bitIdx + (2 - nPack); |
| 534 | + uint16_t maskHi = 0x4444; |
| 535 | + uint16_t maskLo = 0x1111 << lo; |
| 536 | + uint16_t fixed = sel & ~maskHi & ~maskLo; |
| 537 | + int shift = 2 - lo; |
| 538 | + return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift); |
| 539 | + }; |
| 540 | + auto generateSelectors = [&](int head, int tail, auto &&lowBits) { |
| 541 | + uint16_t topSel = 0x3210; |
| 542 | + uint16_t botSel = 0x7654; |
| 543 | + for (auto lowBit : lowBits) { |
| 544 | + topSel = permuteSelector(topSel, lowBit); |
| 545 | + botSel = permuteSelector(botSel, lowBit); |
| 546 | + if (lowBit != head && lowBit != tail) |
| 547 | + regBases[lowBit][0] = 1 << lowBit; |
| 548 | + } |
| 549 | + return std::pair{topSel, botSel}; |
| 550 | + }; |
| 551 | + |
| 552 | + llvm::SmallSet<int32_t, 6> pairedRegBits; |
| 553 | + for (auto [rBit, lBit] : mixedTranspositions) |
| 554 | + pairedRegBits.insert(rBit); |
| 555 | + |
| 556 | + // A low bit in a mixed transposition must be replaced by a high bit. The |
| 557 | + // choice of high bit can affect instruction count. If the first high bit |
| 558 | + // found when walking along `pReg` is unpaired, then that bit is the best |
| 559 | + // choice. We reorder the transpositions to guarantee this during processing. |
| 560 | + auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); }; |
| 561 | + auto nextHighFree = [&](auto p) { |
| 562 | + int curr = p.first; |
| 563 | + do { |
| 564 | + if (curr >= nPack) |
| 565 | + return curr == p.first || !pairedRegBits.contains(curr); |
| 566 | + curr = next(curr); |
| 567 | + } while (curr != p.first); |
| 568 | + return false; |
| 569 | + }; |
| 570 | + std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(), |
| 571 | + nextHighFree); |
| 572 | + // If `P` has an isolated low-bit mixed transposition, and `pReg` maps a low |
| 573 | + // bit to an open high bit, then the high bit should be used as the partner. |
| 574 | + auto prev = [&](int b) { |
| 575 | + int tail = b; |
| 576 | + int curr = next(b); |
| 577 | + while (curr != b) { |
| 578 | + tail = curr; |
| 579 | + curr = next(curr); |
| 580 | + } |
| 581 | + return tail; |
| 582 | + }; |
| 583 | + auto findPartner = [&](int lowBit, auto &preShufLoBits) { |
| 584 | + if (nPack == 2) { |
| 585 | + int otherLow = 1 - lowBit; |
| 586 | + int b = next(otherLow); |
| 587 | + if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) && |
| 588 | + !pairedRegBits.contains(otherLow)) { |
| 589 | + preShufLoBits.push_back(otherLow); |
| 590 | + regBases[prev(otherLow)][0] = 1 << b; |
| 591 | + pairedRegBits.insert(b); |
| 592 | + return b; |
| 593 | + } |
| 594 | + } |
| 595 | + int potentialPartner = nPack; |
| 596 | + while (pairedRegBits.contains(potentialPartner)) |
| 597 | + ++potentialPartner; |
| 598 | + pairedRegBits.insert(potentialPartner); |
| 599 | + return potentialPartner; |
| 600 | + }; |
| 601 | + |
| 602 | + for (auto p : mixedTranspositions) { |
| 603 | + int rBit = p.first; |
| 604 | + int lBit = p.second; |
| 605 | + SmallVector<int> cycle; |
| 606 | + int currBit = rBit; |
| 607 | + do { |
| 608 | + cycle.push_back(currBit); |
| 609 | + currBit = next(currBit); |
| 610 | + } while (currBit != rBit); |
| 611 | + |
| 612 | + // Find any low register bits adjacent to the excised lane bits which aren't |
| 613 | + // used in other mixed transpositions. |
| 614 | + auto isBoundary = [&](int bit) { |
| 615 | + return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit); |
| 616 | + }; |
| 617 | + auto forwardEnd = llvm::find_if(cycle, isBoundary); |
| 618 | + auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary); |
| 619 | + SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd); |
| 620 | + SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd); |
| 621 | + int head; |
| 622 | + int tail; |
| 623 | + int partnerBit = -1; |
| 624 | + |
| 625 | + // Case work to determine what to conjugate out. |
| 626 | + if (forwardEnd != cycle.end()) { |
| 627 | + if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) { |
| 628 | + // End at original or unpaired high bit. E.g. (l0 r0 r2) or (l0 r2) |
| 629 | + // No conjugation needed. |
| 630 | + head = partnerBit = *forwardEnd; |
| 631 | + } else { |
| 632 | + // End at different paired bit. E.g. (l0 r0 r1 l1 r2) |
| 633 | + // Non-leading factor in a noncommutative case. |
| 634 | + // Conjugate by first low bit in forward walk. |
| 635 | + head = postShufLoBits.front(); |
| 636 | + preShufLoBits.push_back(head); |
| 637 | + postShufLoBits.resize(1); |
| 638 | + pairedRegBits.erase(head); |
| 639 | + } |
| 640 | + tail = *backwardEnd; |
| 641 | + if (tail < nPack && pairedRegBits.contains(tail)) { |
| 642 | + // Non-terminal factor in a noncommutative case. |
| 643 | + preShufLoBits.insert(preShufLoBits.begin(), tail); |
| 644 | + } |
| 645 | + } else { |
| 646 | + if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) { |
| 647 | + // Symmetric noncommutative case. E.g. (l0 r0 l1 r1) |
| 648 | + preShufLoBits.erase(preShufLoBits.begin()); |
| 649 | + postShufLoBits.pop_back(); |
| 650 | + pairedRegBits.erase(postShufLoBits.front()); |
| 651 | + head = rBit; |
| 652 | + tail = next(rBit); |
| 653 | + } else { |
| 654 | + // Isolated low bits with single mixed transposition. E.g. (l0 r0 r1) |
| 655 | + if (postShufLoBits.size() == 2) |
| 656 | + postShufLoBits.pop_back(); |
| 657 | + head = tail = preShufLoBits.front(); |
| 658 | + } |
| 659 | + } |
| 660 | + |
| 661 | + if (partnerBit < 0) |
| 662 | + partnerBit = findPartner(head, preShufLoBits); |
| 663 | + auto [topPostSel, botPostSel] = |
| 664 | + generateSelectors(head, tail, llvm::reverse(postShufLoBits)); |
| 665 | + auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits); |
| 666 | + regBases[tail][0] = 1 << head; |
| 667 | + |
| 668 | + DecomposedWarpConversion::TranspositionInfo info; |
| 669 | + info.transposition = {partnerBit, lBit}; |
| 670 | + info.topPreSel = topPreSel; |
| 671 | + info.botPreSel = botPreSel; |
| 672 | + info.topPostSel = topPostSel; |
| 673 | + info.botPostSel = botPostSel; |
| 674 | + |
| 675 | + // In noncommutative cases, post-shuffle selectors of non-leading terms come |
| 676 | + // from a single low bit by design, so we can determine where to insert a |
| 677 | + // non-terminal factor by examining processed selectors. |
| 678 | + if (!preShufLoBits.empty()) { |
| 679 | + uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410; |
| 680 | + auto it = |
| 681 | + llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; }); |
| 682 | + ret.insert(it, info); |
| 683 | + } else { |
| 684 | + ret.push_back(info); |
| 685 | + } |
| 686 | + } |
| 687 | + if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) { |
| 688 | + // If (r0 r1) was originally in `P`, fold it into a mixed transposition. |
| 689 | + auto &t = ret.back(); |
| 690 | + t.topPostSel = 0x3120; |
| 691 | + t.botPostSel = 0x7564; |
| 692 | + } |
| 693 | + return ret; |
469 | 694 | } |
470 | 695 |
|
471 | 696 | SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>> |
@@ -763,7 +988,7 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { |
763 | 988 | auto kLane = StringAttr::get(ctx, "lane"); |
764 | 989 | if (to_vector(layout.getOutDimNames()) == |
765 | 990 | SmallVector<StringAttr, 2>{kRegister, kLane}) { |
766 | | - auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy); |
| 991 | + auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, 32); |
767 | 992 | return (factors.mixedTranspositions.size() < 2); |
768 | 993 | } |
769 | 994 | return false; |
|
0 commit comments