Skip to content

Commit 3191cdf

Browse files
committed
[SROA] Canonicalize homogeneous structs into fixed vectors to eliminate allocas
Motivation: SROA would keep temporary allocas (e.g. copies and zero-inits) for homogeneous, 16-byte structs. On targets like AMDGPU these map to scratch memory and can severely hurt performance. The following example could not eliminate the allocas before this change: ``` struct alignas(16) myint4 { int x, y, z, w; }; void foo(myint4* x, myint4 y, int cond) { myint4 temp = y; myint4 zero{0,0,0,0}; myint4 data = cond ? temp : zero; *x = data; } ``` Method: During rewritePartition, when the slice type is a struct of 2 or 4 identical element types, and DataLayout proves it is tightly packed (no padding; element offsets are i*EltSize; StructSize == N*EltSize), and the element type is a valid fixed-size vector element, and the total size is at or below a configurable threshold, rewrite the slice type to a fixed vector <N x EltTy>. This runs before the alloca-reuse fast path. Why it works: For tightly packed homogeneous structs, the in-memory representation is bitwise-identical to the corresponding fixed vector, so the transformation is semantics-preserving. The vector form enables SROA/ InstCombine/GVN to replace memcpy/memset and conditional copies with vector selects and a single vector store, allowing the allocas to be eliminated. Tests (flat and nested struct) show allocas/mem* disappear and a <4 x i32> store remains. Control: Introduces -sroa-max-struct-to-vector-bytes=N (default 0 = disabled) to guard the transform by struct size. Enable via: - opt: -passes='sroa,gvn,instcombine,simplifycfg' \ -sroa-max-struct-to-vector-bytes=16 - clang/llc: -mllvm -sroa-max-struct-to-vector-bytes=16 Set to 0 to turn the optimization off if regressions are observed.
1 parent c6b4ef1 commit 3191cdf

File tree

2 files changed

+368
-0
lines changed

2 files changed

+368
-0
lines changed

llvm/lib/Transforms/Scalar/SROA.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ namespace llvm {
122122
/// Disable running mem2reg during SROA in order to test or debug SROA.
123123
static cl::opt<bool> SROASkipMem2Reg("sroa-skip-mem2reg", cl::init(false),
124124
cl::Hidden);
125+
/// Maximum struct size in bytes to canonicalize homogeneous structs to vectors.
126+
/// 0 disables the transformation to avoid regressions by default.
127+
static cl::opt<unsigned> SROAMaxStructToVectorBytes(
128+
"sroa-max-struct-to-vector-bytes", cl::init(0), cl::Hidden,
129+
cl::desc("Max struct size in bytes to canonicalize homogeneous structs to "
130+
"fixed vectors (0=disable)"));
125131
extern cl::opt<bool> ProfcheckDisableMetadataFixes;
126132
} // namespace llvm
127133

@@ -5267,6 +5273,57 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
52675273
if (VecTy)
52685274
SliceTy = VecTy;
52695275

