Skip to content

Commit 68eea1e

Browse files
author
chengjunp
committed
Handle the cases where ld/st has different elt types
1 parent be0aa77 commit 68eea1e

File tree

3 files changed

+135
-78
lines changed

3 files changed

+135
-78
lines changed

llvm/lib/Transforms/Scalar/SROA.cpp

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2679,7 +2679,32 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
26792679
return V;
26802680
}
26812681

2682-
static Value *mergeTwoVectors(Value *V0, Value *V1, IRBuilder<> &Builder) {
2682+
/// This function takes two vector values and combines them into a single vector
2683+
/// by concatenating their elements. The function handles:
2684+
///
2685+
/// 1. Element type mismatch: If either vector's element type differs from
2686+
/// NewAIEltType, the function bitcasts the vector to use NewAIEltType while
2687+
/// preserving the total bit width (adjusting the number of elements
2688+
/// accordingly).
2689+
///
2690+
/// 2. Size mismatch: After transforming the vectors to have the desired element
2691+
/// type, if the two vectors have different numbers of elements, the smaller
2692+
/// vector is extended with poison values to match the size of the larger
2693+
/// vector before concatenation.
2694+
///
2695+
/// 3. Concatenation: The vectors are merged using a shuffle operation that
2696+
/// places all elements of V0 first, followed by all elements of V1.
2697+
///
2698+
/// \param V0 The first vector to merge (must be a vector type)
2699+
/// \param V1 The second vector to merge (must be a vector type)
2700+
/// \param DL The data layout for size calculations
2701+
/// \param NewAIEltTy The desired element type for the result vector
2702+
/// \param Builder IRBuilder for creating new instructions
2703+
/// \return A new vector containing all elements from V0 followed by all
2704+
/// elements from V1
2705+
static Value *mergeTwoVectors(Value *V0, Value *V1, const DataLayout &DL,
2706+
Type *NewAIEltTy,
2707+
IRBuilder<> &Builder) {
26832708
assert(V0->getType()->isVectorTy() && V1->getType()->isVectorTy() &&
26842709
"Can not merge two non-vector values");
26852710

@@ -2689,8 +2714,28 @@ static Value *mergeTwoVectors(Value *V0, Value *V1, IRBuilder<> &Builder) {
26892714
auto *VecType0 = cast<FixedVectorType>(V0->getType());
26902715
auto *VecType1 = cast<FixedVectorType>(V1->getType());
26912716

2692-
assert(VecType0->getElementType() == VecType1->getElementType() &&
2693-
"Can not merge two vectors with different element types");
2717+
// If V0/V1 element types are different from NewAllocaElementType,
2718+
// we need to introduce bitcasts before merging them
2719+
auto BitcastIfNeeded = [&](Value *&V, FixedVectorType *&VecType,
2720+
const char *DebugName) {
2721+
Type *EltType = VecType->getElementType();
2722+
if (EltType != NewAIEltTy) {
2723+
// Calculate new number of elements to maintain same bit width
2724+
unsigned TotalBits =
2725+
VecType->getNumElements() * DL.getTypeSizeInBits(EltType);
2726+
unsigned NewNumElts =
2727+
TotalBits / DL.getTypeSizeInBits(NewAIEltTy);
2728+
2729+
auto *NewVecType = FixedVectorType::get(NewAIEltTy, NewNumElts);
2730+
V = Builder.CreateBitCast(V, NewVecType);
2731+
VecType = NewVecType;
2732+
LLVM_DEBUG(dbgs() << " bitcast " << DebugName << ": " << *V << "\n");
2733+
}
2734+
};
2735+
2736+
BitcastIfNeeded(V0, VecType0, "V0");
2737+
BitcastIfNeeded(V1, VecType1, "V1");
2738+
26942739
unsigned NumElts0 = VecType0->getNumElements();
26952740
unsigned NumElts1 = VecType1->getNumElements();
26962741

@@ -2923,24 +2968,19 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
29232968
uint64_t BeginOffset;
29242969
uint64_t EndOffset;
29252970
Value *StoredValue;
2926-
TypeSize StoredTypeSize = TypeSize::getZero();
2927-
2928-
StoreInfo(StoreInst *SI, uint64_t Begin, uint64_t End, Value *Val,
2929-
TypeSize StoredTypeSize)
2930-
: Store(SI), BeginOffset(Begin), EndOffset(End), StoredValue(Val),
2931-
StoredTypeSize(StoredTypeSize) {}
2971+
StoreInfo(StoreInst *SI, uint64_t Begin, uint64_t End, Value *Val)
2972+
: Store(SI), BeginOffset(Begin), EndOffset(End), StoredValue(Val) {}
29322973
};
29332974

29342975
SmallVector<StoreInfo, 4> StoreInfos;
29352976

29362977
// The alloca must be a fixed vector type
2937-
auto *AllocatedTy = NewAI.getAllocatedType();
2938-
if (!isa<FixedVectorType>(AllocatedTy))
2978+
Type *AllocatedEltTy = nullptr;
2979+
if (auto *FixedVecTy = dyn_cast<FixedVectorType>(NewAI.getAllocatedType()))
2980+
AllocatedEltTy = FixedVecTy->getElementType();
2981+
else
29392982
return std::nullopt;
29402983

2941-
Slice *LoadSlice = nullptr;
2942-
Type *LoadElementType = nullptr;
2943-
Type *StoreElementType = nullptr;
29442984
for (Slice &S : P) {
29452985
auto *User = cast<Instruction>(S.getUse()->getUser());
29462986
if (auto *LI = dyn_cast<LoadInst>(User)) {
@@ -2957,27 +2997,20 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
29572997
if (DL.getTypeSizeInBits(FixedVecTy) !=
29582998
DL.getTypeSizeInBits(NewAI.getAllocatedType()))
29592999
return std::nullopt;
2960-
LoadElementType = FixedVecTy->getElementType();
29613000
TheLoad = LI;
2962-
LoadSlice = &S;
29633001
} else if (auto *SI = dyn_cast<StoreInst>(User)) {
2964-
// The store needs to be a fixed vector type
2965-
// All the stores should have the same element type
3002+
// The stored value should be a fixed vector type
29663003
Type *StoredValueType = SI->getValueOperand()->getType();
2967-
Type *CurrentElementType = nullptr;
2968-
TypeSize StoredTypeSize = TypeSize::getZero();
2969-
if (auto *FixedVecTy = dyn_cast<FixedVectorType>(StoredValueType)) {
2970-
// Fixed vector type - use its element type
2971-
CurrentElementType = FixedVecTy->getElementType();
2972-
StoredTypeSize = DL.getTypeSizeInBits(FixedVecTy);
2973-
} else
3004+
if (!isa<FixedVectorType>(StoredValueType))
29743005
return std::nullopt;
2975-
// Check element type consistency across all stores
2976-
if (StoreElementType && StoreElementType != CurrentElementType)
3006+
3007+
// The total number of stored bits should be the multiple of the new
3008+
// alloca element type size
3009+
if (DL.getTypeSizeInBits(StoredValueType) %
3010+
DL.getTypeSizeInBits(AllocatedEltTy) != 0)
29773011
return std::nullopt;
2978-
StoreElementType = CurrentElementType;
29793012
StoreInfos.emplace_back(SI, S.beginOffset(), S.endOffset(),
2980-
SI->getValueOperand(), StoredTypeSize);
3013+
SI->getValueOperand());
29813014
} else {
29823015
// If we have instructions other than load and store, we cannot do the
29833016
// tree structured merge
@@ -2992,16 +3025,6 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
29923025
if (StoreInfos.size() < 2)
29933026
return std::nullopt;
29943027

2995-
// The load and store element types should be the same
2996-
if (LoadElementType != StoreElementType)
2997-
return std::nullopt;
2998-
2999-
// The load should cover the whole alloca
3000-
// TODO: maybe we can relax this constraint
3001-
if (!LoadSlice || LoadSlice->beginOffset() != NewAllocaBeginOffset ||
3002-
LoadSlice->endOffset() != NewAllocaEndOffset)
3003-
return std::nullopt;
3004-
30053028
// Stores should not overlap and should cover the whole alloca
30063029
// Sort by begin offset
30073030
llvm::sort(StoreInfos, [](const StoreInfo &A, const StoreInfo &B) {
@@ -3011,7 +3034,6 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
30113034
// Check for overlaps and coverage
30123035
uint64_t ExpectedStart = NewAllocaBeginOffset;
30133036
TypeSize TotalStoreBits = TypeSize::getZero();
3014-
Instruction *PrevStore = nullptr;
30153037
for (auto &StoreInfo : StoreInfos) {
30163038
uint64_t BeginOff = StoreInfo.BeginOffset;
30173039
uint64_t EndOff = StoreInfo.EndOffset;
@@ -3021,8 +3043,8 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
30213043
return std::nullopt;
30223044

30233045
ExpectedStart = EndOff;
3024-
TotalStoreBits += StoreInfo.StoredTypeSize;
3025-
PrevStore = StoreInfo.Store;
3046+
TotalStoreBits +=
3047+
DL.getTypeSizeInBits(StoreInfo.Store->getValueOperand()->getType());
30263048
}
30273049
// Check that stores cover the entire alloca
30283050
// We need check both the end offset and the total store bits
@@ -3070,7 +3092,7 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
30703092
VecElements.pop();
30713093
Value *V1 = VecElements.front();
30723094
VecElements.pop();
3073-
Value *Merged = mergeTwoVectors(V0, V1, Builder);
3095+
Value *Merged = mergeTwoVectors(V0, V1, DL, AllocatedEltTy, Builder);
30743096
LLVM_DEBUG(dbgs() << " shufflevector: " << *Merged << "\n");
30753097
VecElements.push(Merged);
30763098
}

llvm/test/Transforms/SROA/vector-promotion-cannot-tree-structure-merge.ll

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -83,41 +83,6 @@ entry:
8383
ret <4 x float> %result
8484
}
8585

86-
define <4 x float> @test_store_not_same_element_type() {
87-
entry:
88-
%alloca = alloca <4 x float>
89-
90-
%ptr0 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 0
91-
%float_vec = insertelement <2 x float> poison, float 1.0, i32 0
92-
%float_vec2 = insertelement <2 x float> %float_vec, float 2.0, i32 1
93-
store <2 x float> %float_vec2, ptr %ptr0
94-
95-
%ptr1 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 2
96-
%int_vec = insertelement <2 x i32> poison, i32 3, i32 0
97-
%int_vec2 = insertelement <2 x i32> %int_vec, i32 4, i32 1
98-
store <2 x i32> %int_vec2, ptr %ptr1
99-
100-
%result = load <4 x float>, ptr %alloca
101-
ret <4 x float> %result
102-
}
103-
104-
define <4 x i32> @test_load_store_different_element_type() {
105-
entry:
106-
%alloca = alloca <4 x float>
107-
108-
%ptr0 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 0
109-
%float_vec = insertelement <2 x float> poison, float 1.0, i32 0
110-
%float_vec2 = insertelement <2 x float> %float_vec, float 2.0, i32 1
111-
store <2 x float> %float_vec2, ptr %ptr0
112-
113-
%ptr1 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 2
114-
%float_vec3 = insertelement <2 x float> poison, float 3.0, i32 0
115-
%float_vec4 = insertelement <2 x float> %float_vec3, float 4.0, i32 1
116-
store <2 x float> %float_vec4, ptr %ptr1
117-
118-
%result = load <4 x i32>, ptr %alloca
119-
ret <4 x i32> %result
120-
}
12186

12287
define <4 x float> @test_no_stores() {
12388
entry:

llvm/test/Transforms/SROA/vector-promotion-via-tree-structure-merge.ll

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,60 @@ entry:
287287
ret <7 x float> %result
288288
}
289289

290+
; Load and store with different element type
291+
define <4 x double> @load_store_different_element_type(<2 x i32> %a, <2 x float> %b, <2 x float> %c, <2 x i32> %d) {
292+
; CHECK-LABEL: define <4 x double> @load_store_different_element_type(
293+
; CHECK-SAME: <2 x i32> [[A:%.*]], <2 x float> [[B:%.*]], <2 x float> [[C:%.*]], <2 x i32> [[D:%.*]]) {
294+
; CHECK-NEXT: [[ENTRY:.*:]]
295+
; CHECK-NEXT: [[TMP0:%.*]] = bitcast <2 x i32> [[A]] to <1 x double>
296+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x float> [[B]] to <1 x double>
297+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <1 x double> [[TMP0]], <1 x double> [[TMP1]], <2 x i32> <i32 0, i32 1>
298+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <2 x float> [[C]] to <1 x double>
299+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <2 x i32> [[D]] to <1 x double>
300+
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <1 x double> [[TMP3]], <1 x double> [[TMP4]], <2 x i32> <i32 0, i32 1>
301+
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x double> [[TMP2]], <2 x double> [[TMP5]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
302+
; CHECK-NEXT: ret <4 x double> [[TMP6]]
303+
;
304+
; DEBUG-LABEL: define <4 x double> @load_store_different_element_type(
305+
; DEBUG-SAME: <2 x i32> [[A:%.*]], <2 x float> [[B:%.*]], <2 x float> [[C:%.*]], <2 x i32> [[D:%.*]]) !dbg [[DBG117:![0-9]+]] {
306+
; DEBUG-NEXT: [[ENTRY:.*:]]
307+
; DEBUG-NEXT: #dbg_value(ptr poison, [[META119:![0-9]+]], !DIExpression(), [[META125:![0-9]+]])
308+
; DEBUG-NEXT: #dbg_value(ptr undef, [[META119]], !DIExpression(), [[META125]])
309+
; DEBUG-NEXT: #dbg_value(ptr undef, [[META120:![0-9]+]], !DIExpression(), [[META126:![0-9]+]])
310+
; DEBUG-NEXT: #dbg_value(ptr undef, [[META121:![0-9]+]], !DIExpression(), [[META127:![0-9]+]])
311+
; DEBUG-NEXT: #dbg_value(ptr undef, [[META122:![0-9]+]], !DIExpression(), [[META128:![0-9]+]])
312+
; DEBUG-NEXT: #dbg_value(ptr undef, [[META123:![0-9]+]], !DIExpression(), [[META129:![0-9]+]])
313+
; DEBUG-NEXT: [[TMP0:%.*]] = bitcast <2 x i32> [[A]] to <1 x double>, !dbg [[DBG130:![0-9]+]]
314+
; DEBUG-NEXT: [[TMP1:%.*]] = bitcast <2 x float> [[B]] to <1 x double>, !dbg [[DBG130]]
315+
; DEBUG-NEXT: [[TMP2:%.*]] = shufflevector <1 x double> [[TMP0]], <1 x double> [[TMP1]], <2 x i32> <i32 0, i32 1>, !dbg [[DBG130]]
316+
; DEBUG-NEXT: [[TMP3:%.*]] = bitcast <2 x float> [[C]] to <1 x double>, !dbg [[DBG130]]
317+
; DEBUG-NEXT: [[TMP4:%.*]] = bitcast <2 x i32> [[D]] to <1 x double>, !dbg [[DBG130]]
318+
; DEBUG-NEXT: [[TMP5:%.*]] = shufflevector <1 x double> [[TMP3]], <1 x double> [[TMP4]], <2 x i32> <i32 0, i32 1>, !dbg [[DBG130]]
319+
; DEBUG-NEXT: [[TMP6:%.*]] = shufflevector <2 x double> [[TMP2]], <2 x double> [[TMP5]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>, !dbg [[DBG130]]
320+
; DEBUG-NEXT: #dbg_value(<4 x double> [[TMP6]], [[META124:![0-9]+]], !DIExpression(), [[META131:![0-9]+]])
321+
; DEBUG-NEXT: ret <4 x double> [[TMP6]], !dbg [[DBG132:![0-9]+]]
322+
;
323+
entry:
324+
%alloca = alloca <8 x float>
325+
326+
; Store the vectors at different offsets
327+
%ptr0 = getelementptr inbounds <8 x float>, ptr %alloca, i32 0, i32 0
328+
store <2 x i32> %a, ptr %ptr0
329+
330+
%ptr1 = getelementptr inbounds <8 x float>, ptr %alloca, i32 0, i32 2
331+
store <2 x float> %b, ptr %ptr1
332+
333+
%ptr2 = getelementptr inbounds <8 x float>, ptr %alloca, i32 0, i32 4
334+
store <2 x float> %c, ptr %ptr2
335+
336+
%ptr3 = getelementptr inbounds <8 x float>, ptr %alloca, i32 0, i32 6
337+
store <2 x i32> %d, ptr %ptr3
338+
339+
; Load the complete vector
340+
%result = load <4 x double>, ptr %alloca
341+
ret <4 x double> %result
342+
}
343+
290344
;.
291345
; DEBUG: [[META0:![0-9]+]] = distinct !DICompileUnit(language: DW_LANG_C, file: [[META1:![0-9]+]], producer: "debugify", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
292346
; DEBUG: [[META1]] = !DIFile(filename: "{{.*}}<stdin>", directory: {{.*}})
@@ -402,6 +456,22 @@ entry:
402456
; DEBUG: [[DBG114]] = !DILocation(line: 72, column: 1, scope: [[DBG103]])
403457
; DEBUG: [[META115]] = !DILocation(line: 73, column: 1, scope: [[DBG103]])
404458
; DEBUG: [[DBG116]] = !DILocation(line: 74, column: 1, scope: [[DBG103]])
459+
; DEBUG: [[DBG117]] = distinct !DISubprogram(name: "load_store_different_element_type", linkageName: "load_store_different_element_type", scope: null, file: [[META1]], line: 75, type: [[META6]], scopeLine: 75, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: [[META0]], retainedNodes: [[META118:![0-9]+]])
460+
; DEBUG: [[META118]] = !{[[META119]], [[META120]], [[META121]], [[META122]], [[META123]], [[META124]]}
461+
; DEBUG: [[META119]] = !DILocalVariable(name: "41", scope: [[DBG117]], file: [[META1]], line: 75, type: [[META10]])
462+
; DEBUG: [[META120]] = !DILocalVariable(name: "42", scope: [[DBG117]], file: [[META1]], line: 76, type: [[META10]])
463+
; DEBUG: [[META121]] = !DILocalVariable(name: "43", scope: [[DBG117]], file: [[META1]], line: 78, type: [[META10]])
464+
; DEBUG: [[META122]] = !DILocalVariable(name: "44", scope: [[DBG117]], file: [[META1]], line: 80, type: [[META10]])
465+
; DEBUG: [[META123]] = !DILocalVariable(name: "45", scope: [[DBG117]], file: [[META1]], line: 82, type: [[META10]])
466+
; DEBUG: [[META124]] = !DILocalVariable(name: "46", scope: [[DBG117]], file: [[META1]], line: 84, type: [[META16]])
467+
; DEBUG: [[META125]] = !DILocation(line: 75, column: 1, scope: [[DBG117]])
468+
; DEBUG: [[META126]] = !DILocation(line: 76, column: 1, scope: [[DBG117]])
469+
; DEBUG: [[META127]] = !DILocation(line: 78, column: 1, scope: [[DBG117]])
470+
; DEBUG: [[META128]] = !DILocation(line: 80, column: 1, scope: [[DBG117]])
471+
; DEBUG: [[META129]] = !DILocation(line: 82, column: 1, scope: [[DBG117]])
472+
; DEBUG: [[DBG130]] = !DILocation(line: 83, column: 1, scope: [[DBG117]])
473+
; DEBUG: [[META131]] = !DILocation(line: 84, column: 1, scope: [[DBG117]])
474+
; DEBUG: [[DBG132]] = !DILocation(line: 85, column: 1, scope: [[DBG117]])
405475
;.
406476
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
407477
; CHECK-MODIFY-CFG: {{.*}}

0 commit comments

Comments
 (0)