@@ -2679,7 +2679,32 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
26792679 return V;
26802680}
26812681
2682- static Value *mergeTwoVectors (Value *V0, Value *V1, IRBuilder<> &Builder) {
2682+ // / This function takes two vector values and combines them into a single vector
2683+ // / by concatenating their elements. The function handles:
2684+ // /
2685+ // / 1. Element type mismatch: If either vector's element type differs from
2686+ // / NewAIEltType, the function bitcasts the vector to use NewAIEltType while
2687+ // / preserving the total bit width (adjusting the number of elements
2688+ // / accordingly).
2689+ // /
2690+ // / 2. Size mismatch: After transforming the vectors to have the desired element
2691+ // / type, if the two vectors have different numbers of elements, the smaller
2692+ // / vector is extended with poison values to match the size of the larger
2693+ // / vector before concatenation.
2694+ // /
2695+ // / 3. Concatenation: The vectors are merged using a shuffle operation that
2696+ // / places all elements of V0 first, followed by all elements of V1.
2697+ // /
2698+ // / \param V0 The first vector to merge (must be a vector type)
2699+ // / \param V1 The second vector to merge (must be a vector type)
2700+ // / \param DL The data layout for size calculations
2701+ // / \param NewAIEltTy The desired element type for the result vector
2702+ // / \param Builder IRBuilder for creating new instructions
2703+ // / \return A new vector containing all elements from V0 followed by all
2704+ // / elements from V1
2705+ static Value *mergeTwoVectors (Value *V0, Value *V1, const DataLayout &DL,
2706+ Type *NewAIEltTy,
2707+ IRBuilder<> &Builder) {
26832708 assert (V0->getType ()->isVectorTy () && V1->getType ()->isVectorTy () &&
26842709 " Can not merge two non-vector values" );
26852710
@@ -2689,8 +2714,28 @@ static Value *mergeTwoVectors(Value *V0, Value *V1, IRBuilder<> &Builder) {
26892714 auto *VecType0 = cast<FixedVectorType>(V0->getType ());
26902715 auto *VecType1 = cast<FixedVectorType>(V1->getType ());
26912716
2692- assert (VecType0->getElementType () == VecType1->getElementType () &&
2693- " Can not merge two vectors with different element types" );
2717+ // If V0/V1 element types are different from NewAllocaElementType,
2718+ // we need to introduce bitcasts before merging them
2719+ auto BitcastIfNeeded = [&](Value *&V, FixedVectorType *&VecType,
2720+ const char *DebugName) {
2721+ Type *EltType = VecType->getElementType ();
2722+ if (EltType != NewAIEltTy) {
2723+ // Calculate new number of elements to maintain same bit width
2724+ unsigned TotalBits =
2725+ VecType->getNumElements () * DL.getTypeSizeInBits (EltType);
2726+ unsigned NewNumElts =
2727+ TotalBits / DL.getTypeSizeInBits (NewAIEltTy);
2728+
2729+ auto *NewVecType = FixedVectorType::get (NewAIEltTy, NewNumElts);
2730+ V = Builder.CreateBitCast (V, NewVecType);
2731+ VecType = NewVecType;
2732+ LLVM_DEBUG (dbgs () << " bitcast " << DebugName << " : " << *V << " \n " );
2733+ }
2734+ };
2735+
2736+ BitcastIfNeeded (V0, VecType0, " V0" );
2737+ BitcastIfNeeded (V1, VecType1, " V1" );
2738+
26942739 unsigned NumElts0 = VecType0->getNumElements ();
26952740 unsigned NumElts1 = VecType1->getNumElements ();
26962741
@@ -2923,24 +2968,19 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
29232968 uint64_t BeginOffset;
29242969 uint64_t EndOffset;
29252970 Value *StoredValue;
2926- TypeSize StoredTypeSize = TypeSize::getZero();
2927-
2928- StoreInfo (StoreInst *SI, uint64_t Begin, uint64_t End, Value *Val,
2929- TypeSize StoredTypeSize)
2930- : Store(SI), BeginOffset(Begin), EndOffset(End), StoredValue(Val),
2931- StoredTypeSize(StoredTypeSize) {}
2971+ StoreInfo (StoreInst *SI, uint64_t Begin, uint64_t End, Value *Val)
2972+ : Store(SI), BeginOffset(Begin), EndOffset(End), StoredValue(Val) {}
29322973 };
29332974
29342975 SmallVector<StoreInfo, 4 > StoreInfos;
29352976
29362977 // The alloca must be a fixed vector type
2937- auto *AllocatedTy = NewAI.getAllocatedType();
2938- if (!isa<FixedVectorType>(AllocatedTy))
2978+ Type *AllocatedEltTy = nullptr ;
2979+ if (auto *FixedVecTy = dyn_cast<FixedVectorType>(NewAI.getAllocatedType ()))
2980+ AllocatedEltTy = FixedVecTy->getElementType ();
2981+ else
29392982 return std::nullopt ;
29402983
2941- Slice *LoadSlice = nullptr ;
2942- Type *LoadElementType = nullptr ;
2943- Type *StoreElementType = nullptr ;
29442984 for (Slice &S : P) {
29452985 auto *User = cast<Instruction>(S.getUse ()->getUser ());
29462986 if (auto *LI = dyn_cast<LoadInst>(User)) {
@@ -2957,27 +2997,20 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
29572997 if (DL.getTypeSizeInBits (FixedVecTy) !=
29582998 DL.getTypeSizeInBits (NewAI.getAllocatedType ()))
29592999 return std::nullopt ;
2960- LoadElementType = FixedVecTy->getElementType ();
29613000 TheLoad = LI;
2962- LoadSlice = &S;
29633001 } else if (auto *SI = dyn_cast<StoreInst>(User)) {
2964- // The store needs to be a fixed vector type
2965- // All the stores should have the same element type
3002+ // The stored value should be a fixed vector type
29663003 Type *StoredValueType = SI->getValueOperand ()->getType ();
2967- Type *CurrentElementType = nullptr ;
2968- TypeSize StoredTypeSize = TypeSize::getZero ();
2969- if (auto *FixedVecTy = dyn_cast<FixedVectorType>(StoredValueType)) {
2970- // Fixed vector type - use its element type
2971- CurrentElementType = FixedVecTy->getElementType ();
2972- StoredTypeSize = DL.getTypeSizeInBits (FixedVecTy);
2973- } else
3004+ if (!isa<FixedVectorType>(StoredValueType))
29743005 return std::nullopt ;
2975- // Check element type consistency across all stores
2976- if (StoreElementType && StoreElementType != CurrentElementType)
3006+
3007+ // The total number of stored bits should be the multiple of the new
3008+ // alloca element type size
3009+ if (DL.getTypeSizeInBits (StoredValueType) %
3010+ DL.getTypeSizeInBits (AllocatedEltTy) != 0 )
29773011 return std::nullopt ;
2978- StoreElementType = CurrentElementType;
29793012 StoreInfos.emplace_back (SI, S.beginOffset (), S.endOffset (),
2980- SI->getValueOperand (), StoredTypeSize );
3013+ SI->getValueOperand ());
29813014 } else {
29823015 // If we have instructions other than load and store, we cannot do the
29833016 // tree structured merge
@@ -2992,16 +3025,6 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
29923025 if (StoreInfos.size () < 2 )
29933026 return std::nullopt ;
29943027
2995- // The load and store element types should be the same
2996- if (LoadElementType != StoreElementType)
2997- return std::nullopt ;
2998-
2999- // The load should cover the whole alloca
3000- // TODO: maybe we can relax this constraint
3001- if (!LoadSlice || LoadSlice->beginOffset () != NewAllocaBeginOffset ||
3002- LoadSlice->endOffset() != NewAllocaEndOffset)
3003- return std::nullopt;
3004-
30053028 // Stores should not overlap and should cover the whole alloca
30063029 // Sort by begin offset
30073030 llvm::sort (StoreInfos, [](const StoreInfo &A, const StoreInfo &B) {
@@ -3011,7 +3034,6 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
30113034 // Check for overlaps and coverage
30123035 uint64_t ExpectedStart = NewAllocaBeginOffset;
30133036 TypeSize TotalStoreBits = TypeSize::getZero ();
3014- Instruction *PrevStore = nullptr ;
30153037 for (auto &StoreInfo : StoreInfos) {
30163038 uint64_t BeginOff = StoreInfo.BeginOffset ;
30173039 uint64_t EndOff = StoreInfo.EndOffset ;
@@ -3021,8 +3043,8 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
30213043 return std::nullopt ;
30223044
30233045 ExpectedStart = EndOff;
3024- TotalStoreBits += StoreInfo. StoredTypeSize ;
3025- PrevStore = StoreInfo.Store ;
3046+ TotalStoreBits +=
3047+ DL. getTypeSizeInBits ( StoreInfo.Store -> getValueOperand ()-> getType ()) ;
30263048 }
30273049 // Check that stores cover the entire alloca
30283050 // We need check both the end offset and the total store bits
@@ -3070,7 +3092,7 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
30703092 VecElements.pop ();
30713093 Value *V1 = VecElements.front ();
30723094 VecElements.pop ();
3073- Value *Merged = mergeTwoVectors (V0, V1, Builder);
3095+ Value *Merged = mergeTwoVectors (V0, V1, DL, AllocatedEltTy, Builder);
30743096 LLVM_DEBUG (dbgs () << " shufflevector: " << *Merged << " \n " );
30753097 VecElements.push (Merged);
30763098 }
0 commit comments