Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 306 additions & 7 deletions llvm/lib/Transforms/Scalar/SROA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
#include <cstdint>
#include <cstring>
#include <iterator>
#include <queue>
#include <string>
#include <tuple>
#include <utility>
Expand Down Expand Up @@ -2667,6 +2668,90 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
return V;
}

/// This function takes two vector values and combines them into a single vector
/// by concatenating their elements. The function handles:
///
/// 1. Element type mismatch: If either vector's element type differs from
/// NewAIEltType, the function bitcasts the vector to use NewAIEltType while
/// preserving the total bit width (adjusting the number of elements
/// accordingly).
///
/// 2. Size mismatch: After transforming the vectors to have the desired element
/// type, if the two vectors have different numbers of elements, the smaller
/// vector is extended with poison values to match the size of the larger
/// vector before concatenation.
///
/// 3. Concatenation: The vectors are merged using a shuffle operation that
/// places all elements of V0 first, followed by all elements of V1.
///
/// \param V0 The first vector to merge (must be a vector type)
/// \param V1 The second vector to merge (must be a vector type)
/// \param DL The data layout for size calculations
/// \param NewAIEltTy The desired element type for the result vector
/// \param Builder IRBuilder for creating new instructions
/// \return A new vector containing all elements from V0 followed by all
/// elements from V1
static Value *mergeTwoVectors(Value *V0, Value *V1, const DataLayout &DL,
Type *NewAIEltTy, IRBuilder<> &Builder) {
// V0 and V1 are vectors
// Create a new vector type with combined elements
// Use ShuffleVector to concatenate the vectors
auto *VecType0 = cast<FixedVectorType>(V0->getType());
auto *VecType1 = cast<FixedVectorType>(V1->getType());

// If V0/V1 element types are different from NewAllocaElementType,
// we need to introduce bitcasts before merging them
auto BitcastIfNeeded = [&](Value *&V, FixedVectorType *&VecType,
const char *DebugName) {
Type *EltType = VecType->getElementType();
if (EltType != NewAIEltTy) {
// Calculate new number of elements to maintain same bit width
unsigned TotalBits =
VecType->getNumElements() * DL.getTypeSizeInBits(EltType);
unsigned NewNumElts = TotalBits / DL.getTypeSizeInBits(NewAIEltTy);

auto *NewVecType = FixedVectorType::get(NewAIEltTy, NewNumElts);
V = Builder.CreateBitCast(V, NewVecType);
VecType = NewVecType;
LLVM_DEBUG(dbgs() << " bitcast " << DebugName << ": " << *V << "\n");
}
};

BitcastIfNeeded(V0, VecType0, "V0");
BitcastIfNeeded(V1, VecType1, "V1");

unsigned NumElts0 = VecType0->getNumElements();
unsigned NumElts1 = VecType1->getNumElements();

SmallVector<int, 16> ShuffleMask;

if (NumElts0 == NumElts1) {
for (unsigned i = 0; i < NumElts0 + NumElts1; ++i)
ShuffleMask.push_back(i);
} else {
// If two vectors have different sizes, we need to extend
// the smaller vector to the size of the larger vector.
unsigned SmallSize = std::min(NumElts0, NumElts1);
unsigned LargeSize = std::max(NumElts0, NumElts1);
bool IsV0Smaller = NumElts0 < NumElts1;
Value *&ExtendedVec = IsV0Smaller ? V0 : V1;
SmallVector<int, 16> ExtendMask;
for (unsigned i = 0; i < SmallSize; ++i)
ExtendMask.push_back(i);
for (unsigned i = SmallSize; i < LargeSize; ++i)
ExtendMask.push_back(PoisonMaskElem);
ExtendedVec = Builder.CreateShuffleVector(
ExtendedVec, PoisonValue::get(ExtendedVec->getType()), ExtendMask);
LLVM_DEBUG(dbgs() << " shufflevector: " << *ExtendedVec << "\n");
for (unsigned i = 0; i < NumElts0; ++i)
ShuffleMask.push_back(i);
for (unsigned i = 0; i < NumElts1; ++i)
ShuffleMask.push_back(LargeSize + i);
}

return Builder.CreateShuffleVector(V0, V1, ShuffleMask);
}

namespace {

/// Visitor to rewrite instructions using p particular slice of an alloca
Expand Down Expand Up @@ -2811,6 +2896,213 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
return CanSROA;
}

