diff --git a/llvm/include/llvm/SandboxIR/Region.h b/llvm/include/llvm/SandboxIR/Region.h index 67411f3fb741d..8133e01734ea7 100644 --- a/llvm/include/llvm/SandboxIR/Region.h +++ b/llvm/include/llvm/SandboxIR/Region.h @@ -63,6 +63,11 @@ class Region { Context &Ctx; + /// ID (for later deregistration) of the "create instruction" callback. + Context::CallbackID CreateInstCB; + /// ID (for later deregistration) of the "erase instruction" callback. + Context::CallbackID EraseInstCB; + // TODO: Add cost modeling. // TODO: Add a way to encode/decode region info to/from metadata. diff --git a/llvm/lib/SandboxIR/Region.cpp b/llvm/lib/SandboxIR/Region.cpp index b6292f3b24b87..1455012440f90 100644 --- a/llvm/lib/SandboxIR/Region.cpp +++ b/llvm/lib/SandboxIR/Region.cpp @@ -15,9 +15,17 @@ Region::Region(Context &Ctx) : Ctx(Ctx) { LLVMContext &LLVMCtx = Ctx.LLVMCtx; auto *RegionStrMD = MDString::get(LLVMCtx, RegionStr); RegionMDN = MDNode::getDistinct(LLVMCtx, {RegionStrMD}); + + CreateInstCB = Ctx.registerCreateInstrCallback( + [this](Instruction *NewInst) { add(NewInst); }); + EraseInstCB = Ctx.registerEraseInstrCallback( + [this](Instruction *ErasedInst) { remove(ErasedInst); }); } -Region::~Region() {} +Region::~Region() { + Ctx.unregisterCreateInstrCallback(CreateInstCB); + Ctx.unregisterEraseInstrCallback(EraseInstCB); +} void Region::add(Instruction *I) { Insts.insert(I); diff --git a/llvm/unittests/SandboxIR/RegionTest.cpp b/llvm/unittests/SandboxIR/RegionTest.cpp index a2efe551c8ff2..47368f93a32c0 100644 --- a/llvm/unittests/SandboxIR/RegionTest.cpp +++ b/llvm/unittests/SandboxIR/RegionTest.cpp @@ -81,6 +81,37 @@ define i8 @foo(i8 %v0, i8 %v1) { #endif } +TEST_F(RegionTest, CallbackUpdates) { + parseIR(C, R"IR( +define i8 @foo(i8 %v0, i8 %v1, ptr %ptr) { + %t0 = add i8 %v0, 1 + %t1 = add i8 %t0, %v1 + ret i8 %t0 +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *Ptr = F->getArg(2); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *T0 = cast(&*It++); + auto *T1 = cast(&*It++); + auto *Ret = cast(&*It++); + sandboxir::Region Rgn(Ctx); + Rgn.add(T0); + Rgn.add(T1); + + // Test creation. + auto *NewI = sandboxir::StoreInst::create(T0, Ptr, /*Align=*/std::nullopt, + Ret->getIterator(), Ctx); + EXPECT_THAT(Rgn.insts(), testing::ElementsAre(T0, T1, NewI)); + + // Test deletion. + T1->eraseFromParent(); + EXPECT_THAT(Rgn.insts(), testing::ElementsAre(T0, NewI)); +} + TEST_F(RegionTest, MetadataFromIR) { parseIR(C, R"IR( define i8 @foo(i8 %v0, i8 %v1) {