diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h index 0fa40e00d23fc..6c2315af0e797 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h @@ -30,8 +30,23 @@ namespace llvm::sandboxir { class PriorityCmp { public: bool operator()(const DGNode *N1, const DGNode *N2) { - // TODO: This should be a hierarchical comparator. - return N1->getInstruction()->comesBefore(N2->getInstruction()); + // Given that the DAG does not model dependencies such that PHIs are always + // at the top, or terminators always at the bottom, we need to force the + // priority here in the comparator of the ready list container. + auto *I1 = N1->getInstruction(); + auto *I2 = N2->getInstruction(); + bool IsTerm1 = I1->isTerminator(); + bool IsTerm2 = I2->isTerminator(); + if (IsTerm1 != IsTerm2) + // Terminators have the lowest priority. + return IsTerm1 > IsTerm2; + bool IsPHI1 = isa(I1); + bool IsPHI2 = isa(I2); + if (IsPHI1 != IsPHI2) + // PHIs have the highest priority. + return IsPHI1 < IsPHI2; + // Otherwise rely on the instruction order. + return I2->comesBefore(I1); } }; diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp index 373af27ffbff0..0d5d86acaee89 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp @@ -253,6 +253,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) { %add0 = add i8 %v0, 0 %add1 = add i8 %v1, 1 br label %bb1 + bb1: store i8 %add0, ptr %ptr0 store i8 %add1, ptr %ptr1 @@ -392,3 +393,77 @@ define void @foo(ptr %ptr) { EXPECT_TRUE(ReadyList.empty()); EXPECT_THAT(Nodes, testing::UnorderedElementsAre(L0N, RetN)); } + +TEST_F(SchedulerTest, ReadyListPriorities) { + parseIR(C, R"IR( +define void @foo(ptr %ptr) { +bb0: + br label %bb1 + +bb1: + %phi0 = phi i8 [0, %bb0], [1, %bb1] + %phi1 = phi i8 [0, %bb0], [1, %bb1] + %ld0 = load i8, ptr %ptr + store i8 %ld0, ptr %ptr + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB1 = getBasicBlockByName(F, "bb1"); + auto It = BB1->begin(); + auto *Phi0 = cast(&*It++); + auto *Phi1 = cast(&*It++); + auto *L0 = cast(&*It++); + auto *S0 = cast(&*It++); + auto *Ret = cast(&*It++); + + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); + DAG.extend({&*BB1->begin(), BB1->getTerminator()}); + auto *Phi0N = DAG.getNode(Phi0); + auto *Phi1N = DAG.getNode(Phi1); + auto *L0N = DAG.getNode(L0); + auto *S0N = DAG.getNode(S0); + auto *RetN = DAG.getNode(Ret); + + sandboxir::ReadyListContainer ReadyList; + // Check PHI vs non-PHI. + ReadyList.insert(S0N); + ReadyList.insert(Phi0N); + EXPECT_EQ(ReadyList.pop(), Phi0N); + EXPECT_EQ(ReadyList.pop(), S0N); + ReadyList.insert(Phi0N); + ReadyList.insert(S0N); + EXPECT_EQ(ReadyList.pop(), Phi0N); + EXPECT_EQ(ReadyList.pop(), S0N); + // Check PHI vs terminator. + ReadyList.insert(RetN); + ReadyList.insert(Phi1N); + EXPECT_EQ(ReadyList.pop(), Phi1N); + EXPECT_EQ(ReadyList.pop(), RetN); + ReadyList.insert(Phi1N); + ReadyList.insert(RetN); + EXPECT_EQ(ReadyList.pop(), Phi1N); + EXPECT_EQ(ReadyList.pop(), RetN); + // Check terminator vs non-terminator. + ReadyList.insert(RetN); + ReadyList.insert(L0N); + EXPECT_EQ(ReadyList.pop(), L0N); + EXPECT_EQ(ReadyList.pop(), RetN); + ReadyList.insert(L0N); + ReadyList.insert(RetN); + EXPECT_EQ(ReadyList.pop(), L0N); + EXPECT_EQ(ReadyList.pop(), RetN); + // Check all, program order. + ReadyList.insert(RetN); + ReadyList.insert(L0N); + ReadyList.insert(Phi1N); + ReadyList.insert(S0N); + ReadyList.insert(Phi0N); + EXPECT_EQ(ReadyList.pop(), Phi0N); + EXPECT_EQ(ReadyList.pop(), Phi1N); + EXPECT_EQ(ReadyList.pop(), L0N); + EXPECT_EQ(ReadyList.pop(), S0N); + EXPECT_EQ(ReadyList.pop(), RetN); +}