Skip to content

Commit ce0d085

Browse files
author
vporpo
authored
[SandboxVec][Legality] Query the scheduler for legality (#114616)
This patch adds the legality check of whether the candidate instructions can be scheduled together. This uses a Scheduler object.
1 parent c1cec8c commit ce0d085

File tree

7 files changed

+114
-29
lines changed

7 files changed

+114
-29
lines changed

llvm/include/llvm/SandboxIR/Pass.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
namespace llvm {
1616

17+
class AAResults;
1718
class ScalarEvolution;
1819

1920
namespace sandboxir {
@@ -22,14 +23,16 @@ class Function;
2223
class Region;
2324

2425
class Analyses {
26+
AAResults *AA = nullptr;
2527
ScalarEvolution *SE = nullptr;
2628

2729
Analyses() = default;
2830

2931
public:
30-
Analyses(ScalarEvolution &SE) : SE(&SE) {}
32+
Analyses(AAResults &AA, ScalarEvolution &SE) : AA(&AA), SE(&SE) {}
3133

3234
public:
35+
AAResults &getAA() const { return *AA; }
3336
ScalarEvolution &getScalarEvolution() const { return *SE; }
3437
/// For use by unit tests.
3538
static Analyses emptyForTesting() { return Analyses(); }

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/IR/DataLayout.h"
1818
#include "llvm/Support/Casting.h"
1919
#include "llvm/Support/raw_ostream.h"
20+
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
2021

2122
namespace llvm::sandboxir {
2223

@@ -36,6 +37,7 @@ enum class ResultReason {
3637
DiffMathFlags,
3738
DiffWrapFlags,
3839
NotConsecutive,
40+
CantSchedule,
3941
Unimplemented,
4042
Infeasible,
4143
};
@@ -66,6 +68,8 @@ struct ToStr {
6668
return "DiffWrapFlags";
6769
case ResultReason::NotConsecutive:
6870
return "NotConsecutive";
71+
case ResultReason::CantSchedule:
72+
return "CantSchedule";
6973
case ResultReason::Unimplemented:
7074
return "Unimplemented";
7175
case ResultReason::Infeasible:
@@ -146,6 +150,7 @@ class Pack final : public LegalityResultWithReason {
146150

147151
/// Performs the legality analysis and returns a LegalityResult object.
148152
class LegalityAnalysis {
153+
Scheduler Sched;
149154
/// Owns the legality result objects created by createLegalityResult().
150155
SmallVector<std::unique_ptr<LegalityResult>> ResultPool;
151156
/// Checks opcodes, types and other IR-specifics and returns a ResultReason
@@ -157,8 +162,8 @@ class LegalityAnalysis {
157162
const DataLayout &DL;
158163

159164
public:
160-
LegalityAnalysis(ScalarEvolution &SE, const DataLayout &DL)
161-
: SE(SE), DL(DL) {}
165+
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL)
166+
: Sched(AA), SE(SE), DL(DL) {}
162167
/// A LegalityResult factory.
163168
template <typename ResultT, typename... ArgsT>
164169
ResultT &createLegalityResult(ArgsT... Args) {
@@ -167,7 +172,10 @@ class LegalityAnalysis {
167172
}
168173
/// Checks if it's legal to vectorize the instructions in \p Bndl.
169174
/// \Returns a LegalityResult object owned by LegalityAnalysis.
170-
const LegalityResult &canVectorize(ArrayRef<Value *> Bndl);
175+
/// \p SkipScheduling skips the scheduler check and is only meant for testing.
176+
// TODO: Try to remove the SkipScheduling argument by refactoring the tests.
177+
const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
178+
bool SkipScheduling = false);
171179
};
172180

173181
} // namespace llvm::sandboxir

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <memory>
1212

13+
#include "llvm/Analysis/AliasAnalysis.h"
1314
#include "llvm/Analysis/ScalarEvolution.h"
1415
#include "llvm/IR/PassManager.h"
1516
#include "llvm/SandboxIR/PassManager.h"
@@ -20,6 +21,7 @@ class TargetTransformInfo;
2021

2122
class SandboxVectorizerPass : public PassInfoMixin<SandboxVectorizerPass> {
2223
TargetTransformInfo *TTI = nullptr;
24+
AAResults *AA = nullptr;
2325
ScalarEvolution *SE = nullptr;
2426

2527
// A pipeline of SandboxIR function passes run by the vectorizer.

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ static void dumpBndl(ArrayRef<Value *> Bndl) {
184184
}
185185
#endif // NDEBUG
186186

187-
const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl) {
187+
const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
188+
bool SkipScheduling) {
188189
// If Bndl contains values other than instructions, we need to Pack.
189190
if (any_of(Bndl, [](auto *V) { return !isa<Instruction>(V); })) {
190191
LLVM_DEBUG(dbgs() << "Not vectorizing: Not Instructions:\n";
@@ -197,7 +198,15 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl) {
197198

198199
// TODO: Check for existing vectors containing values in Bndl.
199200

200-
// TODO: Check with scheduler.
201+
if (!SkipScheduling) {
202+
// TODO: Try to remove the IBndl vector.
203+
SmallVector<Instruction *, 8> IBndl;
204+
IBndl.reserve(Bndl.size());
205+
for (auto *V : Bndl)
206+
IBndl.push_back(cast<Instruction>(V));
207+
if (!Sched.trySchedule(IBndl))
208+
return createLegalityResult<Pack>(ResultReason::CantSchedule);
209+
}
201210

202211
return createLegalityResult<Widen>();
203212
}

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
6161
void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }
6262

6363
bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
64-
Legality = std::make_unique<LegalityAnalysis>(A.getScalarEvolution(),
65-
F.getParent()->getDataLayout());
64+
Legality = std::make_unique<LegalityAnalysis>(
65+
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout());
6666
Change = false;
6767
// TODO: Start from innermost BBs first
6868
for (auto &BB : F) {

llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ SandboxVectorizerPass::~SandboxVectorizerPass() = default;
5151
PreservedAnalyses SandboxVectorizerPass::run(Function &F,
5252
FunctionAnalysisManager &AM) {
5353
TTI = &AM.getResult<TargetIRAnalysis>(F);
54+
AA = &AM.getResult<AAManager>(F);
5455
SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
5556

5657
bool Changed = runImpl(F);
@@ -83,6 +84,6 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
8384
// Create SandboxIR for LLVMF and run BottomUpVec on it.
8485
sandboxir::Context Ctx(LLVMF.getContext());
8586
sandboxir::Function &F = *Ctx.createFunction(&LLVMF);
86-
sandboxir::Analyses A(*SE);
87+
sandboxir::Analyses A(*AA, *SE);
8788
return FPM.runOnFunction(F, A);
8889
}

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp

Lines changed: 82 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
1010
#include "llvm/Analysis/AssumptionCache.h"
11+
#include "llvm/Analysis/BasicAliasAnalysis.h"
1112
#include "llvm/Analysis/LoopInfo.h"
1213
#include "llvm/Analysis/ScalarEvolution.h"
1314
#include "llvm/Analysis/TargetLibraryInfo.h"
@@ -30,15 +31,20 @@ struct LegalityTest : public testing::Test {
3031
std::unique_ptr<AssumptionCache> AC;
3132
std::unique_ptr<LoopInfo> LI;
3233
std::unique_ptr<ScalarEvolution> SE;
34+
std::unique_ptr<BasicAAResult> BAA;
35+
std::unique_ptr<AAResults> AA;
3336

34-
ScalarEvolution &getSE(llvm::Function &LLVMF) {
37+
void getAnalyses(llvm::Function &LLVMF) {
3538
DT = std::make_unique<DominatorTree>(LLVMF);
3639
TLII = std::make_unique<TargetLibraryInfoImpl>();
3740
TLI = std::make_unique<TargetLibraryInfo>(*TLII);
3841
AC = std::make_unique<AssumptionCache>(LLVMF);
3942
LI = std::make_unique<LoopInfo>(*DT);
4043
SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
41-
return *SE;
44+
BAA = std::make_unique<BasicAAResult>(LLVMF.getParent()->getDataLayout(),
45+
LLVMF, *TLI, *AC, DT.get());
46+
AA = std::make_unique<AAResults>(*TLI);
47+
AA->addAAResult(*BAA);
4248
}
4349

4450
void parseIR(LLVMContext &C, const char *IR) {
@@ -49,7 +55,7 @@ struct LegalityTest : public testing::Test {
4955
}
5056
};
5157

52-
TEST_F(LegalityTest, Legality) {
58+
TEST_F(LegalityTest, LegalitySkipSchedule) {
5359
parseIR(C, R"IR(
5460
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
5561
%gep0 = getelementptr float, ptr %ptr, i32 0
@@ -76,7 +82,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
7682
}
7783
)IR");
7884
llvm::Function *LLVMF = &*M->getFunction("foo");
79-
auto &SE = getSE(*LLVMF);
85+
getAnalyses(*LLVMF);
8086
const auto &DL = M->getDataLayout();
8187

8288
sandboxir::Context Ctx(C);
@@ -104,83 +110,139 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
104110
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
105111
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
106112

107-
sandboxir::LegalityAnalysis Legality(SE, DL);
108-
const auto &Result = Legality.canVectorize({St0, St1});
113+
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
114+
const auto &Result =
115+
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
109116
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
110117

111118
{
112119
// Check NotInstructions
113-
auto &Result = Legality.canVectorize({F, St0});
120+
auto &Result = Legality.canVectorize({F, St0}, /*SkipScheduling=*/true);
114121
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
115122
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
116123
sandboxir::ResultReason::NotInstructions);
117124
}
118125
{
119126
// Check DiffOpcodes
120-
const auto &Result = Legality.canVectorize({St0, Ld0});
127+
const auto &Result =
128+
Legality.canVectorize({St0, Ld0}, /*SkipScheduling=*/true);
121129
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
122130
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
123131
sandboxir::ResultReason::DiffOpcodes);
124132
}
125133
{
126134
// Check DiffTypes
127-
EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({St0, StVec2})));
128-
EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({StVec2, StVec3})));
135+
EXPECT_TRUE(isa<sandboxir::Widen>(
136+
Legality.canVectorize({St0, StVec2}, /*SkipScheduling=*/true)));
137+
EXPECT_TRUE(isa<sandboxir::Widen>(
138+
Legality.canVectorize({StVec2, StVec3}, /*SkipScheduling=*/true)));
129139

130-
const auto &Result = Legality.canVectorize({St0, StI8});
140+
const auto &Result =
141+
Legality.canVectorize({St0, StI8}, /*SkipScheduling=*/true);
131142
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
132143
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
133144
sandboxir::ResultReason::DiffTypes);
134145
}
135146
{
136147
// Check DiffMathFlags
137-
const auto &Result = Legality.canVectorize({FAdd0, FAdd1});
148+
const auto &Result =
149+
Legality.canVectorize({FAdd0, FAdd1}, /*SkipScheduling=*/true);
138150
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
139151
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
140152
sandboxir::ResultReason::DiffMathFlags);
141153
}
142154
{
143155
// Check DiffWrapFlags
144-
const auto &Result = Legality.canVectorize({Trunc0, Trunc1});
156+
const auto &Result =
157+
Legality.canVectorize({Trunc0, Trunc1}, /*SkipScheduling=*/true);
145158
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
146159
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
147160
sandboxir::ResultReason::DiffWrapFlags);
148161
}
149162
{
150163
// Check DiffTypes for unary operands that have a different type.
151-
const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8});
164+
const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8},
165+
/*SkipScheduling=*/true);
152166
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
153167
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
154168
sandboxir::ResultReason::DiffTypes);
155169
}
156170
{
157171
// Check DiffOpcodes for CMPs with different predicates.
158-
const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT});
172+
const auto &Result =
173+
Legality.canVectorize({CmpSLT, CmpSGT}, /*SkipScheduling=*/true);
159174
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
160175
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
161176
sandboxir::ResultReason::DiffOpcodes);
162177
}
163178
{
164179
// Check NotConsecutive Ld0,Ld0b
165-
const auto &Result = Legality.canVectorize({Ld0, Ld0b});
180+
const auto &Result =
181+
Legality.canVectorize({Ld0, Ld0b}, /*SkipScheduling=*/true);
166182
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
167183
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
168184
sandboxir::ResultReason::NotConsecutive);
169185
}
170186
{
171187
// Check NotConsecutive Ld0,Ld3
172-
const auto &Result = Legality.canVectorize({Ld0, Ld3});
188+
const auto &Result =
189+
Legality.canVectorize({Ld0, Ld3}, /*SkipScheduling=*/true);
173190
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
174191
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
175192
sandboxir::ResultReason::NotConsecutive);
176193
}
177194
{
178195
// Check Widen Ld0,Ld1
179-
const auto &Result = Legality.canVectorize({Ld0, Ld1});
196+
const auto &Result =
197+
Legality.canVectorize({Ld0, Ld1}, /*SkipScheduling=*/true);
180198
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
181199
}
182200
}
183201

