Skip to content

Commit 3146a3b

Browse files
author
chengjunp
committed
Fix bugs and update tests
1 parent 659dfd7 commit 3146a3b

File tree

3 files changed

+220
-373
lines changed

3 files changed

+220
-373
lines changed

llvm/lib/Transforms/Scalar/SROA.cpp

Lines changed: 36 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2693,9 +2693,6 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,
26932693
/// elements from V1
26942694
static Value *mergeTwoVectors(Value *V0, Value *V1, const DataLayout &DL,
26952695
Type *NewAIEltTy, IRBuilder<> &Builder) {
2696-
assert(V0->getType()->isVectorTy() && V1->getType()->isVectorTy() &&
2697-
"Can not merge two non-vector values");
2698-
26992696
// V0 and V1 are vectors
27002697
// Create a new vector type with combined elements
27012698
// Use ShuffleVector to concatenate the vectors
@@ -2737,18 +2734,15 @@ static Value *mergeTwoVectors(Value *V0, Value *V1, const DataLayout &DL,
27372734
unsigned SmallSize = std::min(NumElts0, NumElts1);
27382735
unsigned LargeSize = std::max(NumElts0, NumElts1);
27392736
bool IsV0Smaller = NumElts0 < NumElts1;
2740-
Value *SmallVec = IsV0Smaller ? V0 : V1;
2741-
2737+
Value *&ExtendedVec = IsV0Smaller ? V0 : V1;
27422738
SmallVector<int, 16> ExtendMask;
27432739
for (unsigned i = 0; i < SmallSize; ++i)
27442740
ExtendMask.push_back(i);
27452741
for (unsigned i = SmallSize; i < LargeSize; ++i)
27462742
ExtendMask.push_back(PoisonMaskElem);
2747-
Value *ExtendedVec = Builder.CreateShuffleVector(
2748-
SmallVec, PoisonValue::get(SmallVec->getType()), ExtendMask);
2743+
ExtendedVec = Builder.CreateShuffleVector(
2744+
ExtendedVec, PoisonValue::get(ExtendedVec->getType()), ExtendMask);
27492745
LLVM_DEBUG(dbgs() << " shufflevector: " << *ExtendedVec << "\n");
2750-
V0 = IsV0Smaller ? ExtendedVec : V0;
2751-
V1 = IsV0Smaller ? V1 : ExtendedVec;
27522746
for (unsigned i = 0; i < NumElts0; ++i)
27532747
ShuffleMask.push_back(i);
27542748
for (unsigned i = 0; i < NumElts1; ++i)
@@ -2961,53 +2955,45 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
29612955

29622956
SmallVector<StoreInfo, 4> StoreInfos;
29632957

2964-
// The alloca must be a fixed vector type
2965-
Type *AllocatedEltTy = nullptr;
2966-
if (auto *FixedVecTy = dyn_cast<FixedVectorType>(NewAI.getAllocatedType()))
2967-
AllocatedEltTy = FixedVecTy->getElementType();
2968-
else
2969-
return std::nullopt;
2970-
// If the allocated element type is a pointer, we do not handle it
2971-
// TODO: handle this case by using inttoptr/ptrtoint
2972-
if (AllocatedEltTy->isPtrOrPtrVectorTy())
2973-
return std::nullopt;
2958+
// If the new alloca is a fixed vector type, we use its element type as the
2959+
// allocated element type, otherwise we use i8 as the allocated element
2960+
Type *AllocatedEltTy =
2961+
isa<FixedVectorType>(NewAI.getAllocatedType())
2962+
? cast<FixedVectorType>(NewAI.getAllocatedType())->getElementType()
2963+
: Type::getInt8Ty(NewAI.getContext());
2964+
2965+
// Helper to check if a type is
2966+
// 1. A fixed vector type
2967+
// 2. The element type is not a pointer
2968+
// 3. The element type size is byte-aligned
2969+
// We only handle the cases that the ld/st meet these conditions
2970+
auto IsTypeValidForTreeStructuredMerge = [&](Type *Ty) -> bool {
2971+
auto *FixedVecTy = dyn_cast<FixedVectorType>(Ty);
2972+
return FixedVecTy &&
2973+
DL.getTypeSizeInBits(FixedVecTy->getElementType()) % 8 == 0 &&
2974+
!FixedVecTy->getElementType()->isPointerTy();
2975+
};
29742976

29752977
for (Slice &S : P) {
29762978
auto *User = cast<Instruction>(S.getUse()->getUser());
29772979
if (auto *LI = dyn_cast<LoadInst>(User)) {
2978-
// Do not handle the case where there is more than one load
2979-
// TODO: maybe we can handle this case
2980-
if (TheLoad)
2981-
return std::nullopt;
2982-
// If load is not a fixed vector type, we do not handle it
2983-
// If the number of loaded bits is not the same as the new alloca type
2984-
// size, we do not handle it
2985-
auto *FixedVecTy = dyn_cast<FixedVectorType>(LI->getType());
2986-
if (!FixedVecTy)
2987-
return std::nullopt;
2988-
if (DL.getTypeSizeInBits(FixedVecTy) !=
2989-
DL.getTypeSizeInBits(NewAI.getAllocatedType()))
2990-
return std::nullopt;
2991-
// If the loaded value is a pointer, we do not handle it
2992-
// TODO: handle this case by using inttoptr/ptrtoint
2993-
if (FixedVecTy->getElementType()->isPtrOrPtrVectorTy())
2980+
// Do not handle the case if
2981+
// 1. There is more than one load
2982+
// 2. The load is volatile
2983+
// 3. The load does not read the entire alloca structure
2984+
// 4. The load does not meet the conditions in the helper function
2985+
if (TheLoad || !IsTypeValidForTreeStructuredMerge(LI->getType()) ||
2986+
S.beginOffset() != NewAllocaBeginOffset ||
2987+
S.endOffset() != NewAllocaEndOffset ||
2988+
LI->isVolatile())
29942989
return std::nullopt;
29952990
TheLoad = LI;
29962991
} else if (auto *SI = dyn_cast<StoreInst>(User)) {
2997-
// The stored value should be a fixed vector type
2998-
Type *StoredValueType = SI->getValueOperand()->getType();
2999-
if (!isa<FixedVectorType>(StoredValueType))
3000-
return std::nullopt;
3001-
3002-
// The total number of stored bits should be the multiple of the new
3003-
// alloca element type size
3004-
if (DL.getTypeSizeInBits(StoredValueType) %
3005-
DL.getTypeSizeInBits(AllocatedEltTy) !=
3006-
0)
3007-
return std::nullopt;
3008-
// If the stored value is a pointer, we do not handle it
3009-
// TODO: handle this case by using inttoptr/ptrtoint
3010-
if (StoredValueType->isPtrOrPtrVectorTy())
2992+
// Do not handle the case if
2993+
// 1. The store does not meet the conditions in the helper function
2994+
// 2. The store is volatile
2995+
if (!IsTypeValidForTreeStructuredMerge(SI->getValueOperand()->getType()) ||
2996+
SI->isVolatile())
30112997
return std::nullopt;
30122998
StoreInfos.emplace_back(SI, S.beginOffset(), S.endOffset(),
30132999
SI->getValueOperand());
@@ -3033,7 +3019,6 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
30333019

30343020
// Check for overlaps and coverage
30353021
uint64_t ExpectedStart = NewAllocaBeginOffset;
3036-
TypeSize TotalStoreBits = TypeSize::getZero();
30373022
for (auto &StoreInfo : StoreInfos) {
30383023
uint64_t BeginOff = StoreInfo.BeginOffset;
30393024
uint64_t EndOff = StoreInfo.EndOffset;
@@ -3043,13 +3028,9 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
30433028
return std::nullopt;
30443029

30453030
ExpectedStart = EndOff;
3046-
TotalStoreBits +=
3047-
DL.getTypeSizeInBits(StoreInfo.Store->getValueOperand()->getType());
30483031
}
30493032
// Check that stores cover the entire alloca
3050-
// We need check both the end offset and the total store bits
3051-
if (ExpectedStart != NewAllocaEndOffset ||
3052-
TotalStoreBits != DL.getTypeSizeInBits(NewAI.getAllocatedType()))
3033+
if (ExpectedStart != NewAllocaEndOffset)
30533034
return std::nullopt;
30543035

30553036
// Stores should be in the same basic block

llvm/test/Transforms/SROA/vector-promotion-cannot-tree-structure-merge.ll

Lines changed: 71 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ entry:
2121

2222
define <4 x float> @test_more_than_one_load(<2 x float> %a, <2 x float> %b) {
2323
entry:
24-
%alloca = alloca <4 x float>
24+
%alloca = alloca [4 x float]
2525

26-
%ptr0 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 0
26+
%ptr0 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 0
2727
store <2 x float> %a, ptr %ptr0
2828

29-
%ptr1 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 2
29+
%ptr1 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 2
3030
store <2 x float> %b, ptr %ptr1
3131

3232
%result1 = load <4 x float>, ptr %alloca
@@ -38,19 +38,19 @@ entry:
3838

3939
define void @test_no_load(<4 x float> %a) {
4040
entry:
41-
%alloca = alloca <4 x float>
41+
%alloca = alloca [4 x float]
4242
store <4 x float> %a, ptr %alloca
4343
ret void
4444
}
4545

4646
define i32 @test_load_not_fixed_vector(<2 x float> %a, <2 x float> %b) {
4747
entry:
48-
%alloca = alloca <4 x float>
48+
%alloca = alloca [4 x float]
4949

50-
%ptr0 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 0
50+
%ptr0 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 0
5151
store <2 x float> %a, ptr %ptr0
5252

53-
%ptr1 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 2
53+
%ptr1 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 2
5454
store <2 x float> %b, ptr %ptr1
5555

5656
%result = load i32, ptr %alloca
@@ -59,12 +59,12 @@ entry:
5959

6060
define <3 x float> @test_load_not_covering_alloca(<2 x float> %a, <2 x float> %b) {
6161
entry:
62-
%alloca = alloca <4 x float>
62+
%alloca = alloca [4 x float]
6363

64-
%ptr0 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 0
64+
%ptr0 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 0
6565
store <2 x float> %a, ptr %ptr0
6666

67-
%ptr1 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 2
67+
%ptr1 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 2
6868
store <2 x float> %b, ptr %ptr1
6969

7070
%result = load <3 x float>, ptr %ptr0
@@ -73,9 +73,9 @@ entry:
7373

7474
define <4 x float> @test_store_not_fixed_vector(<vscale x 2 x float> %a) {
7575
entry:
76-
%alloca = alloca <4 x float>
76+
%alloca = alloca [4 x float]
7777

78-
%ptr0 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 0
78+
%ptr0 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 0
7979
%fixed = extractelement <vscale x 2 x float> %a, i32 0
8080
store float %fixed, ptr %ptr0
8181

@@ -86,23 +86,23 @@ entry:
8686

8787
define <4 x float> @test_no_stores() {
8888
entry:
89-
%alloca = alloca <4 x float>
89+
%alloca = alloca [4 x float]
9090

9191
%result = load <4 x float>, ptr %alloca
9292
ret <4 x float> %result
9393
}
9494

9595
define <4 x float> @test_stores_overlapping(<2 x float> %a, <2 x float> %b, <2 x float> %c) {
9696
entry:
97-
%alloca = alloca <4 x float>
97+
%alloca = alloca [4 x float]
9898

99-
%ptr0 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 0
99+
%ptr0 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 0
100100
store <2 x float> %a, ptr %ptr0
101101

102-
%ptr1 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 1
102+
%ptr1 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 1
103103
store <2 x float> %b, ptr %ptr1
104104

105-
%ptr2 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 2
105+
%ptr2 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 2
106106
store <2 x float> %c, ptr %ptr2
107107

108108
%result = load <4 x float>, ptr %alloca
@@ -111,9 +111,9 @@ entry:
111111

112112
define <4 x float> @test_stores_not_covering_alloca(<2 x float> %a) {
113113
entry:
114-
%alloca = alloca <4 x float>
114+
%alloca = alloca [4 x float]
115115

116-
%ptr0 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 0
116+
%ptr0 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 0
117117
store <2 x float> %a, ptr %ptr0
118118

119119
%result = load <4 x float>, ptr %alloca
@@ -122,15 +122,15 @@ entry:
122122

123123
define <4 x float> @test_stores_not_same_basic_block(<2 x float> %a, <2 x float> %b, i1 %cond) {
124124
entry:
125-
%alloca = alloca <4 x float>
125+
%alloca = alloca [4 x float]
126126

127-
%ptr0 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 0
127+
%ptr0 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 0
128128
store <2 x float> %a, ptr %ptr0
129129

130130
br i1 %cond, label %then, label %else
131131

132132
then:
133-
%ptr1 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 2
133+
%ptr1 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 2
134134
store <2 x float> %b, ptr %ptr1
135135
br label %merge
136136

@@ -144,36 +144,79 @@ merge:
144144

145145
define <4 x float> @test_load_before_stores(<2 x float> %a, <2 x float> %b) {
146146
entry:
147-
%alloca = alloca <4 x float>
147+
%alloca = alloca [4 x float]
148148

149-
%ptr0 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 0
149+
%ptr0 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 0
150150
store <2 x float> %a, ptr %ptr0
151151

152152
%intermediate = load <4 x float>, ptr %alloca
153153

154-
%ptr1 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 2
154+
%ptr1 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 2
155155
store <2 x float> %b, ptr %ptr1
156156

157157
ret <4 x float> %intermediate
158158
}
159159

160160
define <4 x float> @test_other_instructions(<2 x float> %a, <2 x float> %b) {
161161
entry:
162-
%alloca = alloca <4 x float>
162+
%alloca = alloca [4 x float]
163163

164164
; Store first vector
165-
%ptr0 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 0
165+
%ptr0 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 0
166166
store <2 x float> %a, ptr %ptr0
167167

168168
; Other instruction (memset) that's not a simple load/store
169169
call void @llvm.memset.p0.i64(ptr %alloca, i8 0, i64 8, i1 false)
170170

171171
; Store second vector
172-
%ptr1 = getelementptr inbounds <4 x float>, ptr %alloca, i32 0, i32 2
172+
%ptr1 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 2
173173
store <2 x float> %b, ptr %ptr1
174174

175175
%result = load <4 x float>, ptr %alloca
176176
ret <4 x float> %result
177177
}
178178

179+
define <4 x float> @volatile_stores(<2 x i32> %a, <2 x i32> %b) {
180+
entry:
181+
%alloca = alloca [4 x float]
182+
183+
%ptr0 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 0
184+
store volatile <2 x i32> %a, ptr %ptr0
185+
186+
%ptr1 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 2
187+
store volatile <2 x i32> %b, ptr %ptr1
188+
189+
%result = load <4 x float>, ptr %alloca
190+
ret <4 x float> %result
191+
}
192+
193+
define <4 x float> @volatile_loads(<2 x i32> %a, <2 x i32> %b) {
194+
entry:
195+
%alloca = alloca [4 x float]
196+
197+
%ptr0 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 0
198+
store <2 x i32> %a, ptr %ptr0
199+
200+
%ptr1 = getelementptr inbounds [4 x float], ptr %alloca, i32 0, i32 2
201+
store <2 x i32> %b, ptr %ptr1
202+
203+
%result = load volatile <4 x float>, ptr %alloca
204+
ret <4 x float> %result
205+
}
206+
207+
define <4 x i15> @non_byte_aligned_alloca(<2 x i15> %a, <2 x i15> %b) {
208+
entry:
209+
%alloca = alloca [4 x i15]
210+
211+
%ptr0 = getelementptr inbounds [4 x i15], ptr %alloca, i32 0, i32 0
212+
store <2 x i15> %a, ptr %ptr0
213+
214+
%ptr1 = getelementptr inbounds [4 x i15], ptr %alloca, i32 0, i32 2
215+
store <2 x i15> %b, ptr %ptr1
216+
217+
%result = load <4 x i15>, ptr %alloca
218+
ret <4 x i15> %result
219+
220+
}
221+
179222
declare void @llvm.memset.p0.i64(ptr nocapture writeonly, i8, i64, i1 immarg)

0 commit comments

Comments
 (0)