Skip to content

Commit a30ca09

Browse files
author
chengjunp
committed
Initial impl of tree structure merge in SROA
1 parent b698927 commit a30ca09

File tree

3 files changed

+910
-7
lines changed

3 files changed

+910
-7
lines changed

llvm/lib/Transforms/Scalar/SROA.cpp

Lines changed: 288 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
#include <cstdint>
9292
#include <cstring>
9393
#include <iterator>
94+
#include <queue>
9495
#include <string>
9596
#include <tuple>
9697
#include <utility>
@@ -2678,6 +2679,53 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
26782679
return V;
26792680
}
26802681

2682+
static Value *mergeTwoVectors(Value *V0, Value *V1, IRBuilder<> &Builder) {
2683+
assert(V0->getType()->isVectorTy() && V1->getType()->isVectorTy() &&
2684+
"Can not merge two non-vector values");
2685+
2686+
// V0 and V1 are vectors
2687+
// Create a new vector type with combined elements
2688+
// Use ShuffleVector to concatenate the vectors
2689+
auto *VecType0 = cast<FixedVectorType>(V0->getType());
2690+
auto *VecType1 = cast<FixedVectorType>(V1->getType());
2691+
2692+
assert(VecType0->getElementType() == VecType1->getElementType() &&
2693+
"Can not merge two vectors with different element types");
2694+
unsigned NumElts0 = VecType0->getNumElements();
2695+
unsigned NumElts1 = VecType1->getNumElements();
2696+
2697+
SmallVector<int, 16> ShuffleMask;
2698+
2699+
if (NumElts0 == NumElts1) {
2700+
for (unsigned i = 0; i < NumElts0 + NumElts1; ++i)
2701+
ShuffleMask.push_back(i);
2702+
} else {
2703+
// If two vectors have different sizes, we need to extend
2704+
// the smaller vector to the size of the larger vector.
2705+
unsigned SmallSize = std::min(NumElts0, NumElts1);
2706+
unsigned LargeSize = std::max(NumElts0, NumElts1);
2707+
bool IsV0Smaller = NumElts0 < NumElts1;
2708+
Value *SmallVec = IsV0Smaller ? V0 : V1;
2709+
2710+
SmallVector<int, 16> ExtendMask;
2711+
for (unsigned i = 0; i < SmallSize; ++i)
2712+
ExtendMask.push_back(i);
2713+
for (unsigned i = SmallSize; i < LargeSize; ++i)
2714+
ExtendMask.push_back(PoisonMaskElem);
2715+
Value *ExtendedVec = Builder.CreateShuffleVector(
2716+
SmallVec, PoisonValue::get(SmallVec->getType()), ExtendMask);
2717+
LLVM_DEBUG(dbgs() << " shufflevector: " << *ExtendedVec << "\n");
2718+
V0 = IsV0Smaller ? ExtendedVec : V0;
2719+
V1 = IsV0Smaller ? V1 : ExtendedVec;
2720+
for (unsigned i = 0; i < NumElts0; ++i)
2721+
ShuffleMask.push_back(i);
2722+
for (unsigned i = 0; i < NumElts1; ++i)
2723+
ShuffleMask.push_back(LargeSize + i);
2724+
}
2725+
2726+
return Builder.CreateShuffleVector(V0, V1, ShuffleMask);
2727+
}
2728+
26812729
namespace {
26822730

26832731
/// Visitor to rewrite instructions using p particular slice of an alloca
@@ -2822,6 +2870,230 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
28222870
return CanSROA;
28232871
}
28242872