202+
TEST_F(LegalityTest, LegalitySchedule) {
203+
parseIR(C, R"IR(
204+
define void @foo(ptr %ptr) {
205+
%gep0 = getelementptr float, ptr %ptr, i32 0
206+
%gep1 = getelementptr float, ptr %ptr, i32 1
207+
%ld0 = load float, ptr %gep0
208+
store float %ld0, ptr %gep1
209+
%ld1 = load float, ptr %gep1
210+
store float %ld0, ptr %gep0
211+
store float %ld1, ptr %gep1
212+
ret void
213+
}
214+
)IR");
215+
llvm::Function *LLVMF = &*M->getFunction("foo");
216+
getAnalyses(*LLVMF);
217+
const auto &DL = M->getDataLayout();
218+
219+
sandboxir::Context Ctx(C);
220+
auto *F = Ctx.createFunction(LLVMF);
221+
auto *BB = &*F->begin();
222+
auto It = BB->begin();
223+
[[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
224+
[[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
225+
auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
226+
[[maybe_unused]] auto *ConflictingSt = cast<sandboxir::StoreInst>(&*It++);
227+
auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
228+
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
229+
auto *St1 = cast<sandboxir::StoreInst>(&*It++);
230+
231+
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
232+
{
233+
// Can vectorize St0,St1.
234+
const auto &Result = Legality.canVectorize({St0, St1});
235+
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
236+
}
237+
{
238+
// Can't vectorize Ld0,Ld1 because of conflicting store.
239+
auto &Result = Legality.canVectorize({Ld0, Ld1});
240+
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
241+
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
242+
sandboxir::ResultReason::CantSchedule);
243+
}
244+
}
245+
184246
#ifndef NDEBUG
185247
TEST_F(LegalityTest, LegalityResultDump) {
186248
parseIR(C, R"IR(
@@ -189,7 +251,7 @@ define void @foo() {
189251
}
190252
)IR");
191253
llvm::Function *LLVMF = &*M->getFunction("foo");
192-
auto &SE = getSE(*LLVMF);
254+
getAnalyses(*LLVMF);
193255
const auto &DL = M->getDataLayout();
194256

195257
auto Matches = [](const sandboxir::LegalityResult &Result,
@@ -200,7 +262,7 @@ define void @foo() {
200262
return Buff == ExpectedStr;
201263
};
202264

203-
sandboxir::LegalityAnalysis Legality(SE, DL);
265+
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
204266
EXPECT_TRUE(
205267
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
206268
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(

0 commit comments

Comments
 (0)