5276+
// Canonicalize homogeneous, tightly-packed 2- or 4-field structs to
5277+
// a fixed-width vector type when the DataLayout proves bitwise identity.
5278+
// Do this BEFORE the alloca reuse fast-path so that we don't miss
5279+
// opportunities to vectorize memcpy on allocas whose SliceTy initially
5280+
// equals the allocated type.
5281+
if (SROAMaxStructToVectorBytes) {
5282+
if (auto *STy = dyn_cast<StructType>(SliceTy)) {
5283+
unsigned NumElts = STy->getNumElements();
5284+
if (NumElts == 2 || NumElts == 4) {
5285+
Type *EltTy =
5286+
STy->getNumElements() > 0 ? STy->getElementType(0) : nullptr;
5287+
bool IsAllowedElt = false;
5288+
if (EltTy && VectorType::isValidElementType(EltTy)) {
5289+
if (auto *IT = dyn_cast<IntegerType>(EltTy))
5290+
IsAllowedElt = IT->getBitWidth() >= 8;
5291+
else if (EltTy->isFloatingPointTy())
5292+
IsAllowedElt = true;
5293+
}
5294+
bool AllSame = IsAllowedElt;
5295+
for (unsigned I = 1; AllSame && I < NumElts; ++I)
5296+
if (STy->getElementType(I) != EltTy)
5297+
AllSame = false;
5298+
if (AllSame) {
5299+
const StructLayout *SL = DL.getStructLayout(STy);
5300+
TypeSize EltTS = DL.getTypeAllocSize(EltTy);
5301+
if (EltTS.isFixed()) {
5302+
const uint64_t EltSize = EltTS.getFixedValue();
5303+
if (EltSize >= 1) {
5304+
const uint64_t StructSize = SL->getSizeInBytes();
5305+
if (StructSize != 0 && StructSize <= SROAMaxStructToVectorBytes) {
5306+
bool TightlyPacked = (StructSize == NumElts * EltSize);
5307+
if (TightlyPacked) {
5308+
for (unsigned I = 0; I < NumElts; ++I) {
5309+
if (SL->getElementOffset(I) != I * EltSize) {
5310+
TightlyPacked = false;
5311+
break;
5312+
}
5313+
}
5314+
}
5315+
if (TightlyPacked) {
5316+
Type *NewSliceTy = FixedVectorType::get(EltTy, NumElts);
5317+
SliceTy = NewSliceTy;
5318+
}
5319+
}
5320+
}
5321+
}
5322+
}
5323+
}
5324+
}
5325+
}
5326+
52705327
// Check for the case where we're going to rewrite to a new alloca of the
52715328
// exact same type as the original, and with the same access offsets. In that
52725329
// case, re-use the existing alloca, but still run through the rewriter to
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
; RUN: opt -passes='sroa,gvn,instcombine,simplifycfg' -S \
2+
; RUN: -sroa-max-struct-to-vector-bytes=16 %s \
3+
; RUN: | FileCheck %s \
4+
; RUN: --check-prefixes=FLAT,NESTED,PADDED,NONHOMO,I1,PTR
5+
%struct.myint4 = type { i32, i32, i32, i32 }
6+
7+
; FLAT-LABEL: define dso_local void @foo_flat(
8+
; FLAT-NOT: alloca
9+
; FLAT-NOT: llvm.memcpy
10+
; FLAT-NOT: llvm.memset
11+
; FLAT: insertelement <2 x i64>
12+
; FLAT: bitcast <2 x i64> %{{[^ ]+}} to <4 x i32>
13+
; FLAT: select i1 %{{[^,]+}}, <4 x i32> zeroinitializer, <4 x i32> %{{[^)]+}}
14+
; FLAT: store <4 x i32> %{{[^,]+}}, ptr %x, align 16
15+
; FLAT: ret void
16+
define dso_local void @foo_flat(ptr noundef %x, i64 %y.coerce0, i64 %y.coerce1, i32 noundef %cond) {
17+
entry:
18+
%y = alloca %struct.myint4, align 16
19+
%x.addr = alloca ptr, align 8
20+
%cond.addr = alloca i32, align 4
21+
%temp = alloca %struct.myint4, align 16
22+
%zero = alloca %struct.myint4, align 16
23+
%data = alloca %struct.myint4, align 16
24+
%0 = getelementptr inbounds nuw { i64, i64 }, ptr %y, i32 0, i32 0
25+
store i64 %y.coerce0, ptr %0, align 16
26+
%1 = getelementptr inbounds nuw { i64, i64 }, ptr %y, i32 0, i32 1
27+
store i64 %y.coerce1, ptr %1, align 8
28+
store ptr %x, ptr %x.addr, align 8
29+
store i32 %cond, ptr %cond.addr, align 4
30+
call void @llvm.lifetime.start.p0(ptr %temp)
31+
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %temp, ptr align 16 %y, i64 16, i1 false)
32+
call void @llvm.lifetime.start.p0(ptr %zero)
33+
call void @llvm.memset.p0.i64(ptr align 16 %zero, i8 0, i64 16, i1 false)
34+
call void @llvm.lifetime.start.p0(ptr %data)
35+
%2 = load i32, ptr %cond.addr, align 4
36+
%tobool = icmp ne i32 %2, 0
37+
br i1 %tobool, label %cond.true, label %cond.false
38+
39+
cond.true:
40+
br label %cond.end
41+
42+
cond.false:
43+
br label %cond.end
44+
45+
cond.end:
46+
%cond1 = phi ptr [ %temp, %cond.true ], [ %zero, %cond.false ]
47+
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %data, ptr align 16 %cond1, i64 16, i1 false)
48+
%3 = load ptr, ptr %x.addr, align 8
49+
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %3, ptr align 16 %data, i64 16, i1 false)
50+
call void @llvm.lifetime.end.p0(ptr %data)
51+
call void @llvm.lifetime.end.p0(ptr %zero)
52+
call void @llvm.lifetime.end.p0(ptr %temp)
53+
ret void
54+
}
55+
%struct.myint4_base_n = type { i32, i32, i32, i32 }
56+
%struct.myint4_nested = type { %struct.myint4_base_n }
57+
58+
; NESTED-LABEL: define dso_local void @foo_nested(
59+
; NESTED-NOT: alloca
60+
; NESTED-NOT: llvm.memcpy
61+
; NESTED-NOT: llvm.memset
62+
; NESTED: insertelement <2 x i64>
63+
; NESTED: bitcast <2 x i64> %{{[^ ]+}} to <4 x i32>
64+
; NESTED: select i1 %{{[^,]+}}, <4 x i32> zeroinitializer, <4 x i32> %{{[^)]+}}
65+
; NESTED: store <4 x i32> %{{[^,]+}}, ptr %x, align 16
66+
; NESTED: ret void
67+
define dso_local void @foo_nested(ptr noundef %x, i64 %y.coerce0, i64 %y.coerce1, i32 noundef %cond) {
68+
entry:
69+
%y = alloca %struct.myint4_nested, align 16
70+
%x.addr = alloca ptr, align 8
71+
%cond.addr = alloca i32, align 4
72+
%temp = alloca %struct.myint4_nested, align 16
73+
%zero = alloca %struct.myint4_nested, align 16
74+
%data = alloca %struct.myint4_nested, align 16
75+
%0 = getelementptr inbounds nuw { i64, i64 }, ptr %y, i32 0, i32 0
76+
store i64 %y.coerce0, ptr %0, align 16
77+
%1 = getelementptr inbounds nuw { i64, i64 }, ptr %y, i32 0, i32 1
78+
store i64 %y.coerce1, ptr %1, align 8
79+
store ptr %x, ptr %x.addr, align 8
80+
store i32 %cond, ptr %cond.addr, align 4
81+
call void @llvm.lifetime.start.p0(ptr %temp)
82+
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %temp, ptr align 16 %y, i64 16, i1 false)
83+
call void @llvm.lifetime.start.p0(ptr %zero)
84+
call void @llvm.memset.p0.i64(ptr align 16 %zero, i8 0, i64 16, i1 false)
85+
call void @llvm.lifetime.start.p0(ptr %data)
86+
%2 = load i32, ptr %cond.addr, align 4
87+
%tobool = icmp ne i32 %2, 0
88+
br i1 %tobool, label %cond.true, label %cond.false
89+
90+
cond.true:
91+
br label %cond.end
92+
93+
cond.false:
94+
br label %cond.end
95+
96+
cond.end:
97+
%cond1 = phi ptr [ %temp, %cond.true ], [ %zero, %cond.false ]
98+
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %data, ptr align 16 %cond1, i64 16, i1 false)
99+
%3 = load ptr, ptr %x.addr, align 8
100+
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %3, ptr align 16 %data, i64 16, i1 false)
101+
call void @llvm.lifetime.end.p0(ptr %data)
102+
call void @llvm.lifetime.end.p0(ptr %zero)
103+
call void @llvm.lifetime.end.p0(ptr %temp)
104+
ret void
105+
}
106+
107+
; PADDED-LABEL: define dso_local void @foo_padded(
108+
; PADDED: llvm.memcpy
109+
; PADDED-NOT: store <
110+
; PADDED: ret void
111+
%struct.padded = type { i32, i8, i32, i8 }
112+
define dso_local void @foo_padded(ptr noundef %x, i32 %a0, i8 %a1,
113+
i32 %a2, i8 %a3,
114+
i32 noundef %cond) {
115+
entry:
116+
%y = alloca %struct.padded, align 4
117+
%x.addr = alloca ptr, align 8
118+
%cond.addr = alloca i32, align 4
119+
%temp = alloca %struct.padded, align 4
120+
%zero = alloca %struct.padded, align 4
121+
%data = alloca %struct.padded, align 4
122+
%y_i32_0 = getelementptr inbounds %struct.padded, ptr %y, i32 0, i32 0
123+
store i32 %a0, ptr %y_i32_0, align 4
124+
%y_i8_1 = getelementptr inbounds %struct.padded, ptr %y, i32 0, i32 1
125+
store i8 %a1, ptr %y_i8_1, align 1
126+
%y_i32_2 = getelementptr inbounds %struct.padded, ptr %y, i32 0, i32 2
127+
store i32 %a2, ptr %y_i32_2, align 4
128+
%y_i8_3 = getelementptr inbounds %struct.padded, ptr %y, i32 0, i32 3
129+
store i8 %a3, ptr %y_i8_3, align 1
130+
store ptr %x, ptr %x.addr, align 8
131+
store i32 %cond, ptr %cond.addr, align 4
132+
call void @llvm.lifetime.start.p0(ptr %temp)
133+
call void @llvm.memcpy.p0.p0.i64(ptr align 4 %temp, ptr align 4 %y,
134+
i64 16, i1 false)
135+
call void @llvm.lifetime.start.p0(ptr %zero)
136+
call void @llvm.memset.p0.i64(ptr align 4 %zero, i8 0, i64 16, i1 false)
137+
call void @llvm.lifetime.start.p0(ptr %data)
138+
%c.pad = load i32, ptr %cond.addr, align 4
139+
%tobool.pad = icmp ne i32 %c.pad, 0
140+
br i1 %tobool.pad, label %cond.true.pad, label %cond.false.pad
141+
142+
cond.true.pad:
143+
br label %cond.end.pad
144+
145+
cond.false.pad:
146+
br label %cond.end.pad
147+
148+
cond.end.pad:
149+
%cond1.pad = phi ptr [ %temp, %cond.true.pad ], [ %zero, %cond.false.pad ]
150+
call void @llvm.memcpy.p0.p0.i64(ptr align 4 %data, ptr align 4 %cond1.pad,
151+
i64 16, i1 false)
152+
%xv.pad = load ptr, ptr %x.addr, align 8
153+
call void @llvm.memcpy.p0.p0.i64(ptr align 4 %xv.pad, ptr align 4 %data,
154+
i64 16, i1 false)
155+
call void @llvm.lifetime.end.p0(ptr %data)
156+
call void @llvm.lifetime.end.p0(ptr %zero)
157+
call void @llvm.lifetime.end.p0(ptr %temp)
158+
ret void
159+
}
160+
161+
; NONHOMO-LABEL: define dso_local void @foo_nonhomo(
162+
; NONHOMO: llvm.memcpy
163+
; NONHOMO-NOT: store <
164+
; NONHOMO: ret void
165+
%struct.nonhomo = type { i32, i64, i32, i64 }
166+
define dso_local void @foo_nonhomo(ptr noundef %x, i32 %a0, i64 %a1,
167+
i32 %a2, i64 %a3,
168+
i32 noundef %cond) {
169+
entry:
170+
%y = alloca %struct.nonhomo, align 8
171+
%x.addr = alloca ptr, align 8
172+
%cond.addr = alloca i32, align 4
173+
%temp = alloca %struct.nonhomo, align 8
174+
%zero = alloca %struct.nonhomo, align 8
175+
%data = alloca %struct.nonhomo, align 8
176+
%y_i32_0n = getelementptr inbounds %struct.nonhomo, ptr %y, i32 0, i32 0
177+
store i32 %a0, ptr %y_i32_0n, align 4
178+
%y_i64_1n = getelementptr inbounds %struct.nonhomo, ptr %y, i32 0, i32 1
179+
store i64 %a1, ptr %y_i64_1n, align 8
180+
%y_i32_2n = getelementptr inbounds %struct.nonhomo, ptr %y, i32 0, i32 2
181+
store i32 %a2, ptr %y_i32_2n, align 4
182+
%y_i64_3n = getelementptr inbounds %struct.nonhomo, ptr %y, i32 0, i32 3
183+
store i64 %a3, ptr %y_i64_3n, align 8
184+
store ptr %x, ptr %x.addr, align 8
185+
store i32 %cond, ptr %cond.addr, align 4
186+
call void @llvm.lifetime.start.p0(ptr %temp)
187+
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %temp, ptr align 8 %y,
188+
i64 32, i1 false)
189+
call void @llvm.lifetime.start.p0(ptr %zero)
190+
call void @llvm.memset.p0.i64(ptr align 8 %zero, i8 0, i64 32, i1 false)
191+
call void @llvm.lifetime.start.p0(ptr %data)
192+
%c.nh = load i32, ptr %cond.addr, align 4
193+
%tobool.nh = icmp ne i32 %c.nh, 0
194+
br i1 %tobool.nh, label %cond.true.nh, label %cond.false.nh
195+
196+
cond.true.nh:
197+
br label %cond.end.nh
198+
199+
cond.false.nh:
200+
br label %cond.end.nh
201+
202+
cond.end.nh:
203+
%cond1.nh = phi ptr [ %temp, %cond.true.nh ], [ %zero, %cond.false.nh ]
204+
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %data, ptr align 8 %cond1.nh,
205+
i64 32, i1 false)
206+
%xv.nh = load ptr, ptr %x.addr, align 8
207+
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %xv.nh, ptr align 8 %data,
208+
i64 32, i1 false)
209+
call void @llvm.lifetime.end.p0(ptr %data)
210+
call void @llvm.lifetime.end.p0(ptr %zero)
211+
call void @llvm.lifetime.end.p0(ptr %temp)
212+
ret void
213+
}
214+
215+
; I1-LABEL: define dso_local void @foo_i1(
216+
; I1-NOT: <4 x i1>
217+
; I1: ret void
218+
%struct.i1x4 = type { i1, i1, i1, i1 }
219+
define dso_local void @foo_i1(ptr noundef %x, i64 %dummy0, i64 %dummy1,
220+
i32 noundef %cond) {
221+
entry:
222+
%y = alloca %struct.i1x4, align 1
223+
%x.addr = alloca ptr, align 8
224+
%cond.addr = alloca i32, align 4
225+
%temp = alloca %struct.i1x4, align 1
226+
%zero = alloca %struct.i1x4, align 1
227+
%data = alloca %struct.i1x4, align 1
228+
store ptr %x, ptr %x.addr, align 8
229+
store i32 %cond, ptr %cond.addr, align 4
230+
call void @llvm.lifetime.start.p0(ptr %temp)
231+
call void @llvm.memcpy.p0.p0.i64(ptr align 1 %temp, ptr align 1 %y,
232+
i64 4, i1 false)
233+
call void @llvm.lifetime.start.p0(ptr %zero)
234+
call void @llvm.memset.p0.i64(ptr align 1 %zero, i8 0, i64 4, i1 false)
235+
call void @llvm.lifetime.start.p0(ptr %data)
236+
%c.i1 = load i32, ptr %cond.addr, align 4
237+
%tobool.i1 = icmp ne i32 %c.i1, 0
238+
br i1 %tobool.i1, label %cond.true.i1, label %cond.false.i1
239+
240+
cond.true.i1:
241+
br label %cond.end.i1
242+
243+
cond.false.i1:
244+
br label %cond.end.i1
245+
246+
cond.end.i1:
247+
%cond1.i1 = phi ptr [ %temp, %cond.true.i1 ], [ %zero, %cond.false.i1 ]
248+
call void @llvm.memcpy.p0.p0.i64(ptr align 1 %data, ptr align 1 %cond1.i1,
249+
i64 4, i1 false)
250+
%xv.i1 = load ptr, ptr %x.addr, align 8
251+
call void @llvm.memcpy.p0.p0.i64(ptr align 1 %xv.i1, ptr align 1 %data,
252+
i64 4, i1 false)
253+
call void @llvm.lifetime.end.p0(ptr %data)
254+
call void @llvm.lifetime.end.p0(ptr %zero)
255+
call void @llvm.lifetime.end.p0(ptr %temp)
256+
ret void
257+
}
258+
259+
; PTR-LABEL: define dso_local void @foo_ptr(
260+
; PTR: llvm.memcpy
261+
; PTR-NOT: <4 x ptr>
262+
; PTR: ret void
263+
%struct.ptr4 = type { ptr, ptr, ptr, ptr }
264+
define dso_local void @foo_ptr(ptr noundef %x, ptr %p0, ptr %p1,
265+
ptr %p2, ptr %p3,
266+
i32 noundef %cond) {
267+
entry:
268+
%y = alloca %struct.ptr4, align 8
269+
%x.addr = alloca ptr, align 8
270+
%cond.addr = alloca i32, align 4
271+
%temp = alloca %struct.ptr4, align 8
272+
%zero = alloca %struct.ptr4, align 8
273+
%data = alloca %struct.ptr4, align 8
274+
%y_p0 = getelementptr inbounds %struct.ptr4, ptr %y, i32 0, i32 0
275+
store ptr %p0, ptr %y_p0, align 8
276+
%y_p1 = getelementptr inbounds %struct.ptr4, ptr %y, i32 0, i32 1
277+
store ptr %p1, ptr %y_p1, align 8
278+
%y_p2 = getelementptr inbounds %struct.ptr4, ptr %y, i32 0, i32 2
279+
store ptr %p2, ptr %y_p2, align 8
280+
%y_p3 = getelementptr inbounds %struct.ptr4, ptr %y, i32 0, i32 3
281+
store ptr %p3, ptr %y_p3, align 8
282+
store ptr %x, ptr %x.addr, align 8
283+
store i32 %cond, ptr %cond.addr, align 4
284+
call void @llvm.lifetime.start.p0(ptr %temp)
285+
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %temp, ptr align 8 %y,
286+
i64 32, i1 false)
287+
call void @llvm.lifetime.start.p0(ptr %zero)
288+
call void @llvm.memset.p0.i64(ptr align 8 %zero, i8 0, i64 32, i1 false)
289+
call void @llvm.lifetime.start.p0(ptr %data)
290+
%c.ptr = load i32, ptr %cond.addr, align 4
291+
%tobool.ptr = icmp ne i32 %c.ptr, 0
292+
br i1 %tobool.ptr, label %cond.true.ptr, label %cond.false.ptr
293+
294+
cond.true.ptr:
295+
br label %cond.end.ptr
296+
297+
cond.false.ptr:
298+
br label %cond.end.ptr
299+
300+
cond.end.ptr:
301+
%cond1.ptr = phi ptr [ %temp, %cond.true.ptr ], [ %zero, %cond.false.ptr ]
302+
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %data, ptr align 8 %cond1.ptr,
303+
i64 32, i1 false)
304+
%xv.ptr = load ptr, ptr %x.addr, align 8
305+
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %xv.ptr, ptr align 8 %data,
306+
i64 32, i1 false)
307+
call void @llvm.lifetime.end.p0(ptr %data)
308+
call void @llvm.lifetime.end.p0(ptr %zero)
309+
call void @llvm.lifetime.end.p0(ptr %temp)
310+
ret void
311+
}

0 commit comments

Comments
 (0)