Skip to content

Commit 4bc9d29

Browse files
Chengjunpchengjunp
andauthored
[SROA] Use tree-structure merge to remove alloca (#152793)
This patch introduces a new optimization in SROA that handles the pattern where multiple non-overlapping vector `store`s completely fill an `alloca`. The current approach to handle this pattern introduces many `.vecexpand` and `.vecblend` instructions, which can dramatically slow down compilation when dealing with large `alloca`s built from many small vector `store`s. For example, consider an `alloca` of type `<128 x float>` filled by 64 `store`s of `<2 x float>` each. The current implementation requires: - 64 `shufflevector`s( `.vecexpand`) - 64 `select`s ( `.vecblend` ) - All operations use masks of size 128 - These operations form a long dependency chain This kind of IR is both difficult to optimize and slow to compile, particularly impacting the `InstCombine` pass. This patch introduces a tree-structured merge approach that significantly reduces the number of operations and improves compilation performance. Key features: - Detects when vector `store`s completely fill an `alloca` without gaps - Ensures no loads occur in the middle of the store sequence - Uses a tree-based approach with `shufflevector`s to merge stored values - Reduces the number of intermediate operations compared to linear merging - Eliminates the long dependency chains that hurt optimization Example transformation: ``` // Before: (stores do not have to be in order) %alloca = alloca <8 x float> store <2 x float> %val0, ptr %alloca ; offset 0-1 store <2 x float> %val2, ptr %alloca+16 ; offset 4-5 store <2 x float> %val1, ptr %alloca+8 ; offset 2-3 store <2 x float> %val3, ptr %alloca+24 ; offset 6-7 %result = load <8 x float>, ptr %alloca // After (tree-structured merge): %shuffle0 = shufflevector %val0, %val1, <4 x i32> <i32 0, i32 1, i32 2, i32 3> %shuffle1 = shufflevector %val2, %val3, <4 x i32> <i32 0, i32 1, i32 2, i32 3> %result = shufflevector %shuffle0, %shuffle1, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7> ``` Benefits: - Logarithmic depth (O(log n)) instead of linear dependency chains - Fewer total operations for large vectors - Better optimization opportunities for subsequent passes - Significant compilation time improvements for large vector patterns For some large cases, the compile time can be reduced from about 60s to less than 3s. --------- Co-authored-by: chengjunp <[email protected]>
1 parent b4a17b1 commit 4bc9d29

File tree

3 files changed

+829
-7
lines changed

3 files changed

+829
-7
lines changed

llvm/lib/Transforms/Scalar/SROA.cpp

Lines changed: 306 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
#include <cstdint>
9292
#include <cstring>
9393
#include <iterator>
94+
#include <queue>
9495
#include <string>
9596
#include <tuple>
9697
#include <utility>
@@ -2667,6 +2668,90 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
26672668
return V;
26682669
}
26692670

