diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index 5be05bc80c492..b498e0f189465 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -345,6 +345,11 @@ class DependencyGraph { Interval extend(ArrayRef Instrs); /// \Returns the range of instructions included in the DAG. Interval getInterval() const { return DAGInterval; } + /// Called by the scheduler when a new instruction \p I has been created. + void notifyCreateInstr(Instruction *I) { + getOrCreateNode(I); + // TODO: Update the dependencies for the new node. + } #ifndef NDEBUG void print(raw_ostream &OS) const; LLVM_DUMP_METHOD void dump() const; diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h index 58dcb2eeadbc2..63d6ef31c8645 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h @@ -162,8 +162,9 @@ class LegalityAnalysis { const DataLayout &DL; public: - LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL) - : Sched(AA), SE(SE), DL(DL) {} + LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL, + Context &Ctx) + : Sched(AA, Ctx), SE(SE), DL(DL) {} /// A LegalityResult factory. template ResultT &createLegalityResult(ArgsT... Args) { diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h index 46b953ff9b7f4..09369dbb496fc 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h @@ -13,6 +13,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/PassManager.h" +#include "llvm/SandboxIR/Context.h" #include "llvm/SandboxIR/PassManager.h" namespace llvm { @@ -24,6 +25,8 @@ class SandboxVectorizerPass : public PassInfoMixin { AAResults *AA = nullptr; ScalarEvolution *SE = nullptr; + std::unique_ptr Ctx; + // A pipeline of SandboxIR function passes run by the vectorizer. sandboxir::FunctionPassManager FPM; diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h index 08972d460b406..0e4eea3880efb 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h @@ -95,6 +95,8 @@ class Scheduler { DependencyGraph DAG; std::optional ScheduleTopItOpt; SmallVector> Bndls; + Context &Ctx; + Context::CallbackID CreateInstrCB; /// \Returns a scheduling bundle containing \p Instrs. SchedBundle *createBundle(ArrayRef Instrs); @@ -110,8 +112,11 @@ class Scheduler { Scheduler &operator=(const Scheduler &) = delete; public: - Scheduler(AAResults &AA) : DAG(AA) {} - ~Scheduler() {} + Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) { + CreateInstrCB = Ctx.registerCreateInstrCallback( + [this](Instruction *I) { DAG.notifyCreateInstr(I); }); + } + ~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); } bool trySchedule(ArrayRef Instrs); diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp index 37713e7da6432..0a930d30aeab5 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp @@ -182,8 +182,6 @@ Value *BottomUpVec::vectorizeRec(ArrayRef Bndl) { } NewVec = createVectorInstr(Bndl, VecOperands); - // TODO: Notify DAG/Scheduler about new instruction - // TODO: Collect potentially dead instructions. break; } @@ -202,7 +200,8 @@ bool BottomUpVec::tryVectorize(ArrayRef Bndl) { bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { Legality = std::make_unique( - A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout()); + A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(), + F.getContext()); Change = false; // TODO: Start from innermost BBs first for (auto &BB : F) { diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp index 790bee4a4d7f3..c22eb01d74a1c 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp @@ -64,6 +64,9 @@ PreservedAnalyses SandboxVectorizerPass::run(Function &F, } bool SandboxVectorizerPass::runImpl(Function &LLVMF) { + if (Ctx == nullptr) + Ctx = std::make_unique(LLVMF.getContext()); + if (PrintPassPipeline) { FPM.printPipeline(outs()); return false; @@ -82,8 +85,7 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) { } // Create SandboxIR for LLVMF and run BottomUpVec on it. - sandboxir::Context Ctx(LLVMF.getContext()); - sandboxir::Function &F = *Ctx.createFunction(&LLVMF); + sandboxir::Function &F = *Ctx->createFunction(&LLVMF); sandboxir::Analyses A(*AA, *SE); return FPM.runOnFunction(F, A); } diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll index 2b9aac93b7485..45c701a18fd9b 100644 --- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll +++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll @@ -55,7 +55,46 @@ define void @store_fpext_load(ptr %ptr) { ret void } -; TODO: Test store_zext_fcmp_load once we implement scheduler callbacks and legality diamond check +define void @store_fcmp_zext_load(ptr %ptr) { +; CHECK-LABEL: define void @store_fcmp_zext_load( +; CHECK-SAME: ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1 +; CHECK-NEXT: [[PTRB0:%.*]] = getelementptr i32, ptr [[PTR]], i32 0 +; CHECK-NEXT: [[PTRB1:%.*]] = getelementptr i32, ptr [[PTR]], i32 1 +; CHECK-NEXT: [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4 +; CHECK-NEXT: [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4 +; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[FCMP0:%.*]] = fcmp ogt float [[LDA0]], [[LDB0]] +; CHECK-NEXT: [[FCMP1:%.*]] = fcmp ogt float [[LDA1]], [[LDB1]] +; CHECK-NEXT: [[VCMP:%.*]] = fcmp ogt <2 x float> [[VECL]], [[VECL1]] +; CHECK-NEXT: [[ZEXT0:%.*]] = zext i1 [[FCMP0]] to i32 +; CHECK-NEXT: [[ZEXT1:%.*]] = zext i1 [[FCMP1]] to i32 +; CHECK-NEXT: [[VCAST:%.*]] = zext <2 x i1> [[VCMP]] to <2 x i32> +; CHECK-NEXT: store i32 [[ZEXT0]], ptr [[PTRB0]], align 4 +; CHECK-NEXT: store i32 [[ZEXT1]], ptr [[PTRB1]], align 4 +; CHECK-NEXT: store <2 x i32> [[VCAST]], ptr [[PTRB0]], align 4 +; CHECK-NEXT: ret void +; + %ptr0 = getelementptr float, ptr %ptr, i32 0 + %ptr1 = getelementptr float, ptr %ptr, i32 1 + %ptrb0 = getelementptr i32, ptr %ptr, i32 0 + %ptrb1 = getelementptr i32, ptr %ptr, i32 1 + %ldB0 = load float, ptr %ptr0 + %ldB1 = load float, ptr %ptr1 + %ldA0 = load float, ptr %ptr0 + %ldA1 = load float, ptr %ptr1 + %fcmp0 = fcmp ogt float %ldA0, %ldB0 + %fcmp1 = fcmp ogt float %ldA1, %ldB1 + %zext0 = zext i1 %fcmp0 to i32 + %zext1 = zext i1 %fcmp1 to i32 + store i32 %zext0, ptr %ptrb0 + store i32 %zext1, ptr %ptrb1 + ret void +} ; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp index 51e7a14013299..b5e2c302f5901 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp @@ -110,7 +110,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float auto *CmpSLT = cast(&*It++); auto *CmpSGT = cast(&*It++); - sandboxir::LegalityAnalysis Legality(*AA, *SE, DL); + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); const auto &Result = Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true); EXPECT_TRUE(isa(Result)); @@ -228,7 +228,7 @@ define void @foo(ptr %ptr) { auto *St0 = cast(&*It++); auto *St1 = cast(&*It++); - sandboxir::LegalityAnalysis Legality(*AA, *SE, DL); + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); { // Can vectorize St0,St1. const auto &Result = Legality.canVectorize({St0, St1}); @@ -262,7 +262,8 @@ define void @foo() { return Buff == ExpectedStr; }; - sandboxir::LegalityAnalysis Legality(*AA, *SE, DL); + sandboxir::Context Ctx(C); + sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx); EXPECT_TRUE( Matches(Legality.createLegalityResult(), "Widen")); EXPECT_TRUE(Matches(Legality.createLegalityResult( diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp index 92e767e55fbdd..4a8b0ba1d7c12 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp @@ -156,20 +156,20 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { { // Schedule all instructions in sequence. - sandboxir::Scheduler Sched(getAA(*LLVMF)); + sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); EXPECT_TRUE(Sched.trySchedule({Ret})); EXPECT_TRUE(Sched.trySchedule({S1})); EXPECT_TRUE(Sched.trySchedule({S0})); } { // Skip instructions. - sandboxir::Scheduler Sched(getAA(*LLVMF)); + sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); EXPECT_TRUE(Sched.trySchedule({Ret})); EXPECT_TRUE(Sched.trySchedule({S0})); } { // Try invalid scheduling - sandboxir::Scheduler Sched(getAA(*LLVMF)); + sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); EXPECT_TRUE(Sched.trySchedule({Ret})); EXPECT_TRUE(Sched.trySchedule({S0})); EXPECT_FALSE(Sched.trySchedule({S1})); @@ -197,7 +197,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) { auto *S1 = cast(&*It++); auto *Ret = cast(&*It++); - sandboxir::Scheduler Sched(getAA(*LLVMF)); + sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx); EXPECT_TRUE(Sched.trySchedule({Ret})); EXPECT_TRUE(Sched.trySchedule({S0, S1})); EXPECT_TRUE(Sched.trySchedule({L0, L1}));