diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index 134adc4b21ab1..5036dae1f0e27 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -154,6 +154,14 @@ class MemDGNode final : public DGNode { /// Convenience builders for a MemDGNode interval. class MemDGNodeIntervalBuilder { public: + /// Scans the instruction chain in \p Intvl top-down, returning the top-most + /// MemDGNode, or nullptr. + static MemDGNode *getTopMemDGNode(const Interval &Intvl, + const DependencyGraph &DAG); + /// Scans the instruction chain in \p Intvl bottom-up, returning the + /// bottom-most MemDGNode, or nullptr. + static MemDGNode *getBotMemDGNode(const Interval &Intvl, + const DependencyGraph &DAG); /// Given \p Instrs it finds their closest mem nodes in the interval and /// returns the corresponding mem range. Note: BotN (or its neighboring mem /// node) is included in the range. diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp index 82f253d4c6323..c02eba167390d 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp @@ -32,23 +32,43 @@ void DGNode::dump() const { } #endif // NDEBUG +MemDGNode * +MemDGNodeIntervalBuilder::getTopMemDGNode(const Interval &Intvl, + const DependencyGraph &DAG) { + Instruction *I = Intvl.top(); + Instruction *BeforeI = Intvl.bottom(); + // Walk down the chain looking for a mem-dep candidate instruction. + while (!DGNode::isMemDepNodeCandidate(I) && I != BeforeI) + I = I->getNextNode(); + if (!DGNode::isMemDepNodeCandidate(I)) + return nullptr; + return cast(DAG.getNode(I)); +} + +MemDGNode * +MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval &Intvl, + const DependencyGraph &DAG) { + Instruction *I = Intvl.bottom(); + Instruction *AfterI = Intvl.top(); + // Walk up the chain looking for a mem-dep candidate instruction. + while (!DGNode::isMemDepNodeCandidate(I) && I != AfterI) + I = I->getPrevNode(); + if (!DGNode::isMemDepNodeCandidate(I)) + return nullptr; + return cast(DAG.getNode(I)); +} + Interval MemDGNodeIntervalBuilder::make(const Interval &Instrs, DependencyGraph &DAG) { - // If top or bottom instructions are not mem-dep candidate nodes we need to - // walk down/up the chain and find the mem-dep ones. - Instruction *MemTopI = Instrs.top(); - Instruction *MemBotI = Instrs.bottom(); - while (!DGNode::isMemDepNodeCandidate(MemTopI) && MemTopI != MemBotI) - MemTopI = MemTopI->getNextNode(); - while (!DGNode::isMemDepNodeCandidate(MemBotI) && MemBotI != MemTopI) - MemBotI = MemBotI->getPrevNode(); + auto *TopMemN = getTopMemDGNode(Instrs, DAG); // If we couldn't find a mem node in range TopN - BotN then it's empty. - if (!DGNode::isMemDepNodeCandidate(MemTopI)) + if (TopMemN == nullptr) return {}; + auto *BotMemN = getBotMemDGNode(Instrs, DAG); + assert(BotMemN != nullptr && "TopMemN should be null too!"); // Now that we have the mem-dep nodes, create and return the range. - return Interval(cast(DAG.getNode(MemTopI)), - cast(DAG.getNode(MemBotI))); + return Interval(TopMemN, BotMemN); } DependencyGraph::DependencyType diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp index e2f16919a5cdd..b425e5a8ad214 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp @@ -305,6 +305,20 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { auto *S0N = cast(DAG.getNode(S0)); auto *S1N = cast(DAG.getNode(S1)); + // Check getTopMemDGNode(). + using B = sandboxir::MemDGNodeIntervalBuilder; + using InstrInterval = sandboxir::Interval; + EXPECT_EQ(B::getTopMemDGNode(InstrInterval(S0, S0), DAG), S0N); + EXPECT_EQ(B::getTopMemDGNode(InstrInterval(S0, Ret), DAG), S0N); + EXPECT_EQ(B::getTopMemDGNode(InstrInterval(Add0, Add1), DAG), S0N); + EXPECT_EQ(B::getTopMemDGNode(InstrInterval(Add0, Add0), DAG), nullptr); + + // Check getBotMemDGNode(). + EXPECT_EQ(B::getBotMemDGNode(InstrInterval(S1, S1), DAG), S1N); + EXPECT_EQ(B::getBotMemDGNode(InstrInterval(Add0, S1), DAG), S1N); + EXPECT_EQ(B::getBotMemDGNode(InstrInterval(Add0, Ret), DAG), S1N); + EXPECT_EQ(B::getBotMemDGNode(InstrInterval(Ret, Ret), DAG), nullptr); + // Check empty range. EXPECT_THAT(sandboxir::MemDGNodeIntervalBuilder::makeEmpty(), testing::ElementsAre());