Skip to content

Commit 6312bee

Browse files
authored
[SandboxVec][BottomUpVec] Use SeedCollector and slice seeds (#120826)
With this patch we switch from the temporary dummy seeds to actual seeds provided by the seed collector. The seeds get sliced and each slice is used as the starting point for vectorization.
1 parent 0aa831e commit 6312bee

File tree

13 files changed

+196
-32
lines changed

13 files changed

+196
-32
lines changed

llvm/include/llvm/SandboxIR/Pass.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace llvm {
1616

1717
class AAResults;
1818
class ScalarEvolution;
19+
class TargetTransformInfo;
1920

2021
namespace sandboxir {
2122

@@ -25,15 +26,18 @@ class Region;
2526
class Analyses {
2627
AAResults *AA = nullptr;
2728
ScalarEvolution *SE = nullptr;
29+
TargetTransformInfo *TTI = nullptr;
2830

2931
Analyses() = default;
3032

3133
public:
32-
Analyses(AAResults &AA, ScalarEvolution &SE) : AA(&AA), SE(&SE) {}
34+
Analyses(AAResults &AA, ScalarEvolution &SE, TargetTransformInfo &TTI)
35+
: AA(&AA), SE(&SE), TTI(&TTI) {}
3336

3437
public:
3538
AAResults &getAA() const { return *AA; }
3639
ScalarEvolution &getScalarEvolution() const { return *SE; }
40+
TargetTransformInfo &getTTI() const { return *TTI; }
3741
/// For use by unit tests.
3842
static Analyses emptyForTesting() { return Analyses(); }
3943
};

llvm/include/llvm/SandboxIR/Utils.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,16 @@ class Utils {
6060
getUnderlyingObject(LSI->getPointerOperand()->Val));
6161
}
6262

63+
/// \Returns the number of bits of \p Ty.
64+
static unsigned getNumBits(Type *Ty, const DataLayout &DL) {
65+
return DL.getTypeSizeInBits(Ty->LLVMTy);
66+
}
67+
6368
/// \Returns the number of bits required to represent the operands or return
6469
/// value of \p V in \p DL.
6570
static unsigned getNumBits(Value *V, const DataLayout &DL) {
6671
Type *Ty = getExpectedType(V);
67-
return DL.getTypeSizeInBits(Ty->LLVMTy);
72+
return getNumBits(Ty, DL);
6873
}
6974

7075
/// \Returns the number of bits required to represent the operands or

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ class LegalityAnalysis {
177177
// TODO: Try to remove the SkipScheduling argument by refactoring the tests.
178178
const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
179179
bool SkipScheduling = false);
180+
void clear() { Sched.clear(); }
180181
};
181182

182183
} // namespace llvm::sandboxir

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,13 @@ class Scheduler {
147147
~Scheduler() {}
148148

149149
bool trySchedule(ArrayRef<Instruction *> Instrs);
150+
/// Clear the scheduler's state, including the DAG.
151+
void clear() {
152+
Bndls.clear();
153+
// TODO: clear view once it lands.
154+
DAG.clear();
155+
ScheduleTopItOpt = std::nullopt;
156+
}
150157

151158
#ifndef NDEBUG
152159
void dump(raw_ostream &OS) const;

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ class SeedBundle {
9595
/// with a total size <= \p MaxVecRegBits, or an empty slice if the
9696
/// requirements cannot be met . If \p ForcePowOf2 is true, then the returned
9797
/// slice will have a total number of bits that is a power of 2.
98-
MutableArrayRef<Instruction *>
99-
getSlice(unsigned StartIdx, unsigned MaxVecRegBits, bool ForcePowOf2);
98+
ArrayRef<Instruction *> getSlice(unsigned StartIdx, unsigned MaxVecRegBits,
99+
bool ForcePowOf2);
100100

101101
/// \Returns the number of seed elements in the bundle.
102102
std::size_t size() const { return Seeds.size(); }

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@ class VecUtils {
133133
assert(tryGetCommonScalarType(Bndl) && "Expected common scalar type!");
134134
return ScalarTy;
135135
}
136+
/// \Returns the first integer power of 2 that is <= Num.
137+
static unsigned getFloorPowerOf2(unsigned Num) {
138+
if (Num == 0)
139+
return Num;
140+
unsigned Mask = Num;
141+
Mask >>= 1;
142+
for (unsigned ShiftBy = 1; ShiftBy < sizeof(Num) * 8; ShiftBy <<= 1)
143+
Mask |= Mask >> ShiftBy;
144+
return Num & ~Mask;
145+
}
136146
};
137147

138148
} // namespace llvm::sandboxir

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,31 @@
88

