|
14 | 14 | /// such that each pieces can be handled by a hardware instruction.
|
15 | 15 | ///
|
16 | 16 | //===----------------------------------------------------------------------===//
|
17 |
| - |
18 | 17 | #include <mlir/Conversion/LLVMCommon/TypeConverter.h>
|
19 | 18 | #include <mlir/Dialect/Arith/IR/Arith.h>
|
20 | 19 | #include <mlir/Dialect/Func/IR/FuncOps.h>
|
@@ -779,66 +778,83 @@ struct InitTileOpPattern
|
779 | 778 | op, "Skipped InitTileOp because the result tile is not rank 2.\n");
|
780 | 779 |
|
781 | 780 | auto innerBlocks = tileTy.getInnerBlocks();
|
| 781 | + auto memorySpace = op.getSourceMemorySpaceAsInt(); |
782 | 782 |
|
783 | 783 | // skip it if innerBlocks has been set by user or compiler.
|
784 | 784 | if (innerBlocks)
|
785 | 785 | return mlir::failure();
|
786 | 786 |
|
787 | 787 | auto elemTy = tileTy.getElementType();
|
788 | 788 | int elementSize = elemTy.getIntOrFloatBitWidth();
|
789 |
| - if (isForPrefetch(op)) { |
790 |
| - innerBlocks = mlir::DenseI64ArrayAttr::get( |
791 |
| - getContext(), getInnerBlockSizes<Prefetch>( |
792 |
| - op.getOperation(), elemTy, tileTy.getShape()[0], |
793 |
| - tileTy.getShape()[1], this->uArchInterface)); |
794 |
| - } else if (isForLoad(op)) { |
795 |
| - |
796 |
| - // Set transpose and vnni |
797 |
| - bool vnni = false; |
798 |
| - bool transpose = false; |
799 |
| - |
800 |
| - auto order = tileTy.getOrder(); |
801 |
| - if (order[0] == 0 && order[1] == 1) |
802 |
| - transpose = true; |
803 |
| - |
804 |
| - for (auto user : getEffectiveUsers(op)) { |
805 |
| - if (auto loadTileOp = llvm::dyn_cast<xetile::LoadTileOp>(user)) { |
806 |
| - if (isForDPASB(loadTileOp) && elementSize < 32) { |
807 |
| - vnni = true; |
808 |
| - break; |
| 789 | + |
| 790 | + if (memorySpace == 3) { // for shared memory |
| 791 | + const unsigned int lscConstraints = 512; // 512 bytes constraint by lsc |
| 792 | + const unsigned int subgroupSize = 16; |
| 793 | + auto shape = tileTy.getShape(); |
| 794 | + int64_t innerBlockSizes[2]; |
| 795 | + // prefer to use gather loads with 16 simd lanes |
| 796 | + innerBlockSizes[0] = shape[0] % subgroupSize == 0 ? 16 : 1; |
| 797 | + innerBlockSizes[1] = |
| 798 | + (lscConstraints * 8) / (elementSize * innerBlockSizes[0]); |
| 799 | + innerBlockSizes[1] = |
| 800 | + std::min<int64_t>(innerBlockSizes[1], tileTy.getShape()[1]); |
| 801 | + innerBlocks = mlir::DenseI64ArrayAttr::get(getContext(), innerBlockSizes); |
| 802 | + } else { // for global memory |
| 803 | + if (isForPrefetch(op)) { |
| 804 | + innerBlocks = mlir::DenseI64ArrayAttr::get( |
| 805 | + getContext(), getInnerBlockSizes<Prefetch>( |
| 806 | + op.getOperation(), elemTy, tileTy.getShape()[0], |
| 807 | + tileTy.getShape()[1], this->uArchInterface)); |
| 808 | + } else if (isForLoad(op)) { |
| 809 | + |
| 810 | + // Set transpose and vnni |
| 811 | + bool vnni = false; |
| 812 | + bool transpose = false; |
| 813 | + |
| 814 | + auto order = tileTy.getOrder(); |
| 815 | + if (order[0] == 0 && order[1] == 1) |
| 816 | + transpose = true; |
| 817 | + |
| 818 | + for (auto user : getEffectiveUsers(op)) { |
| 819 | + if (auto loadTileOp = llvm::dyn_cast<xetile::LoadTileOp>(user)) { |
| 820 | + if (isForDPASB(loadTileOp) && elementSize < 32) { |
| 821 | + vnni = true; |
| 822 | + break; |
| 823 | + } |
809 | 824 | }
|
810 | 825 | }
|
811 |
| - } |
812 | 826 |
|
813 |
| - if (vnni && transpose && elementSize < 32) { |
814 |
| - int factor = 32 / elementSize; |
815 |
| - vnni = false; |
816 |
| - llvm::SmallVector<int64_t, 2> innerBlock = getInnerBlockSizes<Load>( |
817 |
| - op.getOperation(), mlir::FloatType::getF32(getContext()), |
818 |
| - tileTy.getShape()[1], (tileTy.getShape()[0]) / factor, |
819 |
| - this->uArchInterface, vnni, transpose); |
820 |
| - std::swap(innerBlock[0], innerBlock[1]); |
821 |
| - innerBlock[0] *= factor; |
822 |
| - innerBlocks = mlir::DenseI64ArrayAttr::get(getContext(), innerBlock); |
823 |
| - |
824 |
| - } else if (transpose && elementSize < 32) { |
825 |
| - return rewriter.notifyMatchFailure(op, "Invalid transpose."); |
826 |
| - } else { |
| 827 | + if (vnni && transpose && elementSize < 32) { |
| 828 | + int factor = 32 / elementSize; |
| 829 | + vnni = false; |
| 830 | + llvm::SmallVector<int64_t, 2> innerBlock = getInnerBlockSizes<Load>( |
| 831 | + op.getOperation(), mlir::FloatType::getF32(getContext()), |
| 832 | + tileTy.getShape()[1], (tileTy.getShape()[0]) / factor, |
| 833 | + this->uArchInterface, vnni, transpose); |
| 834 | + std::swap(innerBlock[0], innerBlock[1]); |
| 835 | + innerBlock[0] *= factor; |
| 836 | + innerBlocks = mlir::DenseI64ArrayAttr::get(getContext(), innerBlock); |
| 837 | + |
| 838 | + } else if (transpose && elementSize < 32) { |
| 839 | + return rewriter.notifyMatchFailure(op, "Invalid transpose."); |
| 840 | + } else { |
| 841 | + innerBlocks = mlir::DenseI64ArrayAttr::get( |
| 842 | + getContext(), |
| 843 | + getInnerBlockSizes<Load>( |
| 844 | + op.getOperation(), elemTy, tileTy.getShape()[0], |
| 845 | + tileTy.getShape()[1], this->uArchInterface, vnni, transpose)); |
| 846 | + } |
| 847 | + } else if (isForStore(op)) { |
827 | 848 | innerBlocks = mlir::DenseI64ArrayAttr::get(
|
828 |
| - getContext(), |
829 |
| - getInnerBlockSizes<Load>(op.getOperation(), elemTy, |
830 |
| - tileTy.getShape()[0], tileTy.getShape()[1], |
831 |
| - this->uArchInterface, vnni, transpose)); |
| 849 | + getContext(), getInnerBlockSizes<Store>( |
| 850 | + op.getOperation(), elemTy, tileTy.getShape()[0], |
| 851 | + tileTy.getShape()[1], this->uArchInterface)); |
| 852 | + } else { |
| 853 | + return rewriter.notifyMatchFailure( |
| 854 | + op, |
| 855 | + "The tile is used for multiple purpose. The init-duplicate pass " |
| 856 | + "should be run first to resolve this issue."); |
832 | 857 | }
|
833 |
| - } else if (isForStore(op)) { |
834 |
| - innerBlocks = mlir::DenseI64ArrayAttr::get( |
835 |
| - getContext(), getInnerBlockSizes<Store>( |
836 |
| - op.getOperation(), elemTy, tileTy.getShape()[0], |
837 |
| - tileTy.getShape()[1], this->uArchInterface)); |
838 |
| - } else { |
839 |
| - return rewriter.notifyMatchFailure( |
840 |
| - op, "The tile is used for multiple purpose. The init-duplicate pass " |
841 |
| - "should be run first to resolve this issue."); |
842 | 858 | }
|
843 | 859 |
|
844 | 860 | if (innerBlocks.empty()) {
|
|
0 commit comments