|
91 | 91 | #include <cstdint> |
92 | 92 | #include <cstring> |
93 | 93 | #include <iterator> |
| 94 | +#include <queue> |
94 | 95 | #include <string> |
95 | 96 | #include <tuple> |
96 | 97 | #include <utility> |
@@ -2678,6 +2679,53 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, |
2678 | 2679 | return V; |
2679 | 2680 | } |
2680 | 2681 |
|
| 2682 | +static Value *mergeTwoVectors(Value *V0, Value *V1, IRBuilder<> &Builder) { |
| 2683 | + assert(V0->getType()->isVectorTy() && V1->getType()->isVectorTy() && |
| 2684 | + "Can not merge two non-vector values"); |
| 2685 | + |
| 2686 | + // V0 and V1 are vectors |
| 2687 | + // Create a new vector type with combined elements |
| 2688 | + // Use ShuffleVector to concatenate the vectors |
| 2689 | + auto *VecType0 = cast<FixedVectorType>(V0->getType()); |
| 2690 | + auto *VecType1 = cast<FixedVectorType>(V1->getType()); |
| 2691 | + |
| 2692 | + assert(VecType0->getElementType() == VecType1->getElementType() && |
| 2693 | + "Can not merge two vectors with different element types"); |
| 2694 | + unsigned NumElts0 = VecType0->getNumElements(); |
| 2695 | + unsigned NumElts1 = VecType1->getNumElements(); |
| 2696 | + |
| 2697 | + SmallVector<int, 16> ShuffleMask; |
| 2698 | + |
| 2699 | + if (NumElts0 == NumElts1) { |
| 2700 | + for (unsigned i = 0; i < NumElts0 + NumElts1; ++i) |
| 2701 | + ShuffleMask.push_back(i); |
| 2702 | + } else { |
| 2703 | + // If two vectors have different sizes, we need to extend |
| 2704 | + // the smaller vector to the size of the larger vector. |
| 2705 | + unsigned SmallSize = std::min(NumElts0, NumElts1); |
| 2706 | + unsigned LargeSize = std::max(NumElts0, NumElts1); |
| 2707 | + bool IsV0Smaller = NumElts0 < NumElts1; |
| 2708 | + Value *SmallVec = IsV0Smaller ? V0 : V1; |
| 2709 | + |
| 2710 | + SmallVector<int, 16> ExtendMask; |
| 2711 | + for (unsigned i = 0; i < SmallSize; ++i) |
| 2712 | + ExtendMask.push_back(i); |
| 2713 | + for (unsigned i = SmallSize; i < LargeSize; ++i) |
| 2714 | + ExtendMask.push_back(PoisonMaskElem); |
| 2715 | + Value *ExtendedVec = Builder.CreateShuffleVector( |
| 2716 | + SmallVec, PoisonValue::get(SmallVec->getType()), ExtendMask); |
| 2717 | + LLVM_DEBUG(dbgs() << " shufflevector: " << *ExtendedVec << "\n"); |
| 2718 | + V0 = IsV0Smaller ? ExtendedVec : V0; |
| 2719 | + V1 = IsV0Smaller ? V1 : ExtendedVec; |
| 2720 | + for (unsigned i = 0; i < NumElts0; ++i) |
| 2721 | + ShuffleMask.push_back(i); |
| 2722 | + for (unsigned i = 0; i < NumElts1; ++i) |
| 2723 | + ShuffleMask.push_back(LargeSize + i); |
| 2724 | + } |
| 2725 | + |
| 2726 | + return Builder.CreateShuffleVector(V0, V1, ShuffleMask); |
| 2727 | +} |
| 2728 | + |
2681 | 2729 | namespace { |
2682 | 2730 |
|
2683 | 2731 | /// Visitor to rewrite instructions using p particular slice of an alloca |
@@ -2822,6 +2870,230 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> { |
2822 | 2870 | return CanSROA; |
2823 | 2871 | } |
2824 | 2872 |
|
| 2873 | + /// Attempts to rewrite a partition using tree-structured merge optimization. |
| 2874 | + /// |
| 2875 | + /// This function analyzes a partition to determine if it can be optimized |
| 2876 | + /// using a tree-structured merge pattern, where multiple non-overlapping |
| 2877 | + /// stores completely fill an alloca. And there is no load from the alloca in |
| 2878 | + /// the middle of the stores. Such patterns can be optimized by eliminating |
| 2879 | + /// the intermediate stores and directly constructing the final vector by |
| 2880 | + /// using shufflevectors. |
| 2881 | + /// |
| 2882 | + /// Example transformation: |
| 2883 | + /// Before: (stores do not have to be in order) |
| 2884 | + /// %alloca = alloca <8 x float> |
| 2885 | + /// store <2 x float> %val0, ptr %alloca ; offset 0-1 |
| 2886 | + /// store <2 x float> %val2, ptr %alloca+16 ; offset 4-5 |
| 2887 | + /// store <2 x float> %val1, ptr %alloca+8 ; offset 2-3 |
| 2888 | + /// store <2 x float> %val3, ptr %alloca+24 ; offset 6-7 |
| 2889 | + /// |
| 2890 | + /// After: |
| 2891 | + /// %alloca = alloca <8 x float> |
| 2892 | + /// %shuffle0 = shufflevector %val0, %val1, <4 x i32> <i32 0, i32 1, i32 2, |
| 2893 | + /// i32 3> |
| 2894 | + /// %shuffle1 = shufflevector %val2, %val3, <4 x i32> <i32 0, i32 1, i32 2, |
| 2895 | + /// i32 3> |
| 2896 | + /// %shuffle2 = shufflevector %shuffle0, %shuffle1, <8 x i32> <i32 0, i32 1, |
| 2897 | + /// i32 2, i32 3, i32 4, i32 5, i32 6, i32 7> |
| 2898 | + /// store %shuffle2, ptr %alloca |
| 2899 | + /// |
| 2900 | + /// The optimization looks for partitions that: |
| 2901 | + /// 1. Have no overlapping split slice tails |
| 2902 | + /// 2. Contain non-overlapping stores that cover the entire alloca |
| 2903 | + /// 3. Have exactly one load that reads the complete alloca structure and not |
| 2904 | + /// in the middle of the stores (TODO: maybe we can relax the constraint |
| 2905 | + /// about reading the entire alloca structure) |
| 2906 | + /// |
| 2907 | + /// \param P The partition to analyze and potentially rewrite |
| 2908 | + /// \return An optional vector of values that were deleted during the rewrite |
| 2909 | + /// process, or std::nullopt if the partition cannot be optimized |
| 2910 | + /// using tree-structured merge |
| 2911 | + std::optional<SmallVector<Value *, 4>> |
| 2912 | + rewriteTreeStructuredMerge(Partition &P) { |
| 2913 | + // No tail slices that overlap with the partition |
| 2914 | + if (P.splitSliceTails().size() > 0) |
| 2915 | + return std::nullopt; |
| 2916 | + |
| 2917 | + SmallVector<Value *, 4> DeletedValues; |
| 2918 | + LoadInst *TheLoad = nullptr; |
| 2919 | + |
| 2920 | + // Structure to hold store information |
| 2921 | + struct StoreInfo { |
| 2922 | + StoreInst *Store; |
| 2923 | + uint64_t BeginOffset; |
| 2924 | + uint64_t EndOffset; |
| 2925 | + Value *StoredValue; |
| 2926 | + TypeSize StoredTypeSize = TypeSize::getZero(); |
| 2927 | + |
| 2928 | + StoreInfo(StoreInst *SI, uint64_t Begin, uint64_t End, Value *Val, |
| 2929 | + TypeSize StoredTypeSize) |
| 2930 | + : Store(SI), BeginOffset(Begin), EndOffset(End), StoredValue(Val), |
| 2931 | + StoredTypeSize(StoredTypeSize) {} |
| 2932 | + }; |
| 2933 | + |
| 2934 | + SmallVector<StoreInfo, 4> StoreInfos; |
| 2935 | + |
| 2936 | + // The alloca must be a fixed vector type |
| 2937 | + auto *AllocatedTy = NewAI.getAllocatedType(); |
| 2938 | + if (!isa<FixedVectorType>(AllocatedTy)) |
| 2939 | + return std::nullopt; |
| 2940 | + |
| 2941 | + Slice *LoadSlice = nullptr; |
| 2942 | + Type *LoadElementType = nullptr; |
| 2943 | + Type *StoreElementType = nullptr; |
| 2944 | + for (Slice &S : P) { |
| 2945 | + auto *User = cast<Instruction>(S.getUse()->getUser()); |
| 2946 | + if (auto *LI = dyn_cast<LoadInst>(User)) { |
| 2947 | + // Do not handle the case where there is more than one load |
| 2948 | + // TODO: maybe we can handle this case |
| 2949 | + if (TheLoad) |
| 2950 | + return std::nullopt; |
| 2951 | + // If load is not a fixed vector type, we do not handle it |
| 2952 | + // If the number of loaded bits is not the same as the new alloca type |
| 2953 | + // size, we do not handle it |
| 2954 | + auto *FixedVecTy = dyn_cast<FixedVectorType>(LI->getType()); |
| 2955 | + if (!FixedVecTy) |
| 2956 | + return std::nullopt; |
| 2957 | + if (DL.getTypeSizeInBits(FixedVecTy) != |
| 2958 | + DL.getTypeSizeInBits(NewAI.getAllocatedType())) |
| 2959 | + return std::nullopt; |
| 2960 | + LoadElementType = FixedVecTy->getElementType(); |
| 2961 | + TheLoad = LI; |
| 2962 | + LoadSlice = &S; |
| 2963 | + } else if (auto *SI = dyn_cast<StoreInst>(User)) { |
| 2964 | + // The store needs to be a fixed vector type |
| 2965 | + // All the stores should have the same element type |
| 2966 | + Type *StoredValueType = SI->getValueOperand()->getType(); |
| 2967 | + Type *CurrentElementType = nullptr; |
| 2968 | + TypeSize StoredTypeSize = TypeSize::getZero(); |
| 2969 | + if (auto *FixedVecTy = dyn_cast<FixedVectorType>(StoredValueType)) { |
| 2970 | + // Fixed vector type - use its element type |
| 2971 | + CurrentElementType = FixedVecTy->getElementType(); |
| 2972 | + StoredTypeSize = DL.getTypeSizeInBits(FixedVecTy); |
| 2973 | + } else |
| 2974 | + return std::nullopt; |
| 2975 | + // Check element type consistency across all stores |
| 2976 | + if (StoreElementType && StoreElementType != CurrentElementType) |
| 2977 | + return std::nullopt; |
| 2978 | + StoreElementType = CurrentElementType; |
| 2979 | + StoreInfos.emplace_back(SI, S.beginOffset(), S.endOffset(), |
| 2980 | + SI->getValueOperand(), StoredTypeSize); |
| 2981 | + } else { |
| 2982 | + // If we have instructions other than load and store, we cannot do the |
| 2983 | + // tree structured merge |
| 2984 | + return std::nullopt; |
| 2985 | + } |
| 2986 | + } |
| 2987 | + // If we do not have any load, we cannot do the tree structured merge |
| 2988 | + if (!TheLoad) |
| 2989 | + return std::nullopt; |
| 2990 | + |
| 2991 | + // If we do not have any stores, we cannot do the tree structured merge |
| 2992 | + if (StoreInfos.empty()) |
| 2993 | + return std::nullopt; |
| 2994 | + |
| 2995 | + // The load and store element types should be the same |
| 2996 | + if (LoadElementType != StoreElementType) |
| 2997 | + return std::nullopt; |
| 2998 | + |
| 2999 | + // The load should cover the whole alloca |
| 3000 | + // TODO: maybe we can relax this constraint |
| 3001 | + if (!LoadSlice || LoadSlice->beginOffset() != NewAllocaBeginOffset || |
| 3002 | + LoadSlice->endOffset() != NewAllocaEndOffset) |
| 3003 | + return std::nullopt; |
| 3004 | + |
| 3005 | + // Stores should not overlap and should cover the whole alloca |
| 3006 | + // Sort by begin offset |
| 3007 | + llvm::sort(StoreInfos, [](const StoreInfo &A, const StoreInfo &B) { |
| 3008 | + return A.BeginOffset < B.BeginOffset; |
| 3009 | + }); |
| 3010 | + |
| 3011 | + // Check for overlaps and coverage |
| 3012 | + uint64_t ExpectedStart = NewAllocaBeginOffset; |
| 3013 | + TypeSize TotalStoreBits = TypeSize::getZero(); |
| 3014 | + Instruction *PrevStore = nullptr; |
| 3015 | + for (auto &StoreInfo : StoreInfos) { |
| 3016 | + uint64_t BeginOff = StoreInfo.BeginOffset; |
| 3017 | + uint64_t EndOff = StoreInfo.EndOffset; |
| 3018 | + |
| 3019 | + // Check for gap or overlap |
| 3020 | + if (BeginOff != ExpectedStart) |
| 3021 | + return std::nullopt; |
| 3022 | + |
| 3023 | + ExpectedStart = EndOff; |
| 3024 | + TotalStoreBits += StoreInfo.StoredTypeSize; |
| 3025 | + PrevStore = StoreInfo.Store; |
| 3026 | + } |
| 3027 | + // Check that stores cover the entire alloca |
| 3028 | + // We need check both the end offset and the total store bits |
| 3029 | + if (ExpectedStart != NewAllocaEndOffset || |
| 3030 | + TotalStoreBits != DL.getTypeSizeInBits(NewAI.getAllocatedType())) |
| 3031 | + return std::nullopt; |
| 3032 | + |
| 3033 | + // Stores should be in the same basic block |
| 3034 | + // The load should not be in the middle of the stores |
| 3035 | + BasicBlock *LoadBB = TheLoad->getParent(); |
| 3036 | + BasicBlock *StoreBB = StoreInfos[0].Store->getParent(); |
| 3037 | + |
| 3038 | + for (auto &StoreInfo : StoreInfos) { |
| 3039 | + if (StoreInfo.Store->getParent() != StoreBB) |
| 3040 | + return std::nullopt; |
| 3041 | + if (LoadBB == StoreBB && !StoreInfo.Store->comesBefore(TheLoad)) |
| 3042 | + return std::nullopt; |
| 3043 | + } |
| 3044 | + |
| 3045 | + // If we reach here, the partition can be merged with a tree structured |
| 3046 | + // merge |
| 3047 | + LLVM_DEBUG({ |
| 3048 | + dbgs() << "Tree structured merge rewrite:\n Load: " << *TheLoad |
| 3049 | + << "\n Ordered stores:\n"; |
| 3050 | + for (auto [i, Info] : enumerate(StoreInfos)) |
| 3051 | + dbgs() << " [" << i << "] Range[" << Info.BeginOffset << ", " |
| 3052 | + << Info.EndOffset << ") \tStore: " << *Info.Store |
| 3053 | + << "\tValue: " << *Info.StoredValue << "\n"; |
| 3054 | + }); |
| 3055 | + |
| 3056 | + // Instead of having these stores, we merge all the stored values into a |
| 3057 | + // vector and store the merged value into the alloca |
| 3058 | + std::queue<Value *> VecElements; |
| 3059 | + IRBuilder<> Builder(StoreInfos.back().Store); |
| 3060 | + for (const auto &Info : StoreInfos) { |
| 3061 | + DeletedValues.push_back(Info.Store); |
| 3062 | + VecElements.push(Info.StoredValue); |
| 3063 | + } |
| 3064 | + |
| 3065 | + LLVM_DEBUG(dbgs() << " Rewrite stores into shufflevectors:\n"); |
| 3066 | + while (VecElements.size() > 1) { |
| 3067 | + uint64_t NumElts = VecElements.size(); |
| 3068 | + for (uint64_t i = 0; i < NumElts / 2; i++) { |
| 3069 | + Value *V0 = VecElements.front(); |
| 3070 | + VecElements.pop(); |
| 3071 | + Value *V1 = VecElements.front(); |
| 3072 | + VecElements.pop(); |
| 3073 | + Value *Merged = mergeTwoVectors(V0, V1, Builder); |
| 3074 | + LLVM_DEBUG(dbgs() << " shufflevector: " << *Merged << "\n"); |
| 3075 | + VecElements.push(Merged); |
| 3076 | + } |
| 3077 | + if (NumElts % 2 == 1) { |
| 3078 | + Value *V = VecElements.front(); |
| 3079 | + VecElements.pop(); |
| 3080 | + VecElements.push(V); |
| 3081 | + } |
| 3082 | + } |
| 3083 | + |
| 3084 | + // Store the merged value into the alloca |
| 3085 | + Value *MergedValue = VecElements.front(); |
| 3086 | + Builder.CreateAlignedStore(MergedValue, &NewAI, getSliceAlign()); |
| 3087 | + |
| 3088 | + IRBuilder<> LoadBuilder(TheLoad); |
| 3089 | + TheLoad->replaceAllUsesWith(LoadBuilder.CreateAlignedLoad( |
| 3090 | + TheLoad->getType(), &NewAI, getSliceAlign(), TheLoad->isVolatile(), |
| 3091 | + TheLoad->getName() + ".sroa.new.load")); |
| 3092 | + DeletedValues.push_back(TheLoad); |
| 3093 | + |
| 3094 | + return DeletedValues; |
| 3095 | + } |
| 3096 | + |
2825 | 3097 | private: |
2826 | 3098 | // Make sure the other visit overloads are visible. |
2827 | 3099 | using Base::visit; |
@@ -4996,13 +5268,22 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, |
4996 | 5268 | P.endOffset(), IsIntegerPromotable, VecTy, |
4997 | 5269 | PHIUsers, SelectUsers); |
4998 | 5270 | bool Promotable = true; |
4999 | | - for (Slice *S : P.splitSliceTails()) { |
5000 | | - Promotable &= Rewriter.visit(S); |
5001 | | - ++NumUses; |
5002 | | - } |
5003 | | - for (Slice &S : P) { |
5004 | | - Promotable &= Rewriter.visit(&S); |
5005 | | - ++NumUses; |
| 5271 | + // Check whether we can have tree-structured merge. |
| 5272 | + std::optional<SmallVector<Value *, 4>> DeletedValues = |
| 5273 | + Rewriter.rewriteTreeStructuredMerge(P); |
| 5274 | + if (DeletedValues) { |
| 5275 | + NumUses += DeletedValues->size() + 1; |
| 5276 | + for (Value *V : *DeletedValues) |
| 5277 | + DeadInsts.push_back(V); |
| 5278 | + } else { |
| 5279 | + for (Slice *S : P.splitSliceTails()) { |
| 5280 | + Promotable &= Rewriter.visit(S); |
| 5281 | + ++NumUses; |
| 5282 | + } |
| 5283 | + for (Slice &S : P) { |
| 5284 | + Promotable &= Rewriter.visit(&S); |
| 5285 | + ++NumUses; |
| 5286 | + } |
5006 | 5287 | } |
5007 | 5288 |
|
5008 | 5289 | NumAllocaPartitionUses += NumUses; |
|
0 commit comments