99
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
1010
#include "llvm/ADT/SmallVector.h"
11+
#include "llvm/Analysis/TargetTransformInfo.h"
1112
#include "llvm/SandboxIR/Function.h"
1213
#include "llvm/SandboxIR/Instruction.h"
1314
#include "llvm/SandboxIR/Module.h"
1415
#include "llvm/SandboxIR/Utils.h"
1516
#include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
17+
#include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
1618
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
1719

18-
namespace llvm::sandboxir {
20+
namespace llvm {
21+
22+
static cl::opt<unsigned>
23+
OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden,
24+
cl::desc("Override the vector register size in bits, "
25+
"which is otherwise found by querying TTI."));
26+
static cl::opt<bool>
27+
AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden,
28+
cl::desc("Allow non-power-of-2 vectorization."));
29+
30+
namespace sandboxir {
1931

2032
BottomUpVec::BottomUpVec(StringRef Pipeline)
2133
: FunctionPass("bottom-up-vec"),
2234
RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
2335

24-
// TODO: This is a temporary function that returns some seeds.
25-
// Replace this with SeedCollector's function when it lands.
26-
static llvm::SmallVector<Value *, 4> collectSeeds(BasicBlock &BB) {
27-
llvm::SmallVector<Value *, 4> Seeds;
28-
for (auto &I : BB)
29-
if (auto *SI = llvm::dyn_cast<StoreInst>(&I))
30-
Seeds.push_back(SI);
31-
return Seeds;
32-
}
33-
3436
static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
3537
unsigned OpIdx) {
3638
SmallVector<Value *, 4> Operands;
@@ -265,6 +267,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
265267

266268
bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
267269
DeadInstrCandidates.clear();
270+
Legality->clear();
268271
vectorizeRec(Bndl, /*Depth=*/0);
269272
tryEraseDeadInstrs();
270273
return Change;
@@ -275,17 +278,67 @@ bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
275278
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
276279
F.getContext());
277280
Change = false;
281+
const auto &DL = F.getParent()->getDataLayout();
282+
unsigned VecRegBits =
283+
OverrideVecRegBits != 0
284+
? OverrideVecRegBits
285+
: A.getTTI()
286+
.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
287+
.getFixedValue();
288+
278289
// TODO: Start from innermost BBs first
279290
for (auto &BB : F) {
280-
// TODO: Replace with proper SeedCollector function.
281-
auto Seeds = collectSeeds(BB);
282-
// TODO: Slice Seeds into smaller chunks.
283-
// TODO: If vectorization succeeds, run the RegionPassManager on the
284-
// resulting region.
285-
if (Seeds.size() >= 2)
286-
Change |= tryVectorize(Seeds);
291+
SeedCollector SC(&BB, A.getScalarEvolution());
292+
for (SeedBundle &Seeds : SC.getStoreSeeds()) {
293+
unsigned ElmBits =
294+
Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType(
295+
Seeds[Seeds.getFirstUnusedElementIdx()])),
296+
DL);
297+
298+
auto DivideBy2 = [](unsigned Num) {
299+
auto Floor = VecUtils::getFloorPowerOf2(Num);
300+
if (Floor == Num)
301+
return Floor / 2;
302+
return Floor;
303+
};
304+
// Try to create the largest vector supported by the target. If it fails
305+
// reduce the vector size by half.
306+
for (unsigned SliceElms = std::min(VecRegBits / ElmBits,
307+
Seeds.getNumUnusedBits() / ElmBits);
308+
SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) {
309+
if (Seeds.allUsed())
310+
break;
311+
// Keep trying offsets after FirstUnusedElementIdx, until we vectorize
312+
// the slice. This could be quite expensive, so we enforce a limit.
313+
for (unsigned Offset = Seeds.getFirstUnusedElementIdx(),
314+
OE = Seeds.size();
315+
Offset + 1 < OE; Offset += 1) {
316+
// Seeds are getting used as we vectorize, so skip them.
317+
if (Seeds.isUsed(Offset))
318+
continue;
319+
if (Seeds.allUsed())
320+
break;
321+
322+
auto SeedSlice =
323+
Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2);
324+
if (SeedSlice.empty())
325+
continue;
326+
327+
assert(SeedSlice.size() >= 2 && "Should have been rejected!");
328+
329+
// TODO: If vectorization succeeds, run the RegionPassManager on the
330+
// resulting region.
331+
332+
// TODO: Refactor to remove the unnecessary copy to SeedSliceVals.
333+
SmallVector<Value *> SeedSliceVals(SeedSlice.begin(),
334+
SeedSlice.end());
335+
Change |= tryVectorize(SeedSliceVals);
336+
}
337+
}
338+
}
287339
}
288340
return Change;
289341
}
290342

