diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index 911ee3e839521..b1cad2421bc0d 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -342,11 +342,7 @@ class DependencyGraph { void notifyCreateInstr(Instruction *I); /// Called by the callbacks when instruction \p I is about to get /// deleted. - void notifyEraseInstr(Instruction *I) { - InstrToNodeMap.erase(I); - // TODO: Update the dependencies. - // TODO: Update the MemDGNode chain to remove the node if needed. - } + void notifyEraseInstr(Instruction *I); public: /// This constructor also registers callbacks. diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp index 5cf44ba9dcbaa..25f2665d450d1 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp @@ -370,6 +370,22 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) { } } +void DependencyGraph::notifyEraseInstr(Instruction *I) { + // Update the MemDGNode chain if this is a memory node. + if (auto *MemN = dyn_cast_or_null(getNodeOrNull(I))) { + auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false); + auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false); + if (PrevMemN != nullptr) + PrevMemN->NextMemN = NextMemN; + if (NextMemN != nullptr) + NextMemN->PrevMemN = PrevMemN; + } + + InstrToNodeMap.erase(I); + + // TODO: Update the dependencies. +} + Interval DependencyGraph::extend(ArrayRef Instrs) { if (Instrs.empty()) return {}; diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp index 1130c9c63c71d..8c73ee1def8ae 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp @@ -880,6 +880,16 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) { S2->eraseFromParent(); auto *DeletedN = DAG.getNodeOrNull(S2); EXPECT_TRUE(DeletedN == nullptr); + + // Check the MemDGNode chain. + auto *S1MemN = cast(DAG.getNode(S1)); + auto *S3MemN = cast(DAG.getNode(S3)); + EXPECT_EQ(S1MemN->getNextNode(), S3MemN); + EXPECT_EQ(S3MemN->getPrevNode(), S1MemN); + + // Check the chain when we erase the top node. + S1->eraseFromParent(); + EXPECT_EQ(S3MemN->getPrevNode(), nullptr); + // TODO: Check the dependencies to/from NewSN after they land. - // TODO: Check the MemDGNode chain. }