Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
328 changes: 321 additions & 7 deletions llvm/lib/Transforms/Scalar/SROA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
#include <cstdint>
#include <cstring>
#include <iterator>
#include <queue>
#include <string>
#include <tuple>
#include <utility>
Expand Down Expand Up @@ -2667,6 +2668,96 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
return V;
}

/// This function takes two vector values and combines them into a single vector
/// by concatenating their elements. The function handles:
///
/// 1. Element type mismatch: If either vector's element type differs from
/// NewAIEltType, the function bitcasts the vector to use NewAIEltType while
/// preserving the total bit width (adjusting the number of elements
/// accordingly).
///
/// 2. Size mismatch: After transforming the vectors to have the desired element
/// type, if the two vectors have different numbers of elements, the smaller
/// vector is extended with poison values to match the size of the larger
/// vector before concatenation.
///
/// 3. Concatenation: The vectors are merged using a shuffle operation that
/// places all elements of V0 first, followed by all elements of V1.
///
/// \param V0 The first vector to merge (must be a vector type)
/// \param V1 The second vector to merge (must be a vector type)
/// \param DL The data layout for size calculations
/// \param NewAIEltTy The desired element type for the result vector
/// \param Builder IRBuilder for creating new instructions
/// \return A new vector containing all elements from V0 followed by all
/// elements from V1
static Value *mergeTwoVectors(Value *V0, Value *V1, const DataLayout &DL,
Type *NewAIEltTy, IRBuilder<> &Builder) {
assert(V0->getType()->isVectorTy() && V1->getType()->isVectorTy() &&
"Can not merge two non-vector values");

// V0 and V1 are vectors
// Create a new vector type with combined elements
// Use ShuffleVector to concatenate the vectors
auto *VecType0 = cast<FixedVectorType>(V0->getType());
auto *VecType1 = cast<FixedVectorType>(V1->getType());

// If V0/V1 element types are different from NewAllocaElementType,
// we need to introduce bitcasts before merging them
auto BitcastIfNeeded = [&](Value *&V, FixedVectorType *&VecType,
const char *DebugName) {
Type *EltType = VecType->getElementType();
if (EltType != NewAIEltTy) {
// Calculate new number of elements to maintain same bit width
unsigned TotalBits =
VecType->getNumElements() * DL.getTypeSizeInBits(EltType);
unsigned NewNumElts = TotalBits / DL.getTypeSizeInBits(NewAIEltTy);

auto *NewVecType = FixedVectorType::get(NewAIEltTy, NewNumElts);
V = Builder.CreateBitCast(V, NewVecType);
VecType = NewVecType;
LLVM_DEBUG(dbgs() << " bitcast " << DebugName << ": " << *V << "\n");
}
};

BitcastIfNeeded(V0, VecType0, "V0");
BitcastIfNeeded(V1, VecType1, "V1");

unsigned NumElts0 = VecType0->getNumElements();
unsigned NumElts1 = VecType1->getNumElements();

SmallVector<int, 16> ShuffleMask;

if (NumElts0 == NumElts1) {
for (unsigned i = 0; i < NumElts0 + NumElts1; ++i)
ShuffleMask.push_back(i);
} else {
// If two vectors have different sizes, we need to extend
// the smaller vector to the size of the larger vector.
unsigned SmallSize = std::min(NumElts0, NumElts1);
unsigned LargeSize = std::max(NumElts0, NumElts1);
bool IsV0Smaller = NumElts0 < NumElts1;
Value *SmallVec = IsV0Smaller ? V0 : V1;

SmallVector<int, 16> ExtendMask;
for (unsigned i = 0; i < SmallSize; ++i)
ExtendMask.push_back(i);
for (unsigned i = SmallSize; i < LargeSize; ++i)
ExtendMask.push_back(PoisonMaskElem);
Value *ExtendedVec = Builder.CreateShuffleVector(
SmallVec, PoisonValue::get(SmallVec->getType()), ExtendMask);
LLVM_DEBUG(dbgs() << " shufflevector: " << *ExtendedVec << "\n");
V0 = IsV0Smaller ? ExtendedVec : V0;
V1 = IsV0Smaller ? V1 : ExtendedVec;
for (unsigned i = 0; i < NumElts0; ++i)
ShuffleMask.push_back(i);
for (unsigned i = 0; i < NumElts1; ++i)
ShuffleMask.push_back(LargeSize + i);
}

return Builder.CreateShuffleVector(V0, V1, ShuffleMask);
}

namespace {

/// Visitor to rewrite instructions using p particular slice of an alloca
Expand Down Expand Up @@ -2811,6 +2902,220 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
return CanSROA;
}

