diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index 5fa57efc1462e..0da52c4236d77 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -284,6 +284,10 @@ class DependencyGraph { /// \p DstN. void scanAndAddDeps(MemDGNode &DstN, const Interval &SrcScanRange); + /// Create DAG nodes for instrs in \p NewInterval and update the MemNode + /// chain. + void createNewNodes(const Interval &NewInterval); + public: DependencyGraph(AAResults &AA) : BatchAA(std::make_unique(AA)) {} @@ -309,8 +313,10 @@ class DependencyGraph { return It->second.get(); } /// Build/extend the dependency graph such that it includes \p Instrs. Returns - /// the interval spanning \p Instrs. + /// the range of instructions added to the DAG. Interval extend(ArrayRef Instrs); + /// \Returns the range of instructions included in the DAG. + Interval getInterval() const { return DAGInterval; } #ifndef NDEBUG void print(raw_ostream &OS) const; LLVM_DUMP_METHOD void dump() const; diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp index 0cd2240e7ff1b..db58069de4705 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp @@ -215,17 +215,11 @@ void DependencyGraph::scanAndAddDeps(MemDGNode &DstN, } } -Interval DependencyGraph::extend(ArrayRef Instrs) { - if (Instrs.empty()) - return {}; - - Interval InstrInterval(Instrs); - - DGNode *LastN = getOrCreateNode(InstrInterval.top()); - // Create DGNodes for all instrs in Interval to avoid future Instruction to - // DGNode lookups. +void DependencyGraph::createNewNodes(const Interval &NewInterval) { + // Create Nodes only for the new sections of the DAG. + DGNode *LastN = getOrCreateNode(NewInterval.top()); MemDGNode *LastMemN = dyn_cast(LastN); - for (Instruction &I : drop_begin(InstrInterval)) { + for (Instruction &I : drop_begin(NewInterval)) { auto *N = getOrCreateNode(&I); // Build the Mem node chain. if (auto *MemN = dyn_cast(N)) { @@ -235,16 +229,109 @@ Interval DependencyGraph::extend(ArrayRef Instrs) { LastMemN = MemN; } } + // Link new MemDGNode chain with the old one, if any. + if (!DAGInterval.empty()) { + // TODO: Implement Interval::comesBefore() to replace this check. + bool NewIsAbove = NewInterval.bottom()->comesBefore(DAGInterval.top()); + assert( + (NewIsAbove || DAGInterval.bottom()->comesBefore(NewInterval.top())) && + "Expected NewInterval below DAGInterval."); + const auto &TopInterval = NewIsAbove ? NewInterval : DAGInterval; + const auto &BotInterval = NewIsAbove ? DAGInterval : NewInterval; + MemDGNode *LinkTopN = + MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this); + MemDGNode *LinkBotN = + MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this); + assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!"); + if (LinkTopN != nullptr && LinkBotN != nullptr) { + LinkTopN->setNextNode(LinkBotN); + LinkBotN->setPrevNode(LinkTopN); + } +#ifndef NDEBUG + // TODO: Remove this once we've done enough testing. + // Check that the chain is well formed. + auto UnionIntvl = DAGInterval.getUnionInterval(NewInterval); + MemDGNode *ChainTopN = + MemDGNodeIntervalBuilder::getTopMemDGNode(UnionIntvl, *this); + MemDGNode *ChainBotN = + MemDGNodeIntervalBuilder::getBotMemDGNode(UnionIntvl, *this); + if (ChainTopN != nullptr && ChainBotN != nullptr) { + for (auto *N = ChainTopN->getNextNode(), *LastN = ChainTopN; N != nullptr; + LastN = N, N = N->getNextNode()) { + assert(N == LastN->getNextNode() && "Bad chain!"); + assert(N->getPrevNode() == LastN && "Bad chain!"); + } + } +#endif // NDEBUG + } +} + +Interval DependencyGraph::extend(ArrayRef Instrs) { + if (Instrs.empty()) + return {}; + + Interval InstrsInterval(Instrs); + Interval Union = DAGInterval.getUnionInterval(InstrsInterval); + auto NewInterval = Union.getSingleDiff(DAGInterval); + if (NewInterval.empty()) + return {}; + + createNewNodes(NewInterval); + // Create the dependencies. - auto DstRange = MemDGNodeIntervalBuilder::make(InstrInterval, *this); - if (!DstRange.empty()) { - for (MemDGNode &DstN : drop_begin(DstRange)) { - auto SrcRange = Interval(DstRange.top(), DstN.getPrevNode()); + // + // 1. DAGInterval empty 2. New is below Old 3. New is above old + // ------------------------ ------------------- ------------------- + // Scan: DstN: Scan: + // +---+ -ScanTopN +---+DstTopN -ScanTopN + // | | | |New| | + // |Old| | +---+ -ScanBotN + // | | | +---+ + // DstN: Scan: +---+DstN: | | | + // +---+DstTopN -ScanTopN +---+DstTopN | |Old| + // |New| | |New| | | | + // +---+DstBotN -ScanBotN +---+DstBotN -ScanBotN +---+DstBotN + + // 1. This is a new DAG. + if (DAGInterval.empty()) { + assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!"); + auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); + if (!DstRange.empty()) { + for (MemDGNode &DstN : drop_begin(DstRange)) { + auto SrcRange = Interval(DstRange.top(), DstN.getPrevNode()); + scanAndAddDeps(DstN, SrcRange); + } + } + } + // 2. The new section is below the old section. + else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) { + auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this); + auto SrcRangeFull = MemDGNodeIntervalBuilder::make( + DAGInterval.getUnionInterval(NewInterval), *this); + for (MemDGNode &DstN : DstRange) { + auto SrcRange = + Interval(SrcRangeFull.top(), DstN.getPrevNode()); scanAndAddDeps(DstN, SrcRange); } } + // 3. The new section is above the old section. + else if (NewInterval.bottom()->comesBefore(DAGInterval.top())) { + auto DstRange = MemDGNodeIntervalBuilder::make( + NewInterval.getUnionInterval(DAGInterval), *this); + auto SrcRangeFull = MemDGNodeIntervalBuilder::make(NewInterval, *this); + if (!DstRange.empty()) { + for (MemDGNode &DstN : drop_begin(DstRange)) { + auto SrcRange = + Interval(SrcRangeFull.top(), DstN.getPrevNode()); + scanAndAddDeps(DstN, SrcRange); + } + } + } else { + llvm_unreachable("We don't expect extending in both directions!"); + } - return InstrInterval; + DAGInterval = Union; + return NewInterval; } #ifndef NDEBUG diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp index 7e2be25fa25ae..3dbf03e4ba44e 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp @@ -681,3 +681,70 @@ define void @foo() { EXPECT_FALSE(memDependency(StackSaveN, AllocaN)); EXPECT_FALSE(memDependency(AllocaN, StackRestoreN)); } + +TEST_F(DependencyGraphTest, Extend) { + parseIR(C, R"IR( +define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) { + store i8 %v1, ptr %ptr + store i8 %v2, ptr %ptr + store i8 %v3, ptr %ptr + store i8 %v4, ptr %ptr + store i8 %v5, ptr %ptr + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *S1 = cast(&*It++); + auto *S2 = cast(&*It++); + auto *S3 = cast(&*It++); + auto *S4 = cast(&*It++); + auto *S5 = cast(&*It++); + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + { + // Scenario 1: Build new DAG + auto NewIntvl = DAG.extend({S3, S3}); + EXPECT_EQ(NewIntvl, sandboxir::Interval(S3, S3)); + EXPECT_EQ(DAG.getInterval().top(), S3); + EXPECT_EQ(DAG.getInterval().bottom(), S3); + [[maybe_unused]] auto *S3N = cast(DAG.getNode(S3)); + } + { + // Scenario 2: Extend below + auto NewIntvl = DAG.extend({S5, S5}); + EXPECT_EQ(NewIntvl, sandboxir::Interval(S4, S5)); + auto *S3N = cast(DAG.getNode(S3)); + auto *S4N = cast(DAG.getNode(S4)); + auto *S5N = cast(DAG.getNode(S5)); + EXPECT_TRUE(S4N->hasMemPred(S3N)); + EXPECT_TRUE(S5N->hasMemPred(S4N)); + EXPECT_TRUE(S5N->hasMemPred(S3N)); + } + { + // Scenario 3: Extend above + auto NewIntvl = DAG.extend({S1, S2}); + EXPECT_EQ(NewIntvl, sandboxir::Interval(S1, S2)); + auto *S1N = cast(DAG.getNode(S1)); + auto *S2N = cast(DAG.getNode(S2)); + auto *S3N = cast(DAG.getNode(S3)); + auto *S4N = cast(DAG.getNode(S4)); + auto *S5N = cast(DAG.getNode(S5)); + + EXPECT_TRUE(S2N->hasMemPred(S1N)); + + EXPECT_TRUE(S3N->hasMemPred(S2N)); + EXPECT_TRUE(S3N->hasMemPred(S1N)); + + EXPECT_TRUE(S4N->hasMemPred(S3N)); + EXPECT_TRUE(S4N->hasMemPred(S2N)); + EXPECT_TRUE(S4N->hasMemPred(S1N)); + + EXPECT_TRUE(S5N->hasMemPred(S4N)); + EXPECT_TRUE(S5N->hasMemPred(S3N)); + EXPECT_TRUE(S5N->hasMemPred(S2N)); + EXPECT_TRUE(S5N->hasMemPred(S1N)); + } +}