|
1 | 1 | #include "triton/Tools/LayoutUtils.h"
|
2 | 2 | #include "triton/Tools/GenericSwizzling.h"
|
3 |
| -#include "llvm/ADT/SmallSet.h" |
4 | 3 |
|
5 | 4 | namespace mlir::triton {
|
6 | 5 |
|
@@ -447,137 +446,4 @@ LinearLayout transposeLinearLayout(LinearLayout layout, ArrayRef<int> order) {
|
447 | 446 | to_vector(layout.getOutDimNames()));
|
448 | 447 | }
|
449 | 448 |
|
450 |
| -LinearLayout reorder_like(const LinearLayout &x, const LinearLayout &y) { |
451 |
| - // This will check that the names are the same up to permutation, and |
452 |
| - // apply the necessary permutation: |
453 |
| - auto x2 = x.transposeOuts(llvm::to_vector(y.getOutDimNames())); |
454 |
| - auto x3 = x2.transposeIns(llvm::to_vector(y.getInDimNames())); |
455 |
| - return x3; |
456 |
| -} |
457 |
| - |
458 |
| -LinearLayout basisPermutationLayout(const LinearLayout &src, |
459 |
| - const LinearLayout &dst) { |
460 |
| - // This function computes a permutation layout `P` which satisfies the |
461 |
| - // property `src = dst \circ P`. It requires that the multiset of basis |
462 |
| - // vectors for each of `src` and `dst` agree and that the nonzero values in |
463 |
| - // each of the multisets are unique. I.e., broadcasting is allowed in either |
464 |
| - // layout so long as the degree of broadcasting (the number of zero basis |
465 |
| - // vectors) is the same between the two layouts. |
466 |
| - // |
467 |
| - // The orders of the input and output dimensions of `P` are set to be the |
468 |
| - // order of the input dimensions of `src`. |
469 |
| - // |
470 |
| - // The mapping of broadcasting basis vectors prioritizes keeping such vectors |
471 |
| - // as fixed points of the permutation. I.e., if `src[inDim][i]` and |
472 |
| - // `dst[inDim][i]` are zero vectors, then `P[inDim][i][inDimIdx] == 1 << i`, |
473 |
| - // where `inDimIdx` is the index of `inDim` in `src`. Otherwise, they are |
474 |
| - // paired according to their order of appearance in the two layouts, again |
475 |
| - // following the order of the input dimensions of `src`. |
476 |
| - // |
477 |
| - // The algorithm first performs a linear scan over the columns of `dst` and |
478 |
| - // `src` to build a map from ('flattened') basis vectors to the input |
479 |
| - // vectors of `dst` while tracking the fixed-point zero vectors and 'free' |
480 |
| - // zero vectors. It then performs a second linear scan over `src` to build |
481 |
| - // the basis of `P`. |
482 |
| - |
483 |
| - // Check that the input and output dimensions are equal up to ordering. |
484 |
| - auto srcInDims = src.getInDimNames(); |
485 |
| - assert(std::is_permutation(srcInDims.begin(), srcInDims.end(), |
486 |
| - dst.getInDimNames().begin()) && |
487 |
| - "Layouts must have same input dimensions"); |
488 |
| - for (auto inDim : srcInDims) { |
489 |
| - assert(src.getInDimSize(inDim) == dst.getInDimSize(inDim) && |
490 |
| - "Layouts must have same input dimension sizes"); |
491 |
| - } |
492 |
| - auto srcOutDims = src.getOutDims(); |
493 |
| - assert(std::is_permutation(srcOutDims.begin(), srcOutDims.end(), |
494 |
| - dst.getOutDims().begin()) && |
495 |
| - "Layouts must have same output dimensions and dimension sizes"); |
496 |
| - |
497 |
| - auto srcFlat = src.flattenOuts(); |
498 |
| - // Reorder the output dimensions of `dst` if necessary before flattening, as |
499 |
| - // flattening depends on the order. |
500 |
| - LinearLayout dstFlat; |
501 |
| - if (!llvm::equal(src.getOutDims(), dst.getOutDims())) { |
502 |
| - auto temp = dst.transposeOuts(llvm::to_vector(src.getOutDimNames())); |
503 |
| - dstFlat = temp.flattenOuts(); |
504 |
| - } else { |
505 |
| - dstFlat = dst.flattenOuts(); |
506 |
| - } |
507 |
| - |
508 |
| - // Populate the map of flattened values to dst inputs and track zero vectors. |
509 |
| - // The `commonZeros` become fixed-points of `P`, while the 'free' zeros are |
510 |
| - // later paired with one another. |
511 |
| - DenseMap<int32_t, std::pair<StringAttr, int32_t>> valToDstInput; |
512 |
| - llvm::SmallDenseMap<StringAttr, llvm::SmallSet<int32_t, 4>> commonZeros; |
513 |
| - SmallVector<std::pair<StringAttr, int32_t>> dstFreeZeros; |
514 |
| - size_t srcFreeZerosCount = 0; |
515 |
| - |
516 |
| - // We traverse the input dimensions according to their order in `src` so that |
517 |
| - // 'free' zero vectors for a given input dimension in `src` prefer to map to |
518 |
| - // 'free' zero vectors in the same dimension in `dst. |
519 |
| - for (auto inDim : srcInDims) { |
520 |
| - int inDimSize = dstFlat.getInDimSizeLog2(inDim); |
521 |
| - for (int i = 0; i < inDimSize; ++i) { |
522 |
| - int32_t dstVal = dstFlat.getBasis(inDim, i)[0]; |
523 |
| - int32_t srcVal = srcFlat.getBasis(inDim, i)[0]; |
524 |
| - if (dstVal == 0 && srcVal == 0) { |
525 |
| - commonZeros[inDim].insert(i); |
526 |
| - } else if (dstVal == 0) { |
527 |
| - dstFreeZeros.emplace_back(inDim, i); |
528 |
| - } else { |
529 |
| - auto [it, success] = valToDstInput.try_emplace(dstVal, inDim, i); |
530 |
| - assert(success && "Found duplicate nonzero vectors in dst layout"); |
531 |
| - if (srcVal == 0) |
532 |
| - ++srcFreeZerosCount; |
533 |
| - } |
534 |
| - } |
535 |
| - } |
536 |
| - assert(srcFreeZerosCount == dstFreeZeros.size() && |
537 |
| - "src and dst layouts have differing number of zero bases"); |
538 |
| - |
539 |
| - // Build the basis vectors for the permutation layout `P`. |
540 |
| - // For each basis vector in `src`, determine its target in `dst`: |
541 |
| - // - If the vector is nonzero, find the corresponding vector in `dst`. |
542 |
| - // - If it is a zero vector common to both layouts, set it as a fixed-point. |
543 |
| - // - Otherwise, pair it with the next available free zero of `dst`. |
544 |
| - LinearLayout::BasesT pBases; |
545 |
| - size_t numDims = llvm::size(srcInDims); |
546 |
| - size_t freeZeroIdx = 0; |
547 |
| - for (auto inDim : srcInDims) { |
548 |
| - int inDimSize = srcFlat.getInDimSizeLog2(inDim); |
549 |
| - auto &inDimBases = pBases[inDim]; |
550 |
| - inDimBases.reserve(inDimSize); |
551 |
| - for (int i = 0; i < inDimSize; ++i) |
552 |
| - inDimBases.emplace_back(numDims, 0); |
553 |
| - |
554 |
| - for (int inIdx = 0; inIdx < inDimSize; ++inIdx) { |
555 |
| - int32_t val = srcFlat.getBasis(inDim, inIdx)[0]; |
556 |
| - std::pair<StringAttr, int32_t> dstTarget; |
557 |
| - |
558 |
| - if (val != 0) { |
559 |
| - auto it = valToDstInput.find(val); |
560 |
| - assert(it != valToDstInput.end() && "src basis not found in dst"); |
561 |
| - dstTarget = it->second; |
562 |
| - } else if (commonZeros.lookup(inDim).count(inIdx)) { |
563 |
| - dstTarget = {inDim, inIdx}; |
564 |
| - } else { |
565 |
| - dstTarget = dstFreeZeros[freeZeroIdx++]; |
566 |
| - } |
567 |
| - |
568 |
| - // Build the basis vector for `P` using the ordering on output dimensions |
569 |
| - // induced by the ordering on the input dimensions of `src`. |
570 |
| - auto it = llvm::find(srcInDims, dstTarget.first); |
571 |
| - int outDimIdx = std::distance(srcInDims.begin(), it); |
572 |
| - inDimBases[inIdx][outDimIdx] = 1 << dstTarget.second; |
573 |
| - } |
574 |
| - } |
575 |
| - // Declare the ordering on the `outDims` of `P` to be that of `srcInDims`. |
576 |
| - SmallVector<std::pair<StringAttr, int32_t>> outDims; |
577 |
| - for (auto outDim : srcInDims) |
578 |
| - outDims.emplace_back(outDim, srcFlat.getInDimSize(outDim)); |
579 |
| - |
580 |
| - return LinearLayout(std::move(pBases), outDims, /*requireSurjective=*/true); |
581 |
| -} |
582 |
| - |
583 | 449 | } // namespace mlir::triton
|
0 commit comments