291-
} // namespace llvm::sandboxir
343+
} // namespace sandboxir
344+
} // namespace llvm

llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,6 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
8686

8787
// Create SandboxIR for LLVMF and run BottomUpVec on it.
8888
sandboxir::Function &F = *Ctx->createFunction(&LLVMF);
89-
sandboxir::Analyses A(*AA, *SE);
89+
sandboxir::Analyses A(*AA, *SE, *TTI);
9090
return FPM.runOnFunction(F, A);
9191
}

llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ cl::opt<unsigned> SeedGroupsLimit(
3131
cl::desc("Limit the number of collected seeds groups in a BB to "
3232
"cap compilation time."));
3333

34-
MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
35-
unsigned MaxVecRegBits,
36-
bool ForcePowerOf2) {
34+
ArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
35+
unsigned MaxVecRegBits,
36+
bool ForcePowerOf2) {
3737
// Use uint32_t here for compatibility with IsPowerOf2_32
3838

3939
// BitCount tracks the size of the working slice. From that we can tell
@@ -47,10 +47,13 @@ MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
4747
// Can't start a slice with a used instruction.
4848
assert(!isUsed(StartIdx) && "Expected unused at StartIdx");
4949
for (auto S : make_range(Seeds.begin() + StartIdx, Seeds.end())) {
50+
// Stop if this instruction is used. This needs to be done before
51+
// getNumBits() because a "used" instruction may have been erased.
52+
if (isUsed(StartIdx + NumElements))
53+
break;
5054
uint32_t InstBits = Utils::getNumBits(S);
51-
// Stop if this instruction is used, or if adding it puts the slice over
52-
// the limit.
53-
if (isUsed(StartIdx + NumElements) || BitCount + InstBits > MaxVecRegBits)
55+
// Stop if adding it puts the slice over the limit.
56+
if (BitCount + InstBits > MaxVecRegBits)
5457
break;
5558
NumElements++;
5659
BitCount += InstBits;
@@ -68,7 +71,7 @@ MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
6871
"Must be a power of two");
6972
// Return any non-empty slice
7073
if (NumElements > 1)
71-
return MutableArrayRef<Instruction *>(&Seeds[StartIdx], NumElements);
74+
return ArrayRef<Instruction *>(&Seeds[StartIdx], NumElements);
7275
else
7376
return {};
7477
}

llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2-
; RUN: opt -passes=sandbox-vectorizer -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s
2+
; RUN: opt -passes=sandbox-vectorizer -sbvec-vec-reg-bits=1024 -sbvec-allow-non-pow2 -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s
33

44
define void @store_load(ptr %ptr) {
55
; CHECK-LABEL: define void @store_load(

0 commit comments

Comments
 (0)