diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index 765b65c4971be..68a2daca1403d 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -117,7 +117,7 @@ class DGNode { assert(!isMemDepNodeCandidate(I) && "Expected Non-Mem instruction, "); } DGNode(const DGNode &Other) = delete; - virtual ~DGNode() = default; + virtual ~DGNode(); /// \Returns the number of unscheduled successors. unsigned getNumUnscheduledSuccs() const { return UnscheduledSuccs; } void decrUnscheduledSuccs() { @@ -292,6 +292,7 @@ class DependencyGraph { Context *Ctx = nullptr; std::optional CreateInstrCB; + std::optional EraseInstrCB; std::unique_ptr BatchAA; @@ -334,6 +335,12 @@ class DependencyGraph { // TODO: Update the dependencies for the new node. // TODO: Update the MemDGNode chain to include the new node if needed. } + /// Called by the callbacks when instruction \p I is about to get deleted. + void notifyEraseInstr(Instruction *I) { + InstrToNodeMap.erase(I); + // TODO: Update the dependencies. + // TODO: Update the MemDGNode chain to remove the node if needed. + } public: /// This constructor also registers callbacks. @@ -341,10 +348,14 @@ class DependencyGraph { : Ctx(&Ctx), BatchAA(std::make_unique(AA)) { CreateInstrCB = Ctx.registerCreateInstrCallback( [this](Instruction *I) { notifyCreateInstr(I); }); + EraseInstrCB = Ctx.registerEraseInstrCallback( + [this](Instruction *I) { notifyEraseInstr(I); }); } ~DependencyGraph() { if (CreateInstrCB) Ctx->unregisterCreateInstrCallback(*CreateInstrCB); + if (EraseInstrCB) + Ctx->unregisterEraseInstrCallback(*EraseInstrCB); } DGNode *getNode(Instruction *I) const { diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h index 022fd71df67dc..3959f84c601e0 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h @@ -69,6 +69,10 @@ class SchedBundle { private: ContainerTy Nodes; + /// Called by the DGNode destructor to avoid accessing freed memory. + void eraseFromBundle(DGNode *N) { Nodes.erase(find(Nodes, N)); } + friend DGNode::~DGNode(); // For eraseFromBundle(). + public: SchedBundle() = default; SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) { diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp index 6217c9fecf45d..4b0e12c28f07b 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp @@ -10,6 +10,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/SandboxIR/Instruction.h" #include "llvm/SandboxIR/Utils.h" +#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h" namespace llvm::sandboxir { @@ -58,6 +59,12 @@ bool PredIterator::operator==(const PredIterator &Other) const { return OpIt == Other.OpIt && MemIt == Other.MemIt; } +DGNode::~DGNode() { + if (SB == nullptr) + return; + SB->eraseFromBundle(this); +} + #ifndef NDEBUG void DGNode::print(raw_ostream &OS, bool PrintDeps) const { OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n"; diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp index 206f6c5b4c135..e6bb4b4684d26 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp @@ -830,3 +830,31 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) { // TODO: Check the dependencies to/from NewSN after they land. // TODO: Check the MemDGNode chain. } + +TEST_F(DependencyGraphTest, EraseInstrCallback) { + parseIR(C, R"IR( +define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) { + store i8 %v1, ptr %ptr + store i8 %v2, ptr %ptr + store i8 %v3, ptr %ptr + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *S1 = cast(&*It++); + auto *S2 = cast(&*It++); + auto *S3 = cast(&*It++); + + // Check erase instruction callback. + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); + DAG.extend({S1, S3}); + S2->eraseFromParent(); + auto *DeletedN = DAG.getNodeOrNull(S2); + EXPECT_TRUE(DeletedN == nullptr); + // TODO: Check the dependencies to/from NewSN after they land. + // TODO: Check the MemDGNode chain. +}