/// Attempts to rewrite a partition using tree-structured merge optimization.
///
/// This function analyzes a partition to determine if it can be optimized
/// using a tree-structured merge pattern, where multiple non-overlapping
/// stores completely fill an alloca. And there is no load from the alloca in
/// the middle of the stores. Such patterns can be optimized by eliminating
/// the intermediate stores and directly constructing the final vector by
/// using shufflevectors.
///
/// Example transformation:
/// Before: (stores do not have to be in order)
/// %alloca = alloca <8 x float>
/// store <2 x float> %val0, ptr %alloca ; offset 0-1
/// store <2 x float> %val2, ptr %alloca+16 ; offset 4-5
/// store <2 x float> %val1, ptr %alloca+8 ; offset 2-3
/// store <2 x float> %val3, ptr %alloca+24 ; offset 6-7
///
/// After:
/// %alloca = alloca <8 x float>
/// %shuffle0 = shufflevector %val0, %val1, <4 x i32> <i32 0, i32 1, i32 2,
/// i32 3>
/// %shuffle1 = shufflevector %val2, %val3, <4 x i32> <i32 0, i32 1, i32 2,
/// i32 3>
/// %shuffle2 = shufflevector %shuffle0, %shuffle1, <8 x i32> <i32 0, i32 1,
/// i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
/// store %shuffle2, ptr %alloca
///
/// The optimization looks for partitions that:
/// 1. Have no overlapping split slice tails
/// 2. Contain non-overlapping stores that cover the entire alloca
/// 3. Have exactly one load that reads the complete alloca structure and not
/// in the middle of the stores (TODO: maybe we can relax the constraint
/// about reading the entire alloca structure)
///
/// \param P The partition to analyze and potentially rewrite
/// \return An optional vector of values that were deleted during the rewrite
/// process, or std::nullopt if the partition cannot be optimized
/// using tree-structured merge
std::optional<SmallVector<Value *, 4>>
rewriteTreeStructuredMerge(Partition &P) {
// No tail slices that overlap with the partition
if (P.splitSliceTails().size() > 0)
return std::nullopt;

SmallVector<Value *, 4> DeletedValues;
LoadInst *TheLoad = nullptr;

// Structure to hold store information
struct StoreInfo {
StoreInst *Store;
uint64_t BeginOffset;
uint64_t EndOffset;
Value *StoredValue;
StoreInfo(StoreInst *SI, uint64_t Begin, uint64_t End, Value *Val)
: Store(SI), BeginOffset(Begin), EndOffset(End), StoredValue(Val) {}
};

SmallVector<StoreInfo, 4> StoreInfos;

// The alloca must be a fixed vector type
Type *AllocatedEltTy = nullptr;
if (auto *FixedVecTy = dyn_cast<FixedVectorType>(NewAI.getAllocatedType()))
AllocatedEltTy = FixedVecTy->getElementType();
else
return std::nullopt;
// If the allocated element type is a pointer, we do not handle it
// TODO: handle this case by using inttoptr/ptrtoint
if (AllocatedEltTy->isPtrOrPtrVectorTy())
return std::nullopt;

for (Slice &S : P) {
auto *User = cast<Instruction>(S.getUse()->getUser());
if (auto *LI = dyn_cast<LoadInst>(User)) {
// Do not handle the case where there is more than one load
// TODO: maybe we can handle this case
if (TheLoad)
return std::nullopt;
// If load is not a fixed vector type, we do not handle it
// If the number of loaded bits is not the same as the new alloca type
// size, we do not handle it
auto *FixedVecTy = dyn_cast<FixedVectorType>(LI->getType());
if (!FixedVecTy)
return std::nullopt;
if (DL.getTypeSizeInBits(FixedVecTy) !=
DL.getTypeSizeInBits(NewAI.getAllocatedType()))
return std::nullopt;
// If the loaded value is a pointer, we do not handle it
// TODO: handle this case by using inttoptr/ptrtoint
if (FixedVecTy->getElementType()->isPtrOrPtrVectorTy())
return std::nullopt;
TheLoad = LI;
} else if (auto *SI = dyn_cast<StoreInst>(User)) {
// The stored value should be a fixed vector type
Type *StoredValueType = SI->getValueOperand()->getType();
if (!isa<FixedVectorType>(StoredValueType))
return std::nullopt;

// The total number of stored bits should be the multiple of the new
// alloca element type size
if (DL.getTypeSizeInBits(StoredValueType) %
DL.getTypeSizeInBits(AllocatedEltTy) !=
0)
return std::nullopt;
// If the stored value is a pointer, we do not handle it
// TODO: handle this case by using inttoptr/ptrtoint
if (StoredValueType->isPtrOrPtrVectorTy())
return std::nullopt;
StoreInfos.emplace_back(SI, S.beginOffset(), S.endOffset(),
SI->getValueOperand());
} else {
// If we have instructions other than load and store, we cannot do the
// tree structured merge
return std::nullopt;
}
}
// If we do not have any load, we cannot do the tree structured merge
if (!TheLoad)
return std::nullopt;

// If we do not have multiple stores, we cannot do the tree structured merge
if (StoreInfos.size() < 2)
return std::nullopt;

// Stores should not overlap and should cover the whole alloca
// Sort by begin offset
llvm::sort(StoreInfos, [](const StoreInfo &A, const StoreInfo &B) {
return A.BeginOffset < B.BeginOffset;
});

// Check for overlaps and coverage
uint64_t ExpectedStart = NewAllocaBeginOffset;
TypeSize TotalStoreBits = TypeSize::getZero();
for (auto &StoreInfo : StoreInfos) {
uint64_t BeginOff = StoreInfo.BeginOffset;
uint64_t EndOff = StoreInfo.EndOffset;

// Check for gap or overlap
if (BeginOff != ExpectedStart)
return std::nullopt;

ExpectedStart = EndOff;
TotalStoreBits +=
DL.getTypeSizeInBits(StoreInfo.Store->getValueOperand()->getType());
}
// Check that stores cover the entire alloca
// We need check both the end offset and the total store bits
if (ExpectedStart != NewAllocaEndOffset ||
TotalStoreBits != DL.getTypeSizeInBits(NewAI.getAllocatedType()))
return std::nullopt;

// Stores should be in the same basic block
// The load should not be in the middle of the stores
BasicBlock *LoadBB = TheLoad->getParent();
BasicBlock *StoreBB = StoreInfos[0].Store->getParent();

for (auto &StoreInfo : StoreInfos) {
if (StoreInfo.Store->getParent() != StoreBB)
return std::nullopt;
if (LoadBB == StoreBB && !StoreInfo.Store->comesBefore(TheLoad))
return std::nullopt;
}

// If we reach here, the partition can be merged with a tree structured
// merge
LLVM_DEBUG({
dbgs() << "Tree structured merge rewrite:\n Load: " << *TheLoad
<< "\n Ordered stores:\n";
for (auto [i, Info] : enumerate(StoreInfos))
dbgs() << " [" << i << "] Range[" << Info.BeginOffset << ", "
<< Info.EndOffset << ") \tStore: " << *Info.Store
<< "\tValue: " << *Info.StoredValue << "\n";
});

// Instead of having these stores, we merge all the stored values into a
// vector and store the merged value into the alloca
std::queue<Value *> VecElements;
IRBuilder<> Builder(StoreInfos.back().Store);
for (const auto &Info : StoreInfos) {
DeletedValues.push_back(Info.Store);
VecElements.push(Info.StoredValue);
}

LLVM_DEBUG(dbgs() << " Rewrite stores into shufflevectors:\n");
while (VecElements.size() > 1) {
uint64_t NumElts = VecElements.size();
for (uint64_t i = 0; i < NumElts / 2; i++) {
Value *V0 = VecElements.front();
VecElements.pop();
Value *V1 = VecElements.front();
VecElements.pop();
Value *Merged = mergeTwoVectors(V0, V1, DL, AllocatedEltTy, Builder);
LLVM_DEBUG(dbgs() << " shufflevector: " << *Merged << "\n");
VecElements.push(Merged);
}
if (NumElts % 2 == 1) {
Value *V = VecElements.front();
VecElements.pop();
VecElements.push(V);
}
}

// Store the merged value into the alloca
Value *MergedValue = VecElements.front();
Builder.CreateAlignedStore(MergedValue, &NewAI, getSliceAlign());

IRBuilder<> LoadBuilder(TheLoad);
TheLoad->replaceAllUsesWith(LoadBuilder.CreateAlignedLoad(
TheLoad->getType(), &NewAI, getSliceAlign(), TheLoad->isVolatile(),
TheLoad->getName() + ".sroa.new.load"));
DeletedValues.push_back(TheLoad);

return DeletedValues;
}

private:
// Make sure the other visit overloads are visible.
using Base::visit;
Expand Down Expand Up @@ -4981,13 +5286,22 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
P.endOffset(), IsIntegerPromotable, VecTy,
PHIUsers, SelectUsers);
bool Promotable = true;
for (Slice *S : P.splitSliceTails()) {
Promotable &= Rewriter.visit(S);
++NumUses;
}
for (Slice &S : P) {
Promotable &= Rewriter.visit(&S);
++NumUses;
// Check whether we can have tree-structured merge.
std::optional<SmallVector<Value *, 4>> DeletedValues =
Rewriter.rewriteTreeStructuredMerge(P);
if (DeletedValues) {
NumUses += DeletedValues->size() + 1;
for (Value *V : *DeletedValues)
DeadInsts.push_back(V);
} else {
for (Slice *S : P.splitSliceTails()) {
Promotable &= Rewriter.visit(S);
++NumUses;
}
for (Slice &S : P) {
Promotable &= Rewriter.visit(&S);
++NumUses;
}
}

NumAllocaPartitionUses += NumUses;
Expand Down
Loading
Loading