/// Attempts to rewrite a partition using tree-structured merge optimization.
///
/// This function analyzes a partition to determine if it can be optimized
/// using a tree-structured merge pattern, where multiple non-overlapping
/// stores completely fill an alloca. And there is no load from the alloca in
/// the middle of the stores. Such patterns can be optimized by eliminating
/// the intermediate stores and directly constructing the final vector by
/// using shufflevectors.
///
/// Example transformation:
/// Before: (stores do not have to be in order)
/// %alloca = alloca <8 x float>
/// store <2 x float> %val0, ptr %alloca ; offset 0-1
/// store <2 x float> %val2, ptr %alloca+16 ; offset 4-5
/// store <2 x float> %val1, ptr %alloca+8 ; offset 2-3
/// store <2 x float> %val3, ptr %alloca+24 ; offset 6-7
///
/// After:
/// %alloca = alloca <8 x float>
/// %shuffle0 = shufflevector %val0, %val1, <4 x i32> <i32 0, i32 1, i32 2,
/// i32 3>
/// %shuffle1 = shufflevector %val2, %val3, <4 x i32> <i32 0, i32 1, i32 2,
/// i32 3>
/// %shuffle2 = shufflevector %shuffle0, %shuffle1, <8 x i32> <i32 0, i32 1,
/// i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
/// store %shuffle2, ptr %alloca
///
/// The optimization looks for partitions that:
/// 1. Have no overlapping split slice tails
/// 2. Contain non-overlapping stores that cover the entire alloca
/// 3. Have exactly one load that reads the complete alloca structure and not
/// in the middle of the stores (TODO: maybe we can relax the constraint
/// about reading the entire alloca structure)
///
/// \param P The partition to analyze and potentially rewrite
/// \return An optional vector of values that were deleted during the rewrite
/// process, or std::nullopt if the partition cannot be optimized
/// using tree-structured merge
std::optional<SmallVector<Value *, 4>>
rewriteTreeStructuredMerge(Partition &P) {
// No tail slices that overlap with the partition
if (P.splitSliceTails().size() > 0)
return std::nullopt;

SmallVector<Value *, 4> DeletedValues;
LoadInst *TheLoad = nullptr;

// Structure to hold store information
struct StoreInfo {
StoreInst *Store;
uint64_t BeginOffset;
uint64_t EndOffset;
Value *StoredValue;
StoreInfo(StoreInst *SI, uint64_t Begin, uint64_t End, Value *Val)
: Store(SI), BeginOffset(Begin), EndOffset(End), StoredValue(Val) {}
};

SmallVector<StoreInfo, 4> StoreInfos;

// If the new alloca is a fixed vector type, we use its element type as the
// allocated element type, otherwise we use i8 as the allocated element
Type *AllocatedEltTy =
isa<FixedVectorType>(NewAI.getAllocatedType())
? cast<FixedVectorType>(NewAI.getAllocatedType())->getElementType()
: Type::getInt8Ty(NewAI.getContext());

// Helper to check if a type is
// 1. A fixed vector type
// 2. The element type is not a pointer
// 3. The element type size is byte-aligned
// We only handle the cases that the ld/st meet these conditions
auto IsTypeValidForTreeStructuredMerge = [&](Type *Ty) -> bool {
auto *FixedVecTy = dyn_cast<FixedVectorType>(Ty);
return FixedVecTy &&
DL.getTypeSizeInBits(FixedVecTy->getElementType()) % 8 == 0 &&
!FixedVecTy->getElementType()->isPointerTy();
};

for (Slice &S : P) {
auto *User = cast<Instruction>(S.getUse()->getUser());
if (auto *LI = dyn_cast<LoadInst>(User)) {
// Do not handle the case if
// 1. There is more than one load
// 2. The load is volatile
// 3. The load does not read the entire alloca structure
// 4. The load does not meet the conditions in the helper function
if (TheLoad || !IsTypeValidForTreeStructuredMerge(LI->getType()) ||
S.beginOffset() != NewAllocaBeginOffset ||
S.endOffset() != NewAllocaEndOffset || LI->isVolatile())
return std::nullopt;
TheLoad = LI;
} else if (auto *SI = dyn_cast<StoreInst>(User)) {
// Do not handle the case if
// 1. The store does not meet the conditions in the helper function
// 2. The store is volatile
if (!IsTypeValidForTreeStructuredMerge(
SI->getValueOperand()->getType()) ||
SI->isVolatile())
return std::nullopt;
StoreInfos.emplace_back(SI, S.beginOffset(), S.endOffset(),
SI->getValueOperand());
} else {
// If we have instructions other than load and store, we cannot do the
// tree structured merge
return std::nullopt;
}
}
// If we do not have any load, we cannot do the tree structured merge
if (!TheLoad)
return std::nullopt;

// If we do not have multiple stores, we cannot do the tree structured merge
if (StoreInfos.size() < 2)
return std::nullopt;

// Stores should not overlap and should cover the whole alloca
// Sort by begin offset
llvm::sort(StoreInfos, [](const StoreInfo &A, const StoreInfo &B) {
return A.BeginOffset < B.BeginOffset;
});

// Check for overlaps and coverage
uint64_t ExpectedStart = NewAllocaBeginOffset;
for (auto &StoreInfo : StoreInfos) {
uint64_t BeginOff = StoreInfo.BeginOffset;
uint64_t EndOff = StoreInfo.EndOffset;

// Check for gap or overlap
if (BeginOff != ExpectedStart)
return std::nullopt;

ExpectedStart = EndOff;
}
// Check that stores cover the entire alloca
if (ExpectedStart != NewAllocaEndOffset)
return std::nullopt;

// Stores should be in the same basic block
// The load should not be in the middle of the stores
// Note:
// If the load is in a different basic block with the stores, we can still
// do the tree structured merge. This is because we do not have the
// store->load forwarding here. The merged vector will be stored back to
// NewAI and the new load will load from NewAI. The forwarding will be
// handled later when we try to promote NewAI.
BasicBlock *LoadBB = TheLoad->getParent();
BasicBlock *StoreBB = StoreInfos[0].Store->getParent();

for (auto &StoreInfo : StoreInfos) {
if (StoreInfo.Store->getParent() != StoreBB)
return std::nullopt;
if (LoadBB == StoreBB && !StoreInfo.Store->comesBefore(TheLoad))
return std::nullopt;
}

// If we reach here, the partition can be merged with a tree structured
// merge
LLVM_DEBUG({
dbgs() << "Tree structured merge rewrite:\n Load: " << *TheLoad
<< "\n Ordered stores:\n";
for (auto [i, Info] : enumerate(StoreInfos))
dbgs() << " [" << i << "] Range[" << Info.BeginOffset << ", "
<< Info.EndOffset << ") \tStore: " << *Info.Store
<< "\tValue: " << *Info.StoredValue << "\n";
});

// Instead of having these stores, we merge all the stored values into a
// vector and store the merged value into the alloca
std::queue<Value *> VecElements;
IRBuilder<> Builder(StoreInfos.back().Store);
for (const auto &Info : StoreInfos) {
DeletedValues.push_back(Info.Store);
VecElements.push(Info.StoredValue);
}

LLVM_DEBUG(dbgs() << " Rewrite stores into shufflevectors:\n");
while (VecElements.size() > 1) {
const auto NumElts = VecElements.size();
for ([[maybe_unused]] const auto _ : llvm::seq(NumElts / 2)) {
Value *V0 = VecElements.front();
VecElements.pop();
Value *V1 = VecElements.front();
VecElements.pop();
Value *Merged = mergeTwoVectors(V0, V1, DL, AllocatedEltTy, Builder);
LLVM_DEBUG(dbgs() << " shufflevector: " << *Merged << "\n");
VecElements.push(Merged);
}
if (NumElts % 2 == 1) {
Value *V = VecElements.front();
VecElements.pop();
VecElements.push(V);
}
}

// Store the merged value into the alloca
Value *MergedValue = VecElements.front();
Builder.CreateAlignedStore(MergedValue, &NewAI, getSliceAlign());

IRBuilder<> LoadBuilder(TheLoad);
TheLoad->replaceAllUsesWith(LoadBuilder.CreateAlignedLoad(
TheLoad->getType(), &NewAI, getSliceAlign(), TheLoad->isVolatile(),
TheLoad->getName() + ".sroa.new.load"));
DeletedValues.push_back(TheLoad);

return DeletedValues;
}

private:
// Make sure the other visit overloads are visible.
using Base::visit;
Expand Down Expand Up @@ -4981,13 +5273,20 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
P.endOffset(), IsIntegerPromotable, VecTy,
PHIUsers, SelectUsers);
bool Promotable = true;
for (Slice *S : P.splitSliceTails()) {
Promotable &= Rewriter.visit(S);
++NumUses;
}
for (Slice &S : P) {
Promotable &= Rewriter.visit(&S);
++NumUses;
// Check whether we can have tree-structured merge.
if (auto DeletedValues = Rewriter.rewriteTreeStructuredMerge(P)) {
NumUses += DeletedValues->size() + 1;
for (Value *V : *DeletedValues)
DeadInsts.push_back(V);
} else {
for (Slice *S : P.splitSliceTails()) {
Promotable &= Rewriter.visit(S);
++NumUses;
}
for (Slice &S : P) {
Promotable &= Rewriter.visit(&S);
++NumUses;
}
}

NumAllocaPartitionUses += NumUses;
Expand Down
Loading