|
91 | 91 | #include <cstdint>
|
92 | 92 | #include <cstring>
|
93 | 93 | #include <iterator>
|
| 94 | +#include <queue> |
94 | 95 | #include <string>
|
95 | 96 | #include <tuple>
|
96 | 97 | #include <utility>
|
@@ -2667,6 +2668,90 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
|
2667 | 2668 | return V;
|
2668 | 2669 | }
|
2669 | 2670 |
|
| 2671 | +/// This function takes two vector values and combines them into a single vector |
| 2672 | +/// by concatenating their elements. The function handles: |
| 2673 | +/// |
| 2674 | +/// 1. Element type mismatch: If either vector's element type differs from |
| 2675 | +/// NewAIEltType, the function bitcasts the vector to use NewAIEltType while |
| 2676 | +/// preserving the total bit width (adjusting the number of elements |
| 2677 | +/// accordingly). |
| 2678 | +/// |
| 2679 | +/// 2. Size mismatch: After transforming the vectors to have the desired element |
| 2680 | +/// type, if the two vectors have different numbers of elements, the smaller |
| 2681 | +/// vector is extended with poison values to match the size of the larger |
| 2682 | +/// vector before concatenation. |
| 2683 | +/// |
| 2684 | +/// 3. Concatenation: The vectors are merged using a shuffle operation that |
| 2685 | +/// places all elements of V0 first, followed by all elements of V1. |
| 2686 | +/// |
| 2687 | +/// \param V0 The first vector to merge (must be a vector type) |
| 2688 | +/// \param V1 The second vector to merge (must be a vector type) |
| 2689 | +/// \param DL The data layout for size calculations |
| 2690 | +/// \param NewAIEltTy The desired element type for the result vector |
| 2691 | +/// \param Builder IRBuilder for creating new instructions |
| 2692 | +/// \return A new vector containing all elements from V0 followed by all |
| 2693 | +/// elements from V1 |
| 2694 | +static Value *mergeTwoVectors(Value *V0, Value *V1, const DataLayout &DL, |
| 2695 | + Type *NewAIEltTy, IRBuilder<> &Builder) { |
| 2696 | + // V0 and V1 are vectors |
| 2697 | + // Create a new vector type with combined elements |
| 2698 | + // Use ShuffleVector to concatenate the vectors |
| 2699 | + auto *VecType0 = cast<FixedVectorType>(V0->getType()); |
| 2700 | + auto *VecType1 = cast<FixedVectorType>(V1->getType()); |
| 2701 | + |
| 2702 | + // If V0/V1 element types are different from NewAllocaElementType, |
| 2703 | + // we need to introduce bitcasts before merging them |
| 2704 | + auto BitcastIfNeeded = [&](Value *&V, FixedVectorType *&VecType, |
| 2705 | + const char *DebugName) { |
| 2706 | + Type *EltType = VecType->getElementType(); |
| 2707 | + if (EltType != NewAIEltTy) { |
| 2708 | + // Calculate new number of elements to maintain same bit width |
| 2709 | + unsigned TotalBits = |
| 2710 | + VecType->getNumElements() * DL.getTypeSizeInBits(EltType); |
| 2711 | + unsigned NewNumElts = TotalBits / DL.getTypeSizeInBits(NewAIEltTy); |
| 2712 | + |
| 2713 | + auto *NewVecType = FixedVectorType::get(NewAIEltTy, NewNumElts); |
| 2714 | + V = Builder.CreateBitCast(V, NewVecType); |
| 2715 | + VecType = NewVecType; |
| 2716 | + LLVM_DEBUG(dbgs() << " bitcast " << DebugName << ": " << *V << "\n"); |
| 2717 | + } |
| 2718 | + }; |
| 2719 | + |
| 2720 | + BitcastIfNeeded(V0, VecType0, "V0"); |
| 2721 | + BitcastIfNeeded(V1, VecType1, "V1"); |
| 2722 | + |
| 2723 | + unsigned NumElts0 = VecType0->getNumElements(); |
| 2724 | + unsigned NumElts1 = VecType1->getNumElements(); |
| 2725 | + |
| 2726 | + SmallVector<int, 16> ShuffleMask; |
| 2727 | + |
| 2728 | + if (NumElts0 == NumElts1) { |
| 2729 | + for (unsigned i = 0; i < NumElts0 + NumElts1; ++i) |
| 2730 | + ShuffleMask.push_back(i); |
| 2731 | + } else { |
| 2732 | + // If two vectors have different sizes, we need to extend |
| 2733 | + // the smaller vector to the size of the larger vector. |
| 2734 | + unsigned SmallSize = std::min(NumElts0, NumElts1); |
| 2735 | + unsigned LargeSize = std::max(NumElts0, NumElts1); |
| 2736 | + bool IsV0Smaller = NumElts0 < NumElts1; |
| 2737 | + Value *&ExtendedVec = IsV0Smaller ? V0 : V1; |
| 2738 | + SmallVector<int, 16> ExtendMask; |
| 2739 | + for (unsigned i = 0; i < SmallSize; ++i) |
| 2740 | + ExtendMask.push_back(i); |
| 2741 | + for (unsigned i = SmallSize; i < LargeSize; ++i) |
| 2742 | + ExtendMask.push_back(PoisonMaskElem); |
| 2743 | + ExtendedVec = Builder.CreateShuffleVector( |
| 2744 | + ExtendedVec, PoisonValue::get(ExtendedVec->getType()), ExtendMask); |
| 2745 | + LLVM_DEBUG(dbgs() << " shufflevector: " << *ExtendedVec << "\n"); |
| 2746 | + for (unsigned i = 0; i < NumElts0; ++i) |
| 2747 | + ShuffleMask.push_back(i); |
| 2748 | + for (unsigned i = 0; i < NumElts1; ++i) |
| 2749 | + ShuffleMask.push_back(LargeSize + i); |
| 2750 | + } |
| 2751 | + |
| 2752 | + return Builder.CreateShuffleVector(V0, V1, ShuffleMask); |
| 2753 | +} |
| 2754 | + |
2670 | 2755 | namespace {
|
2671 | 2756 |
|
2672 | 2757 | /// Visitor to rewrite instructions using p particular slice of an alloca
|
@@ -2811,6 +2896,213 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
|
2811 | 2896 | return CanSROA;
|
2812 | 2897 | }
|
2813 | 2898 |
|
| 2899 | + /// Attempts to rewrite a partition using tree-structured merge optimization. |
| 2900 | + /// |
| 2901 | + /// This function analyzes a partition to determine if it can be optimized |
| 2902 | + /// using a tree-structured merge pattern, where multiple non-overlapping |
| 2903 | + /// stores completely fill an alloca. And there is no load from the alloca in |
| 2904 | + /// the middle of the stores. Such patterns can be optimized by eliminating |
| 2905 | + /// the intermediate stores and directly constructing the final vector by |
| 2906 | + /// using shufflevectors. |
| 2907 | + /// |
| 2908 | + /// Example transformation: |
| 2909 | + /// Before: (stores do not have to be in order) |
| 2910 | + /// %alloca = alloca <8 x float> |
| 2911 | + /// store <2 x float> %val0, ptr %alloca ; offset 0-1 |
| 2912 | + /// store <2 x float> %val2, ptr %alloca+16 ; offset 4-5 |
| 2913 | + /// store <2 x float> %val1, ptr %alloca+8 ; offset 2-3 |
| 2914 | + /// store <2 x float> %val3, ptr %alloca+24 ; offset 6-7 |
| 2915 | + /// |
| 2916 | + /// After: |
| 2917 | + /// %alloca = alloca <8 x float> |
| 2918 | + /// %shuffle0 = shufflevector %val0, %val1, <4 x i32> <i32 0, i32 1, i32 2, |
| 2919 | + /// i32 3> |
| 2920 | + /// %shuffle1 = shufflevector %val2, %val3, <4 x i32> <i32 0, i32 1, i32 2, |
| 2921 | + /// i32 3> |
| 2922 | + /// %shuffle2 = shufflevector %shuffle0, %shuffle1, <8 x i32> <i32 0, i32 1, |
| 2923 | + /// i32 2, i32 3, i32 4, i32 5, i32 6, i32 7> |
| 2924 | + /// store %shuffle2, ptr %alloca |
| 2925 | + /// |
| 2926 | + /// The optimization looks for partitions that: |
| 2927 | + /// 1. Have no overlapping split slice tails |
| 2928 | + /// 2. Contain non-overlapping stores that cover the entire alloca |
| 2929 | + /// 3. Have exactly one load that reads the complete alloca structure and not |
| 2930 | + /// in the middle of the stores (TODO: maybe we can relax the constraint |
| 2931 | + /// about reading the entire alloca structure) |
| 2932 | + /// |
| 2933 | + /// \param P The partition to analyze and potentially rewrite |
| 2934 | + /// \return An optional vector of values that were deleted during the rewrite |
| 2935 | + /// process, or std::nullopt if the partition cannot be optimized |
| 2936 | + /// using tree-structured merge |
| 2937 | + std::optional<SmallVector<Value *, 4>> |
| 2938 | + rewriteTreeStructuredMerge(Partition &P) { |
| 2939 | + // No tail slices that overlap with the partition |
| 2940 | + if (P.splitSliceTails().size() > 0) |
| 2941 | + return std::nullopt; |
| 2942 | + |
| 2943 | + SmallVector<Value *, 4> DeletedValues; |
| 2944 | + LoadInst *TheLoad = nullptr; |
| 2945 | + |
| 2946 | + // Structure to hold store information |
| 2947 | + struct StoreInfo { |
| 2948 | + StoreInst *Store; |
| 2949 | + uint64_t BeginOffset; |
| 2950 | + uint64_t EndOffset; |
| 2951 | + Value *StoredValue; |
| 2952 | + StoreInfo(StoreInst *SI, uint64_t Begin, uint64_t End, Value *Val) |
| 2953 | + : Store(SI), BeginOffset(Begin), EndOffset(End), StoredValue(Val) {} |
| 2954 | + }; |
| 2955 | + |
| 2956 | + SmallVector<StoreInfo, 4> StoreInfos; |
| 2957 | + |
| 2958 | + // If the new alloca is a fixed vector type, we use its element type as the |
| 2959 | + // allocated element type, otherwise we use i8 as the allocated element |
| 2960 | + Type *AllocatedEltTy = |
| 2961 | + isa<FixedVectorType>(NewAI.getAllocatedType()) |
| 2962 | + ? cast<FixedVectorType>(NewAI.getAllocatedType())->getElementType() |
| 2963 | + : Type::getInt8Ty(NewAI.getContext()); |
| 2964 | + |
| 2965 | + // Helper to check if a type is |
| 2966 | + // 1. A fixed vector type |
| 2967 | + // 2. The element type is not a pointer |
| 2968 | + // 3. The element type size is byte-aligned |
| 2969 | + // We only handle the cases that the ld/st meet these conditions |
| 2970 | + auto IsTypeValidForTreeStructuredMerge = [&](Type *Ty) -> bool { |
| 2971 | + auto *FixedVecTy = dyn_cast<FixedVectorType>(Ty); |
| 2972 | + return FixedVecTy && |
| 2973 | + DL.getTypeSizeInBits(FixedVecTy->getElementType()) % 8 == 0 && |
| 2974 | + !FixedVecTy->getElementType()->isPointerTy(); |
| 2975 | + }; |
| 2976 | + |
| 2977 | + for (Slice &S : P) { |
| 2978 | + auto *User = cast<Instruction>(S.getUse()->getUser()); |
| 2979 | + if (auto *LI = dyn_cast<LoadInst>(User)) { |
| 2980 | + // Do not handle the case if |
| 2981 | + // 1. There is more than one load |
| 2982 | + // 2. The load is volatile |
| 2983 | + // 3. The load does not read the entire alloca structure |
| 2984 | + // 4. The load does not meet the conditions in the helper function |
| 2985 | + if (TheLoad || !IsTypeValidForTreeStructuredMerge(LI->getType()) || |
| 2986 | + S.beginOffset() != NewAllocaBeginOffset || |
| 2987 | + S.endOffset() != NewAllocaEndOffset || LI->isVolatile()) |
| 2988 | + return std::nullopt; |
| 2989 | + TheLoad = LI; |
| 2990 | + } else if (auto *SI = dyn_cast<StoreInst>(User)) { |
| 2991 | + // Do not handle the case if |
| 2992 | + // 1. The store does not meet the conditions in the helper function |
| 2993 | + // 2. The store is volatile |
| 2994 | + if (!IsTypeValidForTreeStructuredMerge( |
| 2995 | + SI->getValueOperand()->getType()) || |
| 2996 | + SI->isVolatile()) |
| 2997 | + return std::nullopt; |
| 2998 | + StoreInfos.emplace_back(SI, S.beginOffset(), S.endOffset(), |
| 2999 | + SI->getValueOperand()); |
| 3000 | + } else { |
| 3001 | + // If we have instructions other than load and store, we cannot do the |
| 3002 | + // tree structured merge |
| 3003 | + return std::nullopt; |
| 3004 | + } |
| 3005 | + } |
| 3006 | + // If we do not have any load, we cannot do the tree structured merge |
| 3007 | + if (!TheLoad) |
| 3008 | + return std::nullopt; |
| 3009 | + |
| 3010 | + // If we do not have multiple stores, we cannot do the tree structured merge |
| 3011 | + if (StoreInfos.size() < 2) |
| 3012 | + return std::nullopt; |
| 3013 | + |
| 3014 | + // Stores should not overlap and should cover the whole alloca |
| 3015 | + // Sort by begin offset |
| 3016 | + llvm::sort(StoreInfos, [](const StoreInfo &A, const StoreInfo &B) { |
| 3017 | + return A.BeginOffset < B.BeginOffset; |
| 3018 | + }); |
| 3019 | + |
| 3020 | + // Check for overlaps and coverage |
| 3021 | + uint64_t ExpectedStart = NewAllocaBeginOffset; |
| 3022 | + for (auto &StoreInfo : StoreInfos) { |
| 3023 | + uint64_t BeginOff = StoreInfo.BeginOffset; |
| 3024 | + uint64_t EndOff = StoreInfo.EndOffset; |
| 3025 | + |
| 3026 | + // Check for gap or overlap |
| 3027 | + if (BeginOff != ExpectedStart) |
| 3028 | + return std::nullopt; |
| 3029 | + |
| 3030 | + ExpectedStart = EndOff; |
| 3031 | + } |
| 3032 | + // Check that stores cover the entire alloca |
| 3033 | + if (ExpectedStart != NewAllocaEndOffset) |
| 3034 | + return std::nullopt; |
| 3035 | + |
| 3036 | + // Stores should be in the same basic block |
| 3037 | + // The load should not be in the middle of the stores |
| 3038 | + // Note: |
| 3039 | + // If the load is in a different basic block with the stores, we can still |
| 3040 | + // do the tree structured merge. This is because we do not have the |
| 3041 | + // store->load forwarding here. The merged vector will be stored back to |
| 3042 | + // NewAI and the new load will load from NewAI. The forwarding will be |
| 3043 | + // handled later when we try to promote NewAI. |
| 3044 | + BasicBlock *LoadBB = TheLoad->getParent(); |
| 3045 | + BasicBlock *StoreBB = StoreInfos[0].Store->getParent(); |
| 3046 | + |
| 3047 | + for (auto &StoreInfo : StoreInfos) { |
| 3048 | + if (StoreInfo.Store->getParent() != StoreBB) |
| 3049 | + return std::nullopt; |
| 3050 | + if (LoadBB == StoreBB && !StoreInfo.Store->comesBefore(TheLoad)) |
| 3051 | + return std::nullopt; |
| 3052 | + } |
| 3053 | + |
| 3054 | + // If we reach here, the partition can be merged with a tree structured |
| 3055 | + // merge |
| 3056 | + LLVM_DEBUG({ |
| 3057 | + dbgs() << "Tree structured merge rewrite:\n Load: " << *TheLoad |
| 3058 | + << "\n Ordered stores:\n"; |
| 3059 | + for (auto [i, Info] : enumerate(StoreInfos)) |
| 3060 | + dbgs() << " [" << i << "] Range[" << Info.BeginOffset << ", " |
| 3061 | + << Info.EndOffset << ") \tStore: " << *Info.Store |
| 3062 | + << "\tValue: " << *Info.StoredValue << "\n"; |
| 3063 | + }); |
| 3064 | + |
| 3065 | + // Instead of having these stores, we merge all the stored values into a |
| 3066 | + // vector and store the merged value into the alloca |
| 3067 | + std::queue<Value *> VecElements; |
| 3068 | + IRBuilder<> Builder(StoreInfos.back().Store); |
| 3069 | + for (const auto &Info : StoreInfos) { |
| 3070 | + DeletedValues.push_back(Info.Store); |
| 3071 | + VecElements.push(Info.StoredValue); |
| 3072 | + } |
| 3073 | + |
| 3074 | + LLVM_DEBUG(dbgs() << " Rewrite stores into shufflevectors:\n"); |
| 3075 | + while (VecElements.size() > 1) { |
| 3076 | + const auto NumElts = VecElements.size(); |
| 3077 | + for ([[maybe_unused]] const auto _ : llvm::seq(NumElts / 2)) { |
| 3078 | + Value *V0 = VecElements.front(); |
| 3079 | + VecElements.pop(); |
| 3080 | + Value *V1 = VecElements.front(); |
| 3081 | + VecElements.pop(); |
| 3082 | + Value *Merged = mergeTwoVectors(V0, V1, DL, AllocatedEltTy, Builder); |
| 3083 | + LLVM_DEBUG(dbgs() << " shufflevector: " << *Merged << "\n"); |
| 3084 | + VecElements.push(Merged); |
| 3085 | + } |
| 3086 | + if (NumElts % 2 == 1) { |
| 3087 | + Value *V = VecElements.front(); |
| 3088 | + VecElements.pop(); |
| 3089 | + VecElements.push(V); |
| 3090 | + } |
| 3091 | + } |
| 3092 | + |
| 3093 | + // Store the merged value into the alloca |
| 3094 | + Value *MergedValue = VecElements.front(); |
| 3095 | + Builder.CreateAlignedStore(MergedValue, &NewAI, getSliceAlign()); |
| 3096 | + |
| 3097 | + IRBuilder<> LoadBuilder(TheLoad); |
| 3098 | + TheLoad->replaceAllUsesWith(LoadBuilder.CreateAlignedLoad( |
| 3099 | + TheLoad->getType(), &NewAI, getSliceAlign(), TheLoad->isVolatile(), |
| 3100 | + TheLoad->getName() + ".sroa.new.load")); |
| 3101 | + DeletedValues.push_back(TheLoad); |
| 3102 | + |
| 3103 | + return DeletedValues; |
| 3104 | + } |
| 3105 | + |
2814 | 3106 | private:
|
2815 | 3107 | // Make sure the other visit overloads are visible.
|
2816 | 3108 | using Base::visit;
|
@@ -4980,13 +5272,20 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
|
4980 | 5272 | P.endOffset(), IsIntegerPromotable, VecTy,
|
4981 | 5273 | PHIUsers, SelectUsers);
|
4982 | 5274 | bool Promotable = true;
|
4983 |
| - for (Slice *S : P.splitSliceTails()) { |
4984 |
| - Promotable &= Rewriter.visit(S); |
4985 |
| - ++NumUses; |
4986 |
| - } |
4987 |
| - for (Slice &S : P) { |
4988 |
| - Promotable &= Rewriter.visit(&S); |
4989 |
| - ++NumUses; |
| 5275 | + // Check whether we can have tree-structured merge. |
| 5276 | + if (auto DeletedValues = Rewriter.rewriteTreeStructuredMerge(P)) { |
| 5277 | + NumUses += DeletedValues->size() + 1; |
| 5278 | + for (Value *V : *DeletedValues) |
| 5279 | + DeadInsts.push_back(V); |
| 5280 | + } else { |
| 5281 | + for (Slice *S : P.splitSliceTails()) { |
| 5282 | + Promotable &= Rewriter.visit(S); |
| 5283 | + ++NumUses; |
| 5284 | + } |
| 5285 | + for (Slice &S : P) { |
| 5286 | + Promotable &= Rewriter.visit(&S); |
| 5287 | + ++NumUses; |
| 5288 | + } |
4990 | 5289 | }
|
4991 | 5290 |
|
4992 | 5291 | NumAllocaPartitionUses += NumUses;
|
|
0 commit comments