2671+
/// This function takes two vector values and combines them into a single vector
2672+
/// by concatenating their elements. The function handles:
2673+
///
2674+
/// 1. Element type mismatch: If either vector's element type differs from
2675+
/// NewAIEltType, the function bitcasts the vector to use NewAIEltType while
2676+
/// preserving the total bit width (adjusting the number of elements
2677+
/// accordingly).
2678+
///
2679+
/// 2. Size mismatch: After transforming the vectors to have the desired element
2680+
/// type, if the two vectors have different numbers of elements, the smaller
2681+
/// vector is extended with poison values to match the size of the larger
2682+
/// vector before concatenation.
2683+
///
2684+
/// 3. Concatenation: The vectors are merged using a shuffle operation that
2685+
/// places all elements of V0 first, followed by all elements of V1.
2686+
///
2687+
/// \param V0 The first vector to merge (must be a vector type)
2688+
/// \param V1 The second vector to merge (must be a vector type)
2689+
/// \param DL The data layout for size calculations
2690+
/// \param NewAIEltTy The desired element type for the result vector
2691+
/// \param Builder IRBuilder for creating new instructions
2692+
/// \return A new vector containing all elements from V0 followed by all
2693+
/// elements from V1
2694+
static Value *mergeTwoVectors(Value *V0, Value *V1, const DataLayout &DL,
2695+
Type *NewAIEltTy, IRBuilder<> &Builder) {
2696+
// V0 and V1 are vectors
2697+
// Create a new vector type with combined elements
2698+
// Use ShuffleVector to concatenate the vectors
2699+
auto *VecType0 = cast<FixedVectorType>(V0->getType());
2700+
auto *VecType1 = cast<FixedVectorType>(V1->getType());
2701+
2702+
// If V0/V1 element types are different from NewAllocaElementType,
2703+
// we need to introduce bitcasts before merging them
2704+
auto BitcastIfNeeded = [&](Value *&V, FixedVectorType *&VecType,
2705+
const char *DebugName) {
2706+
Type *EltType = VecType->getElementType();
2707+
if (EltType != NewAIEltTy) {
2708+
// Calculate new number of elements to maintain same bit width
2709+
unsigned TotalBits =
2710+
VecType->getNumElements() * DL.getTypeSizeInBits(EltType);
2711+
unsigned NewNumElts = TotalBits / DL.getTypeSizeInBits(NewAIEltTy);
2712+
2713+
auto *NewVecType = FixedVectorType::get(NewAIEltTy, NewNumElts);
2714+
V = Builder.CreateBitCast(V, NewVecType);
2715+
VecType = NewVecType;
2716+
LLVM_DEBUG(dbgs() << " bitcast " << DebugName << ": " << *V << "\n");
2717+
}
2718+
};
2719+
2720+
BitcastIfNeeded(V0, VecType0, "V0");
2721+
BitcastIfNeeded(V1, VecType1, "V1");
2722+
2723+
unsigned NumElts0 = VecType0->getNumElements();
2724+
unsigned NumElts1 = VecType1->getNumElements();
2725+
2726+
SmallVector<int, 16> ShuffleMask;
2727+
2728+
if (NumElts0 == NumElts1) {
2729+
for (unsigned i = 0; i < NumElts0 + NumElts1; ++i)
2730+
ShuffleMask.push_back(i);
2731+
} else {
2732+
// If two vectors have different sizes, we need to extend
2733+
// the smaller vector to the size of the larger vector.
2734+
unsigned SmallSize = std::min(NumElts0, NumElts1);
2735+
unsigned LargeSize = std::max(NumElts0, NumElts1);
2736+
bool IsV0Smaller = NumElts0 < NumElts1;
2737+
Value *&ExtendedVec = IsV0Smaller ? V0 : V1;
2738+
SmallVector<int, 16> ExtendMask;
2739+
for (unsigned i = 0; i < SmallSize; ++i)
2740+
ExtendMask.push_back(i);
2741+
for (unsigned i = SmallSize; i < LargeSize; ++i)
2742+
ExtendMask.push_back(PoisonMaskElem);
2743+
ExtendedVec = Builder.CreateShuffleVector(
2744+
ExtendedVec, PoisonValue::get(ExtendedVec->getType()), ExtendMask);
2745+
LLVM_DEBUG(dbgs() << " shufflevector: " << *ExtendedVec << "\n");
2746+
for (unsigned i = 0; i < NumElts0; ++i)
2747+
ShuffleMask.push_back(i);
2748+
for (unsigned i = 0; i < NumElts1; ++i)
2749+
ShuffleMask.push_back(LargeSize + i);
2750+
}
2751+
2752+
return Builder.CreateShuffleVector(V0, V1, ShuffleMask);
2753+
}
2754+
26702755
namespace {
26712756

26722757
/// Visitor to rewrite instructions using p particular slice of an alloca
@@ -2811,6 +2896,213 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
28112896
return CanSROA;
28122897
}
28132898

