diff --git a/llvm/include/llvm/SandboxIR/Pass.h b/llvm/include/llvm/SandboxIR/Pass.h index fee6bd9e779fd..4f4eae87cd3ff 100644 --- a/llvm/include/llvm/SandboxIR/Pass.h +++ b/llvm/include/llvm/SandboxIR/Pass.h @@ -14,6 +14,7 @@ namespace llvm { +class AAResults; class ScalarEvolution; namespace sandboxir { @@ -22,14 +23,16 @@ class Function; class Region; class Analyses { + AAResults *AA = nullptr; ScalarEvolution *SE = nullptr; Analyses() = default; public: - Analyses(ScalarEvolution &SE) : SE(&SE) {} + Analyses(AAResults &AA, ScalarEvolution &SE) : AA(&AA), SE(&SE) {} public: + AAResults &getAA() const { return *AA; } ScalarEvolution &getScalarEvolution() const { return *SE; } /// For use by unit tests. static Analyses emptyForTesting() { return Analyses(); } diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h index f43e033e3cc7e..58dcb2eeadbc2 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h @@ -17,6 +17,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h" namespace llvm::sandboxir { @@ -36,6 +37,7 @@ enum class ResultReason { DiffMathFlags, DiffWrapFlags, NotConsecutive, + CantSchedule, Unimplemented, Infeasible, }; @@ -66,6 +68,8 @@ struct ToStr { return "DiffWrapFlags"; case ResultReason::NotConsecutive: return "NotConsecutive"; + case ResultReason::CantSchedule: + return "CantSchedule"; case ResultReason::Unimplemented: return "Unimplemented"; case ResultReason::Infeasible: @@ -146,6 +150,7 @@ class Pack final : public LegalityResultWithReason { /// Performs the legality analysis and returns a LegalityResult object. class LegalityAnalysis { + Scheduler Sched; /// Owns the legality result objects created by createLegalityResult(). SmallVector> ResultPool; /// Checks opcodes, types and other IR-specifics and returns a ResultReason @@ -157,8 +162,8 @@ class LegalityAnalysis { const DataLayout &DL; public: - LegalityAnalysis(ScalarEvolution &SE, const DataLayout &DL) - : SE(SE), DL(DL) {} + LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL) + : Sched(AA), SE(SE), DL(DL) {} /// A LegalityResult factory. template ResultT &createLegalityResult(ArgsT... Args) { @@ -167,7 +172,10 @@ class LegalityAnalysis { } /// Checks if it's legal to vectorize the instructions in \p Bndl. /// \Returns a LegalityResult object owned by LegalityAnalysis. - const LegalityResult &canVectorize(ArrayRef Bndl); + /// \p SkipScheduling skips the scheduler check and is only meant for testing. + // TODO: Try to remove the SkipScheduling argument by refactoring the tests. + const LegalityResult &canVectorize(ArrayRef Bndl, + bool SkipScheduling = false); }; } // namespace llvm::sandboxir diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h index 03867df3d9808..46b953ff9b7f4 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h @@ -10,6 +10,7 @@ #include +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/PassManager.h" #include "llvm/SandboxIR/PassManager.h" @@ -20,6 +21,7 @@ class TargetTransformInfo; class SandboxVectorizerPass : public PassInfoMixin { TargetTransformInfo *TTI = nullptr; + AAResults *AA = nullptr; ScalarEvolution *SE = nullptr; // A pipeline of SandboxIR function passes run by the vectorizer. diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp index 1efd178778b9f..8c6deeb7df249 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp @@ -184,7 +184,8 @@ static void dumpBndl(ArrayRef Bndl) { } #endif // NDEBUG -const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef Bndl) { +const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef Bndl, + bool SkipScheduling) { // If Bndl contains values other than instructions, we need to Pack. if (any_of(Bndl, [](auto *V) { return !isa(V); })) { LLVM_DEBUG(dbgs() << "Not vectorizing: Not Instructions:\n"; @@ -197,7 +198,15 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef Bndl) { // TODO: Check for existing vectors containing values in Bndl. - // TODO: Check with scheduler. + if (!SkipScheduling) { + // TODO: Try to remove the IBndl vector. + SmallVector IBndl; + IBndl.reserve(Bndl.size()); + for (auto *V : Bndl) + IBndl.push_back(cast(V)); + if (!Sched.trySchedule(IBndl)) + return createLegalityResult(ResultReason::CantSchedule); + } return createLegalityResult(); } diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index 339330c64f0ca..005d2241430ff 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -61,8 +61,8 @@ void BottomUpVec::vectorizeRec(ArrayRef Bndl) { void BottomUpVec::tryVectorize(ArrayRef Bndl) { vectorizeRec(Bndl); } bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { - Legality = std::make_unique(A.getScalarEvolution(), - F.getParent()->getDataLayout()); + Legality = std::make_unique( + A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout()); Change = false; // TODO: Start from innermost BBs first for (auto &BB : F) { diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp index 96d825ed852fb..790bee4a4d7f3 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp @@ -51,6 +51,7 @@ SandboxVectorizerPass::~SandboxVectorizerPass() = default; PreservedAnalyses SandboxVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) { TTI = &AM.getResult(F); + AA = &AM.getResult(F); SE = &AM.getResult(F); bool Changed = runImpl(F); @@ -83,6 +84,6 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) { // Create SandboxIR for LLVMF and run BottomUpVec on it. sandboxir::Context Ctx(LLVMF.getContext()); sandboxir::Function &F = *Ctx.createFunction(&LLVMF); - sandboxir::Analyses A(*SE); + sandboxir::Analyses A(*AA, *SE); return FPM.runOnFunction(F, A); } diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp index 68557cb8b129f..51e7a14013299 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp @@ -8,6 +8,7 @@ #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -30,15 +31,20 @@ struct LegalityTest : public testing::Test { std::unique_ptr AC; std::unique_ptr LI; std::unique_ptr SE; + std::unique_ptr BAA; + std::unique_ptr AA; - ScalarEvolution &getSE(llvm::Function &LLVMF) { + void getAnalyses(llvm::Function &LLVMF) { DT = std::make_unique(LLVMF); TLII = std::make_unique(); TLI = std::make_unique(*TLII); AC = std::make_unique(LLVMF); LI = std::make_unique(*DT); SE = std::make_unique(LLVMF, *TLI, *AC, *DT, *LI); - return *SE; + BAA = std::make_unique(LLVMF.getParent()->getDataLayout(), + LLVMF, *TLI, *AC, DT.get()); + AA = std::make_unique(*TLI); + AA->addAAResult(*BAA); } void parseIR(LLVMContext &C, const char *IR) { @@ -49,7 +55,7 @@ struct LegalityTest : public testing::Test { } }; -TEST_F(LegalityTest, Legality) { +TEST_F(LegalityTest, LegalitySkipSchedule) { parseIR(C, R"IR( define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) { %gep0 = getelementptr float, ptr %ptr, i32 0 @@ -76,7 +82,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float } )IR"); llvm::Function *LLVMF = &*M->getFunction("foo"); - auto &SE = getSE(*LLVMF); + getAnalyses(*LLVMF); const auto &DL = M->getDataLayout(); sandboxir::Context Ctx(C); @@ -104,83 +110,139 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float auto *CmpSLT = cast(&*It++); auto *CmpSGT = cast(&*It++); - sandboxir::LegalityAnalysis Legality(SE, DL); - const auto &Result = Legality.canVectorize({St0, St1}); + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL); + const auto &Result = + Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); { // Check NotInstructions - auto &Result = Legality.canVectorize({F, St0}); + auto &Result = Legality.canVectorize({F, St0}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::NotInstructions); } { // Check DiffOpcodes - const auto &Result = Legality.canVectorize({St0, Ld0}); + const auto &Result = + Legality.canVectorize({St0, Ld0}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffOpcodes); } { // Check DiffTypes - EXPECT_TRUE(isa(Legality.canVectorize({St0, StVec2}))); - EXPECT_TRUE(isa(Legality.canVectorize({StVec2, StVec3}))); + EXPECT_TRUE(isa( + Legality.canVectorize({St0, StVec2}, /*SkipScheduling=*/true))); + EXPECT_TRUE(isa( + Legality.canVectorize({StVec2, StVec3}, /*SkipScheduling=*/true))); - const auto &Result = Legality.canVectorize({St0, StI8}); + const auto &Result = + Legality.canVectorize({St0, StI8}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffTypes); } { // Check DiffMathFlags - const auto &Result = Legality.canVectorize({FAdd0, FAdd1}); + const auto &Result = + Legality.canVectorize({FAdd0, FAdd1}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffMathFlags); } { // Check DiffWrapFlags - const auto &Result = Legality.canVectorize({Trunc0, Trunc1}); + const auto &Result = + Legality.canVectorize({Trunc0, Trunc1}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffWrapFlags); } { // Check DiffTypes for unary operands that have a different type. - const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8}); + const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8}, + /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffTypes); } { // Check DiffOpcodes for CMPs with different predicates. - const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT}); + const auto &Result = + Legality.canVectorize({CmpSLT, CmpSGT}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffOpcodes); } { // Check NotConsecutive Ld0,Ld0b - const auto &Result = Legality.canVectorize({Ld0, Ld0b}); + const auto &Result = + Legality.canVectorize({Ld0, Ld0b}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::NotConsecutive); } { // Check NotConsecutive Ld0,Ld3 - const auto &Result = Legality.canVectorize({Ld0, Ld3}); + const auto &Result = + Legality.canVectorize({Ld0, Ld3}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::NotConsecutive); } { // Check Widen Ld0,Ld1 - const auto &Result = Legality.canVectorize({Ld0, Ld1}); + const auto &Result = + Legality.canVectorize({Ld0, Ld1}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); } } +TEST_F(LegalityTest, LegalitySchedule) { + parseIR(C, R"IR( +define void @foo(ptr %ptr) { + %gep0 = getelementptr float, ptr %ptr, i32 0 + %gep1 = getelementptr float, ptr %ptr, i32 1 + %ld0 = load float, ptr %gep0 + store float %ld0, ptr %gep1 + %ld1 = load float, ptr %gep1 + store float %ld0, ptr %gep0 + store float %ld1, ptr %gep1 + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + getAnalyses(*LLVMF); + const auto &DL = M->getDataLayout(); + + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + [[maybe_unused]] auto *Gep0 = cast(&*It++); + [[maybe_unused]] auto *Gep1 = cast(&*It++); + auto *Ld0 = cast(&*It++); + [[maybe_unused]] auto *ConflictingSt = cast(&*It++); + auto *Ld1 = cast(&*It++); + auto *St0 = cast(&*It++); + auto *St1 = cast(&*It++); + + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL); + { + // Can vectorize St0,St1. + const auto &Result = Legality.canVectorize({St0, St1}); + EXPECT_TRUE(isa(Result)); + } + { + // Can't vectorize Ld0,Ld1 because of conflicting store. + auto &Result = Legality.canVectorize({Ld0, Ld1}); + EXPECT_TRUE(isa(Result)); + EXPECT_EQ(cast(Result).getReason(), + sandboxir::ResultReason::CantSchedule); + } +} + #ifndef NDEBUG TEST_F(LegalityTest, LegalityResultDump) { parseIR(C, R"IR( @@ -189,7 +251,7 @@ define void @foo() { } )IR"); llvm::Function *LLVMF = &*M->getFunction("foo"); - auto &SE = getSE(*LLVMF); + getAnalyses(*LLVMF); const auto &DL = M->getDataLayout(); auto Matches = [](const sandboxir::LegalityResult &Result, @@ -200,7 +262,7 @@ define void @foo() { return Buff == ExpectedStr; }; - sandboxir::LegalityAnalysis Legality(SE, DL); + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL); EXPECT_TRUE( Matches(Legality.createLegalityResult(), "Widen")); EXPECT_TRUE(Matches(Legality.createLegalityResult(