Skip to content

Commit d11ad2d

Browse files
[SandboxVectorizer] New class to actually collect and manage memory seeds
There are many more tests to add, but I would like to get this reviewed before it grows too big.
1 parent ec24e23 commit d11ad2d

File tree

4 files changed

+295
-0
lines changed

4 files changed

+295
-0
lines changed

llvm/include/llvm/SandboxIR/Utils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ class Utils {
6060
getUnderlyingObject(LSI->getPointerOperand()->Val));
6161
}
6262

63+
/// \Returns the number of elements in \p Ty, that is the number of lanes
64+
/// if a fixed vector or 1 if scalar. ScalableVectors
65+
static int getNumElements(Type *Ty) {
66+
return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getNumElements() : 1;
67+
}
68+
/// Returns \p Ty if scalar or its element type if vector.
69+
static Type *getElementType(Type *Ty) {
70+
return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
71+
}
72+
6373
/// \Returns the number of bits required to represent the operands or return
6474
/// value of \p V in \p DL.
6575
static unsigned getNumBits(Value *V, const DataLayout &DL) {

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,36 @@ class SeedContainer {
284284
#endif // NDEBUG
285285
};
286286

287+
class SeedCollector {
288+
SeedContainer StoreSeeds;
289+
SeedContainer LoadSeeds;
290+
BasicBlock *BB;
291+
Context &Ctx;
292+
293+
/// \Returns the number of SeedBundle groups for all seed types.
294+
/// This is to be used for limiting compilation time.
295+
unsigned totalNumSeedGroups() const {
296+
return StoreSeeds.size() + LoadSeeds.size();
297+
}
298+
299+
public:
300+
SeedCollector(BasicBlock *SBBB, ScalarEvolution &SE);
301+
~SeedCollector();
302+
303+
BasicBlock *getBasicBlock() { return BB; }
304+
305+
iterator_range<SeedContainer::iterator> getStoreSeeds() {
306+
return {StoreSeeds.begin(), StoreSeeds.end()};
307+
}
308+
iterator_range<SeedContainer::iterator> getLoadSeeds() {
309+
return {LoadSeeds.begin(), LoadSeeds.end()};
310+
}
311+
#ifndef NDEBUG
312+
void print(raw_ostream &OS) const;
313+
LLVM_DUMP_METHOD void dump() const;
314+
#endif
315+
};
316+
287317
} // namespace llvm::sandboxir
288318

289319
#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SEEDCOLLECTOR_H

llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@ namespace llvm::sandboxir {
2222
cl::opt<unsigned> SeedBundleSizeLimit(
2323
"sbvec-seed-bundle-size-limit", cl::init(32), cl::Hidden,
2424
cl::desc("Limit the size of the seed bundle to cap compilation time."));
25+
cl::opt<bool>
26+
DisableStoreSeeds("sbvec-disable-store-seeds", cl::init(false), cl::Hidden,
27+
cl::desc("Don't collect store seed instructions."));
28+
cl::opt<bool>
29+
DisableLoadSeeds("sbvec-disable-load-seeds", cl::init(true), cl::Hidden,
30+
cl::desc("Don't collect load seed instructions."));
31+
32+
#define LoadSeedsDef "loads"
33+
#define StoreSeedsDef "stores"
34+
cl::opt<std::string>
35+
ForceSeed("sbvec-force-seeds", cl::init(""), cl::Hidden,
36+
cl::desc("Enable only this type of seeds. This can be one "
37+
"of: '" LoadSeedsDef "','" StoreSeedsDef "'."));
38+
cl::opt<unsigned> SeedGroupsLimit(
39+
"sbvec-seed-groups-limit", cl::init(256), cl::Hidden,
40+
cl::desc("Limit the number of collected seeds groups in a BB to "
41+
"cap compilation time."));
2542

2643
MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
2744
unsigned MaxVecRegBits,
@@ -131,4 +148,74 @@ void SeedContainer::print(raw_ostream &OS) const {
131148
LLVM_DUMP_METHOD void SeedContainer::dump() const { print(dbgs()); }
132149
#endif // NDEBUG
133150

151+
template <typename LoadOrStoreT> static bool isValidMemSeed(LoadOrStoreT *LSI) {
152+
if (LSI->isSimple())
153+
return true;
154+
auto *Ty = Utils::getExpectedType(LSI);
155+
// Omit types that are architecturally unvectorizable
156+
if (Ty->isX86_FP80Ty() || Ty->isPPC_FP128Ty())
157+
return false;
158+
// Omit vector types without compile-time-known lane counts
159+
if (isa<ScalableVectorType>(Ty))
160+
return false;
161+
if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
162+
return VectorType::isValidElementType(VTy->getElementType());
163+
return VectorType::isValidElementType(Ty);
164+
}
165+
166+
template bool isValidMemSeed(LoadInst *LSI);
167+
template bool isValidMemSeed<StoreInst>(StoreInst *LSI);
168+
169+
SeedCollector::SeedCollector(BasicBlock *SBBB, ScalarEvolution &SE)
170+
: StoreSeeds(SE), LoadSeeds(SE), BB(SBBB), Ctx(BB->getContext()) {
171+
// TODO: Register a callback for updating the Collector datastructures upon
172+
// instr removal
173+
174+
bool CollectStores = !DisableStoreSeeds;
175+
bool CollectLoads = !DisableLoadSeeds;
176+
if (LLVM_UNLIKELY(!ForceSeed.empty())) {
177+
CollectStores = false;
178+
CollectLoads = false;
179+
// Enable only the selected one.
180+
if (ForceSeed == StoreSeedsDef)
181+
CollectStores = true;
182+
else if (ForceSeed == LoadSeedsDef)
183+
CollectLoads = true;
184+
else {
185+
errs() << "Bad argument '" << ForceSeed << "' in -" << ForceSeed.ArgStr
186+
<< "='" << ForceSeed << "'.\n";
187+
errs() << "Description: " << ForceSeed.HelpStr << "\n";
188+
exit(1);
189+
}
190+
}
191+
// Actually collect the seeds.
192+
for (auto &I : *BB) {
193+
if (StoreInst *SI = dyn_cast<StoreInst>(&I))
194+
if (CollectStores && isValidMemSeed(SI))
195+
StoreSeeds.insert(SI);
196+
if (LoadInst *LI = dyn_cast<LoadInst>(&I))
197+
if (CollectLoads && isValidMemSeed(LI))
198+
LoadSeeds.insert(LI);
199+
// Cap compilation time.
200+
if (totalNumSeedGroups() > SeedGroupsLimit)
201+
break;
202+
}
203+
}
204+
205+
SeedCollector::~SeedCollector() {
206+
// TODO: Unregister the callback for updating the seed datastructures upon
207+
// instr removal
208+
}
209+
210+
#ifndef NDEBUG
211+
void SeedCollector::print(raw_ostream &OS) const {
212+
OS << "=== StoreSeeds ===\n";
213+
StoreSeeds.print(OS);
214+
OS << "=== LoadSeeds ===\n";
215+
LoadSeeds.print(OS);
216+
}
217+
218+
void SeedCollector::dump() const { print(dbgs()); }
219+
#endif
220+
134221
} // namespace llvm::sandboxir

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SeedCollectorTest.cpp

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,171 @@ define void @foo(ptr %ptrA, float %val, ptr %ptrB) {
268268
}
269269
EXPECT_EQ(Cnt, 0u);
270270
}
271+
272+
TEST_F(SeedBundleTest, ConsecutiveStores) {
273+
// Where "Consecutive" means the stores address consecutive locations in
274+
// memory, but not in program order. Check to see that the collector puts them
275+
// in the proper order for vectorization.
276+
parseIR(C, R"IR(
277+
define void @foo(ptr noalias %ptr, float %val) {
278+
bb:
279+
%ptr0 = getelementptr float, ptr %ptr, i32 0
280+
%ptr1 = getelementptr float, ptr %ptr, i32 1
281+
%ptr2 = getelementptr float, ptr %ptr, i32 2
282+
%ptr3 = getelementptr float, ptr %ptr, i32 3
283+
store float %val, ptr %ptr0
284+
store float %val, ptr %ptr2
285+
store float %val, ptr %ptr1
286+
store float %val, ptr %ptr3
287+
ret void
288+
}
289+
)IR");
290+
Function &LLVMF = *M->getFunction("foo");
291+
DominatorTree DT(LLVMF);
292+
TargetLibraryInfoImpl TLII;
293+
TargetLibraryInfo TLI(TLII);
294+
DataLayout DL(M->getDataLayout());
295+
LoopInfo LI(DT);
296+
AssumptionCache AC(LLVMF);
297+
ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
298+
299+
sandboxir::Context Ctx(C);
300+
auto &F = *Ctx.createFunction(&LLVMF);
301+
auto BB = F.begin();
302+
sandboxir::SeedCollector SC(&*BB, SE);
303+
304+
// Find the stores
305+
auto It = std::next(BB->begin(), 4);
306+
// StX with X as the order by offset in memory
307+
auto *St0 = &*It++;
308+
auto *St2 = &*It++;
309+
auto *St1 = &*It++;
310+
auto *St3 = &*It++;
311+
312+
auto StoreSeedsRange = SC.getStoreSeeds();
313+
auto &SB = *StoreSeedsRange.begin();
314+
// Expect just one vector of store seeds
315+
EXPECT_TRUE(std::next(StoreSeedsRange.begin()) == StoreSeedsRange.end());
316+
EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St2, St3));
317+
}
318+
319+
TEST_F(SeedBundleTest, StoresWithGaps) {
320+
parseIR(C, R"IR(
321+
define void @foo(ptr noalias %ptr, float %val) {
322+
bb:
323+
%ptr0 = getelementptr float, ptr %ptr, i32 0
324+
%ptr1 = getelementptr float, ptr %ptr, i32 3
325+
%ptr2 = getelementptr float, ptr %ptr, i32 5
326+
%ptr3 = getelementptr float, ptr %ptr, i32 7
327+
store float %val, ptr %ptr0
328+
store float %val, ptr %ptr2
329+
store float %val, ptr %ptr1
330+
store float %val, ptr %ptr3
331+
ret void
332+
}
333+
)IR");
334+
Function &LLVMF = *M->getFunction("foo");
335+
DominatorTree DT(LLVMF);
336+
TargetLibraryInfoImpl TLII;
337+
TargetLibraryInfo TLI(TLII);
338+
DataLayout DL(M->getDataLayout());
339+
LoopInfo LI(DT);
340+
AssumptionCache AC(LLVMF);
341+
ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
342+
343+
sandboxir::Context Ctx(C);
344+
auto &F = *Ctx.createFunction(&LLVMF);
345+
auto BB = F.begin();
346+
sandboxir::SeedCollector SC(&*BB, SE);
347+
348+
// Find the stores
349+
auto It = std::next(BB->begin(), 4);
350+
// StX with X as the order by offset in memory
351+
auto *St0 = &*It++;
352+
auto *St2 = &*It++;
353+
auto *St1 = &*It++;
354+
auto *St3 = &*It++;
355+
356+
auto StoreSeedsRange = SC.getStoreSeeds();
357+
auto &SB = *StoreSeedsRange.begin();
358+
// Expect just one vector of store seeds
359+
EXPECT_TRUE(std::next(StoreSeedsRange.begin()) == StoreSeedsRange.end());
360+
EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St2, St3));
361+
}
362+
363+
TEST_F(SeedBundleTest, VectorStores) {
364+
parseIR(C, R"IR(
365+
define void @foo(ptr noalias %ptr, <2 x float> %val) {
366+
bb:
367+
%ptr0 = getelementptr float, ptr %ptr, i32 0
368+
%ptr2 = getelementptr float, ptr %ptr, i32 2
369+
store <2 x float> %val, ptr %ptr2
370+
store <2 x float> %val, ptr %ptr0
371+
ret void
372+
}
373+
)IR");
374+
Function &LLVMF = *M->getFunction("foo");
375+
DominatorTree DT(LLVMF);
376+
TargetLibraryInfoImpl TLII;
377+
TargetLibraryInfo TLI(TLII);
378+
DataLayout DL(M->getDataLayout());
379+
LoopInfo LI(DT);
380+
AssumptionCache AC(LLVMF);
381+
ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
382+
383+
sandboxir::Context Ctx(C);
384+
auto &F = *Ctx.createFunction(&LLVMF);
385+
auto BB = F.begin();
386+
sandboxir::SeedCollector SC(&*BB, SE);
387+
388+
// Find the stores
389+
auto It = std::next(BB->begin(), 2);
390+
// StX with X as the order by offset in memory
391+
auto *St2 = &*It++;
392+
auto *St0 = &*It++;
393+
394+
auto StoreSeedsRange = SC.getStoreSeeds();
395+
auto &SB = *StoreSeedsRange.begin();
396+
EXPECT_TRUE(std::next(StoreSeedsRange.begin()) == StoreSeedsRange.end());
397+
EXPECT_THAT(SB, testing::ElementsAre(St0, St2));
398+
}
399+
400+
TEST_F(SeedBundleTest, MixedScalarVectors) {
401+
parseIR(C, R"IR(
402+
define void @foo(ptr noalias %ptr, float %v, <2 x float> %val) {
403+
bb:
404+
%ptr0 = getelementptr float, ptr %ptr, i32 0
405+
%ptr1 = getelementptr float, ptr %ptr, i32 1
406+
%ptr3 = getelementptr float, ptr %ptr, i32 3
407+
store float %v, ptr %ptr0
408+
store float %v, ptr %ptr3
409+
store <2 x float> %val, ptr %ptr1
410+
ret void
411+
}
412+
)IR");
413+
Function &LLVMF = *M->getFunction("foo");
414+
DominatorTree DT(LLVMF);
415+
TargetLibraryInfoImpl TLII;
416+
TargetLibraryInfo TLI(TLII);
417+
DataLayout DL(M->getDataLayout());
418+
LoopInfo LI(DT);
419+
AssumptionCache AC(LLVMF);
420+
ScalarEvolution SE(LLVMF, TLI, AC, DT, LI);
421+
422+
sandboxir::Context Ctx(C);
423+
auto &F = *Ctx.createFunction(&LLVMF);
424+
auto BB = F.begin();
425+
sandboxir::SeedCollector SC(&*BB, SE);
426+
427+
// Find the stores
428+
auto It = std::next(BB->begin(), 3);
429+
// StX with X as the order by offset in memory
430+
auto *St0 = &*It++;
431+
auto *St3 = &*It++;
432+
auto *St1 = &*It++;
433+
434+
auto &SB = *SC.getStoreSeeds().begin();
435+
EXPECT_TRUE(std::next(SC.getStoreSeeds().begin()) ==
436+
SC.getStoreSeeds().end());
437+
EXPECT_THAT(SB, testing::ElementsAre(St0, St1, St3));
438+
}

0 commit comments

Comments
 (0)