2899+
/// Attempts to rewrite a partition using tree-structured merge optimization.
2900+
///
2901+
/// This function analyzes a partition to determine if it can be optimized
2902+
/// using a tree-structured merge pattern, where multiple non-overlapping
2903+
/// stores completely fill an alloca. And there is no load from the alloca in
2904+
/// the middle of the stores. Such patterns can be optimized by eliminating
2905+
/// the intermediate stores and directly constructing the final vector by
2906+
/// using shufflevectors.
2907+
///
2908+
/// Example transformation:
2909+
/// Before: (stores do not have to be in order)
2910+
/// %alloca = alloca <8 x float>
2911+
/// store <2 x float> %val0, ptr %alloca ; offset 0-1
2912+
/// store <2 x float> %val2, ptr %alloca+16 ; offset 4-5
2913+
/// store <2 x float> %val1, ptr %alloca+8 ; offset 2-3
2914+
/// store <2 x float> %val3, ptr %alloca+24 ; offset 6-7
2915+
///
2916+
/// After:
2917+
/// %alloca = alloca <8 x float>
2918+
/// %shuffle0 = shufflevector %val0, %val1, <4 x i32> <i32 0, i32 1, i32 2,
2919+
/// i32 3>
2920+
/// %shuffle1 = shufflevector %val2, %val3, <4 x i32> <i32 0, i32 1, i32 2,
2921+
/// i32 3>
2922+
/// %shuffle2 = shufflevector %shuffle0, %shuffle1, <8 x i32> <i32 0, i32 1,
2923+
/// i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
2924+
/// store %shuffle2, ptr %alloca
2925+
///
2926+
/// The optimization looks for partitions that:
2927+
/// 1. Have no overlapping split slice tails
2928+
/// 2. Contain non-overlapping stores that cover the entire alloca
2929+
/// 3. Have exactly one load that reads the complete alloca structure and not
2930+
/// in the middle of the stores (TODO: maybe we can relax the constraint
2931+
/// about reading the entire alloca structure)
2932+
///
2933+
/// \param P The partition to analyze and potentially rewrite
2934+
/// \return An optional vector of values that were deleted during the rewrite
2935+
/// process, or std::nullopt if the partition cannot be optimized
2936+
/// using tree-structured merge
2937+
std::optional<SmallVector<Value *, 4>>
2938+
rewriteTreeStructuredMerge(Partition &P) {
2939+
// No tail slices that overlap with the partition
2940+
if (P.splitSliceTails().size() > 0)
2941+
return std::nullopt;
2942+
2943+
SmallVector<Value *, 4> DeletedValues;
2944+
LoadInst *TheLoad = nullptr;
2945+
2946+
// Structure to hold store information
2947+
struct StoreInfo {
2948+
StoreInst *Store;
2949+
uint64_t BeginOffset;
2950+
uint64_t EndOffset;
2951+
Value *StoredValue;
2952+
StoreInfo(StoreInst *SI, uint64_t Begin, uint64_t End, Value *Val)
2953+
: Store(SI), BeginOffset(Begin), EndOffset(End), StoredValue(Val) {}
2954+
};
2955+
2956+
SmallVector<StoreInfo, 4> StoreInfos;
2957+
2958+
// If the new alloca is a fixed vector type, we use its element type as the
2959+
// allocated element type, otherwise we use i8 as the allocated element
2960+
Type *AllocatedEltTy =
2961+
isa<FixedVectorType>(NewAI.getAllocatedType())
2962+
? cast<FixedVectorType>(NewAI.getAllocatedType())->getElementType()
2963+
: Type::getInt8Ty(NewAI.getContext());
2964+
2965+
// Helper to check if a type is
2966+
// 1. A fixed vector type
2967+
// 2. The element type is not a pointer
2968+
// 3. The element type size is byte-aligned
2969+
// We only handle the cases that the ld/st meet these conditions
2970+
auto IsTypeValidForTreeStructuredMerge = [&](Type *Ty) -> bool {
2971+
auto *FixedVecTy = dyn_cast<FixedVectorType>(Ty);
2972+
return FixedVecTy &&
2973+
DL.getTypeSizeInBits(FixedVecTy->getElementType()) % 8 == 0 &&
2974+
!FixedVecTy->getElementType()->isPointerTy();
2975+
};
2976+
2977+
for (Slice &S : P) {
2978+
auto *User = cast<Instruction>(S.getUse()->getUser());
2979+
if (auto *LI = dyn_cast<LoadInst>(User)) {
2980+
// Do not handle the case if
2981+
// 1. There is more than one load
2982+
// 2. The load is volatile
2983+
// 3. The load does not read the entire alloca structure
2984+
// 4. The load does not meet the conditions in the helper function
2985+
if (TheLoad || !IsTypeValidForTreeStructuredMerge(LI->getType()) ||
2986+
S.beginOffset() != NewAllocaBeginOffset ||
2987+
S.endOffset() != NewAllocaEndOffset || LI->isVolatile())
2988+
return std::nullopt;
2989+
TheLoad = LI;
2990+
} else if (auto *SI = dyn_cast<StoreInst>(User)) {
2991+
// Do not handle the case if
2992+
// 1. The store does not meet the conditions in the helper function
2993+
// 2. The store is volatile
2994+
if (!IsTypeValidForTreeStructuredMerge(
2995+
SI->getValueOperand()->getType()) ||
2996+
SI->isVolatile())
2997+
return std::nullopt;
2998+
StoreInfos.emplace_back(SI, S.beginOffset(), S.endOffset(),
2999+
SI->getValueOperand());
3000+
} else {
3001+
// If we have instructions other than load and store, we cannot do the
3002+
// tree structured merge
3003+
return std::nullopt;
3004+
}
3005+
}
3006+
// If we do not have any load, we cannot do the tree structured merge
3007+
if (!TheLoad)
3008+
return std::nullopt;
3009+
3010+
// If we do not have multiple stores, we cannot do the tree structured merge
3011+
if (StoreInfos.size() < 2)
3012+
return std::nullopt;
3013+
3014+
// Stores should not overlap and should cover the whole alloca
3015+
// Sort by begin offset
3016+
llvm::sort(StoreInfos, [](const StoreInfo &A, const StoreInfo &B) {
3017+
return A.BeginOffset < B.BeginOffset;
3018+
});
3019+
3020+
// Check for overlaps and coverage
3021+
uint64_t ExpectedStart = NewAllocaBeginOffset;
3022+
for (auto &StoreInfo : StoreInfos) {
3023+
uint64_t BeginOff = StoreInfo.BeginOffset;
3024+
uint64_t EndOff = StoreInfo.EndOffset;
3025+
3026+
// Check for gap or overlap
3027+
if (BeginOff != ExpectedStart)
3028+
return std::nullopt;
3029+
3030+
ExpectedStart = EndOff;
3031+
}
3032+
// Check that stores cover the entire alloca
3033+
if (ExpectedStart != NewAllocaEndOffset)
3034+
return std::nullopt;
3035+
3036+
// Stores should be in the same basic block
3037+
// The load should not be in the middle of the stores
3038+
// Note:
3039+
// If the load is in a different basic block with the stores, we can still
3040+
// do the tree structured merge. This is because we do not have the
3041+
// store->load forwarding here. The merged vector will be stored back to
3042+
// NewAI and the new load will load from NewAI. The forwarding will be
3043+
// handled later when we try to promote NewAI.
3044+
BasicBlock *LoadBB = TheLoad->getParent();
3045+
BasicBlock *StoreBB = StoreInfos[0].Store->getParent();
3046+
3047+
for (auto &StoreInfo : StoreInfos) {
3048+
if (StoreInfo.Store->getParent() != StoreBB)
3049+
return std::nullopt;
3050+
if (LoadBB == StoreBB && !StoreInfo.Store->comesBefore(TheLoad))
3051+
return std::nullopt;
3052+
}
3053+
3054+
// If we reach here, the partition can be merged with a tree structured
3055+
// merge
3056+
LLVM_DEBUG({
3057+
dbgs() << "Tree structured merge rewrite:\n Load: " << *TheLoad
3058+
<< "\n Ordered stores:\n";
3059+
for (auto [i, Info] : enumerate(StoreInfos))
3060+
dbgs() << " [" << i << "] Range[" << Info.BeginOffset << ", "
3061+
<< Info.EndOffset << ") \tStore: " << *Info.Store
3062+
<< "\tValue: " << *Info.StoredValue << "\n";
3063+
});
3064+
3065+
// Instead of having these stores, we merge all the stored values into a
3066+
// vector and store the merged value into the alloca
3067+
std::queue<Value *> VecElements;
3068+
IRBuilder<> Builder(StoreInfos.back().Store);
3069+
for (const auto &Info : StoreInfos) {
3070+
DeletedValues.push_back(Info.Store);
3071+
VecElements.push(Info.StoredValue);
3072+
}
3073+
3074+
LLVM_DEBUG(dbgs() << " Rewrite stores into shufflevectors:\n");
3075+
while (VecElements.size() > 1) {
3076+
const auto NumElts = VecElements.size();
3077+
for ([[maybe_unused]] const auto _ : llvm::seq(NumElts / 2)) {
3078+
Value *V0 = VecElements.front();
3079+
VecElements.pop();
3080+
Value *V1 = VecElements.front();
3081+
VecElements.pop();
3082+
Value *Merged = mergeTwoVectors(V0, V1, DL, AllocatedEltTy, Builder);
3083+
LLVM_DEBUG(dbgs() << " shufflevector: " << *Merged << "\n");
3084+
VecElements.push(Merged);
3085+
}
3086+
if (NumElts % 2 == 1) {
3087+
Value *V = VecElements.front();
3088+
VecElements.pop();
3089+
VecElements.push(V);
3090+
}
3091+
}
3092+
3093+
// Store the merged value into the alloca
3094+
Value *MergedValue = VecElements.front();
3095+
Builder.CreateAlignedStore(MergedValue, &NewAI, getSliceAlign());
3096+
3097+
IRBuilder<> LoadBuilder(TheLoad);
3098+
TheLoad->replaceAllUsesWith(LoadBuilder.CreateAlignedLoad(
3099+
TheLoad->getType(), &NewAI, getSliceAlign(), TheLoad->isVolatile(),
3100+
TheLoad->getName() + ".sroa.new.load"));
3101+
DeletedValues.push_back(TheLoad);
3102+
3103+
return DeletedValues;
3104+
}
3105+
28143106
private:
28153107
// Make sure the other visit overloads are visible.
28163108
using Base::visit;
@@ -4980,13 +5272,20 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
49805272
P.endOffset(), IsIntegerPromotable, VecTy,
49815273
PHIUsers, SelectUsers);
49825274
bool Promotable = true;
4983-
for (Slice *S : P.splitSliceTails()) {
4984-
Promotable &= Rewriter.visit(S);
4985-
++NumUses;
4986-
}
4987-
for (Slice &S : P) {
4988-
Promotable &= Rewriter.visit(&S);
4989-
++NumUses;
5275+
// Check whether we can have tree-structured merge.
5276+
if (auto DeletedValues = Rewriter.rewriteTreeStructuredMerge(P)) {
5277+
NumUses += DeletedValues->size() + 1;
5278+
for (Value *V : *DeletedValues)
5279+
DeadInsts.push_back(V);
5280+
} else {
5281+
for (Slice *S : P.splitSliceTails()) {
5282+
Promotable &= Rewriter.visit(S);
5283+
++NumUses;
5284+
}
5285+
for (Slice &S : P) {
5286+
Promotable &= Rewriter.visit(&S);
5287+
++NumUses;
5288+
}
49905289
}
49915290

49925291
NumAllocaPartitionUses += NumUses;

0 commit comments

Comments
 (0)