2873+
/// Attempts to rewrite a partition using tree-structured merge optimization.
2874+
///
2875+
/// This function analyzes a partition to determine if it can be optimized
2876+
/// using a tree-structured merge pattern, where multiple non-overlapping
2877+
/// stores completely fill an alloca. And there is no load from the alloca in
2878+
/// the middle of the stores. Such patterns can be optimized by eliminating
2879+
/// the intermediate stores and directly constructing the final vector by
2880+
/// using shufflevectors.
2881+
///
2882+
/// Example transformation:
2883+
/// Before: (stores do not have to be in order)
2884+
/// %alloca = alloca <8 x float>
2885+
/// store <2 x float> %val0, ptr %alloca ; offset 0-1
2886+
/// store <2 x float> %val2, ptr %alloca+16 ; offset 4-5
2887+
/// store <2 x float> %val1, ptr %alloca+8 ; offset 2-3
2888+
/// store <2 x float> %val3, ptr %alloca+24 ; offset 6-7
2889+
///
2890+
/// After:
2891+
/// %alloca = alloca <8 x float>
2892+
/// %shuffle0 = shufflevector %val0, %val1, <4 x i32> <i32 0, i32 1, i32 2,
2893+
/// i32 3>
2894+
/// %shuffle1 = shufflevector %val2, %val3, <4 x i32> <i32 0, i32 1, i32 2,
2895+
/// i32 3>
2896+
/// %shuffle2 = shufflevector %shuffle0, %shuffle1, <8 x i32> <i32 0, i32 1,
2897+
/// i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
2898+
/// store %shuffle2, ptr %alloca
2899+
///
2900+
/// The optimization looks for partitions that:
2901+
/// 1. Have no overlapping split slice tails
2902+
/// 2. Contain non-overlapping stores that cover the entire alloca
2903+
/// 3. Have exactly one load that reads the complete alloca structure and not
2904+
/// in the middle of the stores (TODO: maybe we can relax the constraint
2905+
/// about reading the entire alloca structure)
2906+
///
2907+
/// \param P The partition to analyze and potentially rewrite
2908+
/// \return An optional vector of values that were deleted during the rewrite
2909+
/// process, or std::nullopt if the partition cannot be optimized
2910+
/// using tree-structured merge
2911+
std::optional<SmallVector<Value *, 4>>
2912+
rewriteTreeStructuredMerge(Partition &P) {
2913+
// No tail slices that overlap with the partition
2914+
if (P.splitSliceTails().size() > 0)
2915+
return std::nullopt;
2916+
2917+
SmallVector<Value *, 4> DeletedValues;
2918+
LoadInst *TheLoad = nullptr;
2919+
2920+
// Structure to hold store information
2921+
struct StoreInfo {
2922+
StoreInst *Store;
2923+
uint64_t BeginOffset;
2924+
uint64_t EndOffset;
2925+
Value *StoredValue;
2926+
TypeSize StoredTypeSize = TypeSize::getZero();
2927+
2928+
StoreInfo(StoreInst *SI, uint64_t Begin, uint64_t End, Value *Val,
2929+
TypeSize StoredTypeSize)
2930+
: Store(SI), BeginOffset(Begin), EndOffset(End), StoredValue(Val),
2931+
StoredTypeSize(StoredTypeSize) {}
2932+
};
2933+
2934+
SmallVector<StoreInfo, 4> StoreInfos;
2935+
2936+
// The alloca must be a fixed vector type
2937+
auto *AllocatedTy = NewAI.getAllocatedType();
2938+
if (!isa<FixedVectorType>(AllocatedTy))
2939+
return std::nullopt;
2940+
2941+
Slice *LoadSlice = nullptr;
2942+
Type *LoadElementType = nullptr;
2943+
Type *StoreElementType = nullptr;
2944+
for (Slice &S : P) {
2945+
auto *User = cast<Instruction>(S.getUse()->getUser());
2946+
if (auto *LI = dyn_cast<LoadInst>(User)) {
2947+
// Do not handle the case where there is more than one load
2948+
// TODO: maybe we can handle this case
2949+
if (TheLoad)
2950+
return std::nullopt;
2951+
// If load is not a fixed vector type, we do not handle it
2952+
// If the number of loaded bits is not the same as the new alloca type
2953+
// size, we do not handle it
2954+
auto *FixedVecTy = dyn_cast<FixedVectorType>(LI->getType());
2955+
if (!FixedVecTy)
2956+
return std::nullopt;
2957+
if (DL.getTypeSizeInBits(FixedVecTy) !=
2958+
DL.getTypeSizeInBits(NewAI.getAllocatedType()))
2959+
return std::nullopt;
2960+
LoadElementType = FixedVecTy->getElementType();
2961+
TheLoad = LI;
2962+
LoadSlice = &S;
2963+
} else if (auto *SI = dyn_cast<StoreInst>(User)) {
2964+
// The store needs to be a fixed vector type
2965+
// All the stores should have the same element type
2966+
Type *StoredValueType = SI->getValueOperand()->getType();
2967+
Type *CurrentElementType = nullptr;
2968+
TypeSize StoredTypeSize = TypeSize::getZero();
2969+
if (auto *FixedVecTy = dyn_cast<FixedVectorType>(StoredValueType)) {
2970+
// Fixed vector type - use its element type
2971+
CurrentElementType = FixedVecTy->getElementType();
2972+
StoredTypeSize = DL.getTypeSizeInBits(FixedVecTy);
2973+
} else
2974+
return std::nullopt;
2975+
// Check element type consistency across all stores
2976+
if (StoreElementType && StoreElementType != CurrentElementType)
2977+
return std::nullopt;
2978+
StoreElementType = CurrentElementType;
2979+
StoreInfos.emplace_back(SI, S.beginOffset(), S.endOffset(),
2980+
SI->getValueOperand(), StoredTypeSize);
2981+
} else {
2982+
// If we have instructions other than load and store, we cannot do the
2983+
// tree structured merge
2984+
return std::nullopt;
2985+
}
2986+
}
2987+
// If we do not have any load, we cannot do the tree structured merge
2988+
if (!TheLoad)
2989+
return std::nullopt;
2990+
2991+
// If we do not have any stores, we cannot do the tree structured merge
2992+
if (StoreInfos.empty())
2993+
return std::nullopt;
2994+
2995+
// The load and store element types should be the same
2996+
if (LoadElementType != StoreElementType)
2997+
return std::nullopt;
2998+
2999+
// The load should cover the whole alloca
3000+
// TODO: maybe we can relax this constraint
3001+
if (!LoadSlice || LoadSlice->beginOffset() != NewAllocaBeginOffset ||
3002+
LoadSlice->endOffset() != NewAllocaEndOffset)
3003+
return std::nullopt;
3004+
3005+
// Stores should not overlap and should cover the whole alloca
3006+
// Sort by begin offset
3007+
llvm::sort(StoreInfos, [](const StoreInfo &A, const StoreInfo &B) {
3008+
return A.BeginOffset < B.BeginOffset;
3009+
});
3010+
3011+
// Check for overlaps and coverage
3012+
uint64_t ExpectedStart = NewAllocaBeginOffset;
3013+
TypeSize TotalStoreBits = TypeSize::getZero();
3014+
Instruction *PrevStore = nullptr;
3015+
for (auto &StoreInfo : StoreInfos) {
3016+
uint64_t BeginOff = StoreInfo.BeginOffset;
3017+
uint64_t EndOff = StoreInfo.EndOffset;
3018+
3019+
// Check for gap or overlap
3020+
if (BeginOff != ExpectedStart)
3021+
return std::nullopt;
3022+
3023+
ExpectedStart = EndOff;
3024+
TotalStoreBits += StoreInfo.StoredTypeSize;
3025+
PrevStore = StoreInfo.Store;
3026+
}
3027+
// Check that stores cover the entire alloca
3028+
// We need check both the end offset and the total store bits
3029+
if (ExpectedStart != NewAllocaEndOffset ||
3030+
TotalStoreBits != DL.getTypeSizeInBits(NewAI.getAllocatedType()))
3031+
return std::nullopt;
3032+
3033+
// Stores should be in the same basic block
3034+
// The load should not be in the middle of the stores
3035+
BasicBlock *LoadBB = TheLoad->getParent();
3036+
BasicBlock *StoreBB = StoreInfos[0].Store->getParent();
3037+
3038+
for (auto &StoreInfo : StoreInfos) {
3039+
if (StoreInfo.Store->getParent() != StoreBB)
3040+
return std::nullopt;
3041+
if (LoadBB == StoreBB && !StoreInfo.Store->comesBefore(TheLoad))
3042+
return std::nullopt;
3043+
}
3044+
3045+
// If we reach here, the partition can be merged with a tree structured
3046+
// merge
3047+
LLVM_DEBUG({
3048+
dbgs() << "Tree structured merge rewrite:\n Load: " << *TheLoad
3049+
<< "\n Ordered stores:\n";
3050+
for (auto [i, Info] : enumerate(StoreInfos))
3051+
dbgs() << " [" << i << "] Range[" << Info.BeginOffset << ", "
3052+
<< Info.EndOffset << ") \tStore: " << *Info.Store
3053+
<< "\tValue: " << *Info.StoredValue << "\n";
3054+
});
3055+
3056+
// Instead of having these stores, we merge all the stored values into a
3057+
// vector and store the merged value into the alloca
3058+
std::queue<Value *> VecElements;
3059+
IRBuilder<> Builder(StoreInfos.back().Store);
3060+
for (const auto &Info : StoreInfos) {
3061+
DeletedValues.push_back(Info.Store);
3062+
VecElements.push(Info.StoredValue);
3063+
}
3064+
3065+
LLVM_DEBUG(dbgs() << " Rewrite stores into shufflevectors:\n");
3066+
while (VecElements.size() > 1) {
3067+
uint64_t NumElts = VecElements.size();
3068+
for (uint64_t i = 0; i < NumElts / 2; i++) {
3069+
Value *V0 = VecElements.front();
3070+
VecElements.pop();
3071+
Value *V1 = VecElements.front();
3072+
VecElements.pop();
3073+
Value *Merged = mergeTwoVectors(V0, V1, Builder);
3074+
LLVM_DEBUG(dbgs() << " shufflevector: " << *Merged << "\n");
3075+
VecElements.push(Merged);
3076+
}
3077+
if (NumElts % 2 == 1) {
3078+
Value *V = VecElements.front();
3079+
VecElements.pop();
3080+
VecElements.push(V);
3081+
}
3082+
}
3083+
3084+
// Store the merged value into the alloca
3085+
Value *MergedValue = VecElements.front();
3086+
Builder.CreateAlignedStore(MergedValue, &NewAI, getSliceAlign());
3087+
3088+
IRBuilder<> LoadBuilder(TheLoad);
3089+
TheLoad->replaceAllUsesWith(LoadBuilder.CreateAlignedLoad(
3090+
TheLoad->getType(), &NewAI, getSliceAlign(), TheLoad->isVolatile(),
3091+
TheLoad->getName() + ".sroa.new.load"));
3092+
DeletedValues.push_back(TheLoad);
3093+
3094+
return DeletedValues;
3095+
}
3096+
28253097
private:
28263098
// Make sure the other visit overloads are visible.
28273099
using Base::visit;
@@ -4996,13 +5268,22 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
49965268
P.endOffset(), IsIntegerPromotable, VecTy,
49975269
PHIUsers, SelectUsers);
49985270
bool Promotable = true;
4999-
for (Slice *S : P.splitSliceTails()) {
5000-
Promotable &= Rewriter.visit(S);
5001-
++NumUses;
5002-
}
5003-
for (Slice &S : P) {
5004-
Promotable &= Rewriter.visit(&S);
5005-
++NumUses;
5271+
// Check whether we can have tree-structured merge.
5272+
std::optional<SmallVector<Value *, 4>> DeletedValues =
5273+
Rewriter.rewriteTreeStructuredMerge(P);
5274+
if (DeletedValues) {
5275+
NumUses += DeletedValues->size() + 1;
5276+
for (Value *V : *DeletedValues)
5277+
DeadInsts.push_back(V);
5278+
} else {
5279+
for (Slice *S : P.splitSliceTails()) {
5280+
Promotable &= Rewriter.visit(S);
5281+
++NumUses;
5282+
}
5283+
for (Slice &S : P) {
5284+
Promotable &= Rewriter.visit(&S);
5285+
++NumUses;
5286+
}
50065287
}
50075288

50085289
NumAllocaPartitionUses += NumUses;

0 commit comments

Comments
 (0)