Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class DGNode {
assert(!isMemDepNodeCandidate(I) && "Expected Non-Mem instruction, ");
}
DGNode(const DGNode &Other) = delete;
virtual ~DGNode() = default;
virtual ~DGNode();
/// \Returns the number of unscheduled successors.
unsigned getNumUnscheduledSuccs() const { return UnscheduledSuccs; }
void decrUnscheduledSuccs() {
Expand Down Expand Up @@ -292,6 +292,7 @@ class DependencyGraph {

Context *Ctx = nullptr;
std::optional<Context::CallbackID> CreateInstrCB;
std::optional<Context::CallbackID> EraseInstrCB;

std::unique_ptr<BatchAAResults> BatchAA;

Expand Down Expand Up @@ -334,17 +335,27 @@ class DependencyGraph {
// TODO: Update the dependencies for the new node.
// TODO: Update the MemDGNode chain to include the new node if needed.
}
/// Called by the callbacks when instruction \p I is about to get deleted.
void notifyEraseInstr(Instruction *I) {
InstrToNodeMap.erase(I);
// TODO: Update the dependencies.
// TODO: Update the MemDGNode chain to remove the node if needed.
}

public:
/// This constructor also registers callbacks.
DependencyGraph(AAResults &AA, Context &Ctx)
: Ctx(&Ctx), BatchAA(std::make_unique<BatchAAResults>(AA)) {
CreateInstrCB = Ctx.registerCreateInstrCallback(
[this](Instruction *I) { notifyCreateInstr(I); });
EraseInstrCB = Ctx.registerEraseInstrCallback(
[this](Instruction *I) { notifyEraseInstr(I); });
}
~DependencyGraph() {
if (CreateInstrCB)
Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
if (EraseInstrCB)
Ctx->unregisterEraseInstrCallback(*EraseInstrCB);
}

DGNode *getNode(Instruction *I) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class SchedBundle {
private:
ContainerTy Nodes;

/// Called by the DGNode destructor to avoid accessing freed memory.
void eraseFromBundle(DGNode *N) { Nodes.erase(find(Nodes, N)); }
friend DGNode::~DGNode(); // For eraseFromBundle().

public:
SchedBundle() = default;
SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/SandboxIR/Instruction.h"
#include "llvm/SandboxIR/Utils.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"

namespace llvm::sandboxir {

Expand Down Expand Up @@ -58,6 +59,12 @@ bool PredIterator::operator==(const PredIterator &Other) const {
return OpIt == Other.OpIt && MemIt == Other.MemIt;
}

DGNode::~DGNode() {
if (SB == nullptr)
return;
SB->eraseFromBundle(this);
}

#ifndef NDEBUG
void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -830,3 +830,31 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
// TODO: Check the dependencies to/from NewSN after they land.
// TODO: Check the MemDGNode chain.
}

TEST_F(DependencyGraphTest, EraseInstrCallback) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
store i8 %v1, ptr %ptr
store i8 %v2, ptr %ptr
store i8 %v3, ptr %ptr
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *S2 = cast<sandboxir::StoreInst>(&*It++);
auto *S3 = cast<sandboxir::StoreInst>(&*It++);

// Check erase instruction callback.
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({S1, S3});
S2->eraseFromParent();
auto *DeletedN = DAG.getNodeOrNull(S2);
EXPECT_TRUE(DeletedN == nullptr);
// TODO: Check the dependencies to/from NewSN after they land.
// TODO: Check the MemDGNode chain.
}
Loading