From 56ffef4a97b888e1c53153f71301befdc3cfd24d Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Fri, 18 Oct 2024 12:00:06 -0700 Subject: [PATCH 1/9] [SandboxIR] Add callbacks for instruction insert/remove/move ops. --- llvm/include/llvm/SandboxIR/Context.h | 56 +++++++++++++- llvm/lib/SandboxIR/Context.cpp | 66 +++++++++++++++-- llvm/lib/SandboxIR/Instruction.cpp | 5 ++ llvm/unittests/SandboxIR/SandboxIRTest.cpp | 85 ++++++++++++++++++++++ 4 files changed, 205 insertions(+), 7 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index 1285598a1c028..836988639a14b 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -9,18 +9,31 @@ #ifndef LLVM_SANDBOXIR_CONTEXT_H #define LLVM_SANDBOXIR_CONTEXT_H +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/IR/LLVMContext.h" #include "llvm/SandboxIR/Tracker.h" #include "llvm/SandboxIR/Type.h" namespace llvm::sandboxir { -class Module; -class Value; class Argument; +class BBIterator; class Constant; +class Module; +class Value; class Context { +public: + // A RemoveInstrCallback receives the instruction about to be removed. + using RemoveInstrCallback = std::function; + // A InsertInstrCallback receives the instruction about to be created. + using InsertInstrCallback = std::function; + // A MoveInstrCallback receives the instruction about to be moved, the + // destination BB and an iterator pointing to the insertion position. + using MoveInstrCallback = + std::function; + protected: LLVMContext &LLVMCtx; friend class Type; // For LLVMCtx. @@ -48,6 +61,21 @@ class Context { /// Type objects. DenseMap> LLVMTypeToTypeMap; + /// Callbacks called when an IR instruction is about to get removed. Keys are + /// used as IDs for deregistration. + DenseMap RemoveInstrCallbacks; + /// Callbacks called when an IR instruction is about to get inserted. Keys are + /// used as IDs for deregistration. + DenseMap InsertInstrCallbacks; + /// Callbacks called when an IR instruction is about to get moved. Keys are + /// used as IDs for deregistration. + DenseMap MoveInstrCallbacks; + + /// A counter used for assigning callback IDs during registration. The same + /// counter is used for all kinds of callbacks so we can detect mismatched + /// registration/deregistration. + static int NextCallbackId; + /// Remove \p V from the maps and returns the unique_ptr. std::unique_ptr detachLLVMValue(llvm::Value *V); /// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively @@ -70,6 +98,10 @@ class Context { Constant *getOrCreateConstant(llvm::Constant *LLVMC); friend class Utils; // For getMemoryBase + void runRemoveInstrCallbacks(Instruction *I); + void runInsertInstrCallbacks(Instruction *I); + void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where); + // Friends for getOrCreateConstant(). #define DEF_CONST(ID, CLASS) friend class CLASS; #include "llvm/SandboxIR/Values.def" @@ -198,6 +230,26 @@ class Context { /// \Returns the number of values registered with Context. size_t getNumValues() const { return LLVMValueToValueMap.size(); } + + /// Register a callback that gets called when a SandboxIR instruction is about + /// to be removed from its parent. Note that this will also be called when + /// reverting the creation of an instruction. + /// \Returns a callback ID for later deregistration. + int registerRemoveInstrCallback(RemoveInstrCallback CB); + void unregisterRemoveInstrCallback(int CallbackId); + + /// Register a callback that gets called right after a SandboxIR instruction + /// is created. Note that this will also be called when reverting the removal + /// of an instruction. + /// \Returns a callback ID for later deregistration. + int registerInsertInstrCallback(InsertInstrCallback CB); + void unregisterInsertInstrCallback(int CallbackId); + + /// Register a callback that gets called when a SandboxIR instruction is about + /// to be moved. Note that this will also be called when reverting a move. + /// \Returns a callback ID for later deregistration. + int registerMoveInstrCallback(MoveInstrCallback CB); + void unregisterMoveInstrCallback(int CallbackId); }; } // namespace llvm::sandboxir diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index 486e935bc35fb..e13f833f1ba29 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -35,17 +35,20 @@ Value *Context::registerValue(std::unique_ptr &&VPtr) { assert(VPtr->getSubclassID() != Value::ClassID::User && "Can't register a user!"); + Value *V = VPtr.get(); + [[maybe_unused]] auto Pair = + LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)}); + assert(Pair.second && "Already exists!"); + // Track creation of instructions. // Please note that we don't allow the creation of detached instructions, // meaning that the instructions need to be inserted into a block upon // creation. This is why the tracker class combines creation and insertion. - if (auto *I = dyn_cast(VPtr.get())) + if (auto *I = dyn_cast(V)) { getTracker().emplaceIfTracking(I); + runInsertInstrCallbacks(I); + } - Value *V = VPtr.get(); - [[maybe_unused]] auto Pair = - LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)}); - assert(Pair.second && "Already exists!"); return V; } @@ -660,4 +663,57 @@ Module *Context::createModule(llvm::Module *LLVMM) { return M; } +void Context::runRemoveInstrCallbacks(Instruction *I) { + for (const auto &CBEntry : RemoveInstrCallbacks) { + CBEntry.second(I); + } +} + +void Context::runInsertInstrCallbacks(Instruction *I) { + for (auto &CBEntry : InsertInstrCallbacks) { + CBEntry.second(I); + } +} + +void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) { + for (auto &CBEntry : MoveInstrCallbacks) { + CBEntry.second(I, WhereIt); + } +} + +int Context::NextCallbackId = 0; + +int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) { + int Id = NextCallbackId++; + RemoveInstrCallbacks[Id] = CB; + return Id; +} +void Context::unregisterRemoveInstrCallback(int CallbackId) { + [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(CallbackId); + assert(erased && + "Callback id not found in RemoveInstrCallbacks during deregistration"); +} + +int Context::registerInsertInstrCallback(InsertInstrCallback CB) { + int Id = NextCallbackId++; + InsertInstrCallbacks[Id] = CB; + return Id; +} +void Context::unregisterInsertInstrCallback(int CallbackId) { + [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(CallbackId); + assert(erased && + "Callback id not found in InsertInstrCallbacks during deregistration"); +} + +int Context::registerMoveInstrCallback(MoveInstrCallback CB) { + int Id = NextCallbackId++; + MoveInstrCallbacks[Id] = CB; + return Id; +} +void Context::unregisterMoveInstrCallback(int CallbackId) { + [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(CallbackId); + assert(erased && + "Callback id not found in MoveInstrCallbacks during deregistration"); +} + } // namespace llvm::sandboxir diff --git a/llvm/lib/SandboxIR/Instruction.cpp b/llvm/lib/SandboxIR/Instruction.cpp index d80d10370e32d..ddeb78eea19f7 100644 --- a/llvm/lib/SandboxIR/Instruction.cpp +++ b/llvm/lib/SandboxIR/Instruction.cpp @@ -64,6 +64,8 @@ Instruction *Instruction::getPrevNode() const { } void Instruction::removeFromParent() { + Ctx.runRemoveInstrCallbacks(this); + Ctx.getTracker().emplaceIfTracking(this); // Detach all the LLVM IR instructions from their parent BB. @@ -73,6 +75,8 @@ void Instruction::removeFromParent() { void Instruction::eraseFromParent() { assert(users().empty() && "Still connected to users, can't erase!"); + + Ctx.runRemoveInstrCallbacks(this); std::unique_ptr Detached = Ctx.detach(this); auto LLVMInstrs = getLLVMInstrs(); @@ -100,6 +104,7 @@ void Instruction::moveBefore(BasicBlock &BB, const BBIterator &WhereIt) { // Destination is same as origin, nothing to do. return; + Ctx.runMoveInstrCallbacks(this, WhereIt); Ctx.getTracker().emplaceIfTracking(this); auto *LLVMBB = cast(BB.Val); diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 97113b303f72e..786580d1046a6 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -22,6 +22,7 @@ #include "llvm/SandboxIR/Value.h" #include "llvm/Support/SourceMgr.h" #include "gmock/gmock-matchers.h" +#include "gmock/gmock-more-matchers.h" #include "gtest/gtest.h" using namespace llvm; @@ -5962,3 +5963,87 @@ TEST_F(SandboxIRTest, CheckClassof) { EXPECT_NE(&sandboxir::CLASS::classof, &sandboxir::Instruction::classof); #include "llvm/SandboxIR/Values.def" } + +TEST_F(SandboxIRTest, InstructionCallbacks) { + parseIR(C, R"IR( + define void @foo(ptr %ptr, i8 %val) { + ret void + } + )IR"); + Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + + auto &F = *Ctx.createFunction(&LLVMF); + auto &BB = *F.begin(); + sandboxir::Argument *Ptr = F.getArg(0); + sandboxir::Argument *Val = F.getArg(1); + sandboxir::Instruction *Ret = &BB.front(); + + SmallVector Inserted; + int InsertCbId = Ctx.registerInsertInstrCallback( + [&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); }); + + SmallVector Removed; + int RemoveCbId = Ctx.registerRemoveInstrCallback( + [&Removed](sandboxir::Instruction *I) { Removed.push_back(I); }); + + // Keep the moved instruction and the instruction pointed by the Where + // iterator so we can check both callback arguments work as expected. + SmallVector> + Moved; + int MoveCbId = Ctx.registerMoveInstrCallback( + [&Moved](sandboxir::Instruction *I, const sandboxir::BBIterator &Where) { + // Use a nullptr to signal "move to end" to keep it single. We only + // have a basic block in this test case anyway. + if (Where == Where.getNodeParent()->end()) + Moved.push_back(std::make_pair(I, nullptr)); + else + Moved.push_back(std::make_pair(I, &*Where)); + }); + + Ctx.save(); + auto *NewI = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt, + Ret->getIterator(), Ctx); + EXPECT_THAT(Inserted, testing::ElementsAre(NewI)); + EXPECT_THAT(Removed, testing::IsEmpty()); + EXPECT_THAT(Moved, testing::IsEmpty()); + + Ret->moveBefore(NewI); + EXPECT_THAT(Inserted, testing::ElementsAre(NewI)); + EXPECT_THAT(Removed, testing::IsEmpty()); + EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI))); + + Ret->eraseFromParent(); + EXPECT_THAT(Inserted, testing::ElementsAre(NewI)); + EXPECT_THAT(Removed, testing::ElementsAre(Ret)); + EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI))); + + NewI->eraseFromParent(); + EXPECT_THAT(Inserted, testing::ElementsAre(NewI)); + EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI)); + EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI))); + + // Check that after revert the callbacks have been called for the inverse + // operations of the changes made so far. + Ctx.revert(); + EXPECT_THAT(Inserted, testing::ElementsAre(NewI, NewI, Ret)); + EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI, NewI)); + EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI), + std::make_pair(Ret, nullptr))); + + // Check that deregistration works. Do an operation of each type after + // deregistering callbacks and check. + Inserted.clear(); + Removed.clear(); + Moved.clear(); + Ctx.unregisterInsertInstrCallback(InsertCbId); + Ctx.unregisterRemoveInstrCallback(RemoveCbId); + Ctx.unregisterMoveInstrCallback(MoveCbId); + auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt, + Ret->getIterator(), Ctx); + Ret->moveBefore(NewI2); + Ret->eraseFromParent(); + EXPECT_THAT(Inserted, testing::IsEmpty()); + EXPECT_THAT(Removed, testing::IsEmpty()); + EXPECT_THAT(Moved, testing::IsEmpty()); +} From f361d5cc0cde6c4ddfc85bc719025cf84b0818bd Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Fri, 18 Oct 2024 13:44:21 -0700 Subject: [PATCH 2/9] Address some review feedback. - Introduced `CallbackID` typedef for callback ids rather than a plain int. - Remove unnecessary braces. --- llvm/include/llvm/SandboxIR/Context.h | 11 +++-- llvm/lib/SandboxIR/Context.cpp | 47 ++++++++++------------ llvm/unittests/SandboxIR/SandboxIRTest.cpp | 8 ++-- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index 836988639a14b..2da80481f9b6c 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -34,6 +34,9 @@ class Context { using MoveInstrCallback = std::function; + /// An ID for a registered callback. Used for deregistration. + using CallbackID = int; + protected: LLVMContext &LLVMCtx; friend class Type; // For LLVMCtx. @@ -63,18 +66,18 @@ class Context { /// Callbacks called when an IR instruction is about to get removed. Keys are /// used as IDs for deregistration. - DenseMap RemoveInstrCallbacks; + DenseMap RemoveInstrCallbacks; /// Callbacks called when an IR instruction is about to get inserted. Keys are /// used as IDs for deregistration. - DenseMap InsertInstrCallbacks; + DenseMap InsertInstrCallbacks; /// Callbacks called when an IR instruction is about to get moved. Keys are /// used as IDs for deregistration. - DenseMap MoveInstrCallbacks; + DenseMap MoveInstrCallbacks; /// A counter used for assigning callback IDs during registration. The same /// counter is used for all kinds of callbacks so we can detect mismatched /// registration/deregistration. - static int NextCallbackId; + static CallbackID NextCallbackID; /// Remove \p V from the maps and returns the unique_ptr. std::unique_ptr detachLLVMValue(llvm::Value *V); diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index e13f833f1ba29..f22945a076aa2 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -664,56 +664,53 @@ Module *Context::createModule(llvm::Module *LLVMM) { } void Context::runRemoveInstrCallbacks(Instruction *I) { - for (const auto &CBEntry : RemoveInstrCallbacks) { + for (const auto &CBEntry : RemoveInstrCallbacks) CBEntry.second(I); - } } void Context::runInsertInstrCallbacks(Instruction *I) { - for (auto &CBEntry : InsertInstrCallbacks) { + for (auto &CBEntry : InsertInstrCallbacks) CBEntry.second(I); - } } void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) { - for (auto &CBEntry : MoveInstrCallbacks) { + for (auto &CBEntry : MoveInstrCallbacks) CBEntry.second(I, WhereIt); - } } -int Context::NextCallbackId = 0; +Context::CallbackID Context::NextCallbackID = 0; int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) { - int Id = NextCallbackId++; - RemoveInstrCallbacks[Id] = CB; - return Id; + CallbackID ID = NextCallbackID++; + RemoveInstrCallbacks[ID] = CB; + return ID; } -void Context::unregisterRemoveInstrCallback(int CallbackId) { - [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(CallbackId); +void Context::unregisterRemoveInstrCallback(CallbackID ID) { + [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(ID); assert(erased && - "Callback id not found in RemoveInstrCallbacks during deregistration"); + "Callback ID not found in RemoveInstrCallbacks during deregistration"); } int Context::registerInsertInstrCallback(InsertInstrCallback CB) { - int Id = NextCallbackId++; - InsertInstrCallbacks[Id] = CB; - return Id; + CallbackID ID = NextCallbackID++; + InsertInstrCallbacks[ID] = CB; + return ID; } -void Context::unregisterInsertInstrCallback(int CallbackId) { - [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(CallbackId); +void Context::unregisterInsertInstrCallback(CallbackID ID) { + [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(ID); assert(erased && - "Callback id not found in InsertInstrCallbacks during deregistration"); + "Callback ID not found in InsertInstrCallbacks during deregistration"); } int Context::registerMoveInstrCallback(MoveInstrCallback CB) { - int Id = NextCallbackId++; - MoveInstrCallbacks[Id] = CB; - return Id; + CallbackID ID = NextCallbackID++; + MoveInstrCallbacks[ID] = CB; + return ID; } -void Context::unregisterMoveInstrCallback(int CallbackId) { - [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(CallbackId); +void Context::unregisterMoveInstrCallback(CallbackID ID) { + [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(ID); assert(erased && - "Callback id not found in MoveInstrCallbacks during deregistration"); + "Callback ID not found in MoveInstrCallbacks during deregistration"); } } // namespace llvm::sandboxir diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 786580d1046a6..268c6e0712c50 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -5980,18 +5980,18 @@ TEST_F(SandboxIRTest, InstructionCallbacks) { sandboxir::Instruction *Ret = &BB.front(); SmallVector Inserted; - int InsertCbId = Ctx.registerInsertInstrCallback( + auto InsertCbId = Ctx.registerInsertInstrCallback( [&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); }); SmallVector Removed; - int RemoveCbId = Ctx.registerRemoveInstrCallback( + auto RemoveCbId = Ctx.registerRemoveInstrCallback( [&Removed](sandboxir::Instruction *I) { Removed.push_back(I); }); // Keep the moved instruction and the instruction pointed by the Where // iterator so we can check both callback arguments work as expected. SmallVector> Moved; - int MoveCbId = Ctx.registerMoveInstrCallback( + auto MoveCbId = Ctx.registerMoveInstrCallback( [&Moved](sandboxir::Instruction *I, const sandboxir::BBIterator &Where) { // Use a nullptr to signal "move to end" to keep it single. We only // have a basic block in this test case anyway. @@ -6040,7 +6040,7 @@ TEST_F(SandboxIRTest, InstructionCallbacks) { Ctx.unregisterRemoveInstrCallback(RemoveCbId); Ctx.unregisterMoveInstrCallback(MoveCbId); auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt, - Ret->getIterator(), Ctx); + Ret->getIterator(), Ctx); Ret->moveBefore(NewI2); Ret->eraseFromParent(); EXPECT_THAT(Inserted, testing::IsEmpty()); From e84e9e6a3f889f8bbbe1a8d7074e9213ae9904c3 Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Fri, 18 Oct 2024 14:10:25 -0700 Subject: [PATCH 3/9] Another round of feedback. - Updated callback (de)registration method signatures in header to use the CallbackID type instead of int. - Corrected case in some variable names to follow the style guide. --- llvm/include/llvm/SandboxIR/Context.h | 12 ++++++------ llvm/lib/SandboxIR/Context.cpp | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index 2da80481f9b6c..cc3572db6039d 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -238,21 +238,21 @@ class Context { /// to be removed from its parent. Note that this will also be called when /// reverting the creation of an instruction. /// \Returns a callback ID for later deregistration. - int registerRemoveInstrCallback(RemoveInstrCallback CB); - void unregisterRemoveInstrCallback(int CallbackId); + CallbackID registerRemoveInstrCallback(RemoveInstrCallback CB); + void unregisterRemoveInstrCallback(CallbackID ID); /// Register a callback that gets called right after a SandboxIR instruction /// is created. Note that this will also be called when reverting the removal /// of an instruction. /// \Returns a callback ID for later deregistration. - int registerInsertInstrCallback(InsertInstrCallback CB); - void unregisterInsertInstrCallback(int CallbackId); + CallbackID registerInsertInstrCallback(InsertInstrCallback CB); + void unregisterInsertInstrCallback(CallbackID ID); /// Register a callback that gets called when a SandboxIR instruction is about /// to be moved. Note that this will also be called when reverting a move. /// \Returns a callback ID for later deregistration. - int registerMoveInstrCallback(MoveInstrCallback CB); - void unregisterMoveInstrCallback(int CallbackId); + CallbackID registerMoveInstrCallback(MoveInstrCallback CB); + void unregisterMoveInstrCallback(CallbackID ID); }; } // namespace llvm::sandboxir diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index f22945a076aa2..66ec08757ca31 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -686,8 +686,8 @@ int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) { return ID; } void Context::unregisterRemoveInstrCallback(CallbackID ID) { - [[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(ID); - assert(erased && + [[maybe_unused]] bool Erased = RemoveInstrCallbacks.erase(ID); + assert(Erased && "Callback ID not found in RemoveInstrCallbacks during deregistration"); } @@ -697,8 +697,8 @@ int Context::registerInsertInstrCallback(InsertInstrCallback CB) { return ID; } void Context::unregisterInsertInstrCallback(CallbackID ID) { - [[maybe_unused]] bool erased = InsertInstrCallbacks.erase(ID); - assert(erased && + [[maybe_unused]] bool Erased = InsertInstrCallbacks.erase(ID); + assert(Erased && "Callback ID not found in InsertInstrCallbacks during deregistration"); } @@ -708,8 +708,8 @@ int Context::registerMoveInstrCallback(MoveInstrCallback CB) { return ID; } void Context::unregisterMoveInstrCallback(CallbackID ID) { - [[maybe_unused]] bool erased = MoveInstrCallbacks.erase(ID); - assert(erased && + [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID); + assert(Erased && "Callback ID not found in MoveInstrCallbacks during deregistration"); } From fa875d7995cbcfb850e84f3a7dbab0a9e071ae41 Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Fri, 18 Oct 2024 14:29:40 -0700 Subject: [PATCH 4/9] Make NextCallbackID not static --- llvm/include/llvm/SandboxIR/Context.h | 2 +- llvm/lib/SandboxIR/Context.cpp | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index cc3572db6039d..f1ee6b6de222f 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -77,7 +77,7 @@ class Context { /// A counter used for assigning callback IDs during registration. The same /// counter is used for all kinds of callbacks so we can detect mismatched /// registration/deregistration. - static CallbackID NextCallbackID; + CallbackID NextCallbackID = 0; /// Remove \p V from the maps and returns the unique_ptr. std::unique_ptr detachLLVMValue(llvm::Value *V); diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index 66ec08757ca31..213ad7f5c6d8a 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -678,8 +678,6 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) { CBEntry.second(I, WhereIt); } -Context::CallbackID Context::NextCallbackID = 0; - int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) { CallbackID ID = NextCallbackID++; RemoveInstrCallbacks[ID] = CB; From 906854107b4dfda37ed95db066fab308a8a2c6e8 Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Fri, 18 Oct 2024 18:07:23 -0700 Subject: [PATCH 5/9] Switched callbacks to MapVector to iterate in registration order. Added test to check that registration order is respected when invoking registered callbacks. --- llvm/include/llvm/SandboxIR/Context.h | 7 ++++--- llvm/unittests/SandboxIR/SandboxIRTest.cpp | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index f1ee6b6de222f..b8e1f667f1467 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -10,6 +10,7 @@ #define LLVM_SANDBOXIR_CONTEXT_H #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/LLVMContext.h" #include "llvm/SandboxIR/Tracker.h" @@ -66,13 +67,13 @@ class Context { /// Callbacks called when an IR instruction is about to get removed. Keys are /// used as IDs for deregistration. - DenseMap RemoveInstrCallbacks; + MapVector RemoveInstrCallbacks; /// Callbacks called when an IR instruction is about to get inserted. Keys are /// used as IDs for deregistration. - DenseMap InsertInstrCallbacks; + MapVector InsertInstrCallbacks; /// Callbacks called when an IR instruction is about to get moved. Keys are /// used as IDs for deregistration. - DenseMap MoveInstrCallbacks; + MapVector MoveInstrCallbacks; /// A counter used for assigning callback IDs during registration. The same /// counter is used for all kinds of callbacks so we can detect mismatched diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 268c6e0712c50..5bad56b406447 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -6001,12 +6001,22 @@ TEST_F(SandboxIRTest, InstructionCallbacks) { Moved.push_back(std::make_pair(I, &*Where)); }); + // Two more insertion callbacks, to check that they're called in registration + // order. + SmallVector Order; + auto CheckOrderInsertCbId1 = Ctx.registerInsertInstrCallback( + [&Order](sandboxir::Instruction *I) { Order.push_back(1); }); + + auto CheckOrderInsertCbId2 = Ctx.registerInsertInstrCallback( + [&Order](sandboxir::Instruction *I) { Order.push_back(2); }); + Ctx.save(); auto *NewI = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt, Ret->getIterator(), Ctx); EXPECT_THAT(Inserted, testing::ElementsAre(NewI)); EXPECT_THAT(Removed, testing::IsEmpty()); EXPECT_THAT(Moved, testing::IsEmpty()); + EXPECT_THAT(Order, testing::ElementsAre(1, 2)); Ret->moveBefore(NewI); EXPECT_THAT(Inserted, testing::ElementsAre(NewI)); @@ -6030,6 +6040,7 @@ TEST_F(SandboxIRTest, InstructionCallbacks) { EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI, NewI)); EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI), std::make_pair(Ret, nullptr))); + EXPECT_THAT(Order, testing::ElementsAre(1, 2, 1, 2, 1, 2)); // Check that deregistration works. Do an operation of each type after // deregistering callbacks and check. @@ -6039,6 +6050,8 @@ TEST_F(SandboxIRTest, InstructionCallbacks) { Ctx.unregisterInsertInstrCallback(InsertCbId); Ctx.unregisterRemoveInstrCallback(RemoveCbId); Ctx.unregisterMoveInstrCallback(MoveCbId); + Ctx.unregisterInsertInstrCallback(CheckOrderInsertCbId1); + Ctx.unregisterInsertInstrCallback(CheckOrderInsertCbId2); auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt, Ret->getIterator(), Ctx); Ret->moveBefore(NewI2); From 18ed35d02462594640a5061e04b94db07791fcb6 Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Fri, 25 Oct 2024 17:17:16 -0700 Subject: [PATCH 6/9] Address more review comments, plus a couple of things discussed offline. - Remove insert/remove callbacks to create/erase. Don't call the erase callback on remove. This way we have a consistent model, but we don't have insert/remove callbacks, just create/erase/move. The missing callbacks can still be added if needed in the future. - Changed callback ids to uint64 in the unlikely case they could overflow a 32-bit integer in a large compile, causing hard-to-debug errors. --- llvm/include/llvm/SandboxIR/Context.h | 38 +++++++++++++--------- llvm/lib/SandboxIR/Context.cpp | 34 ++++++++++--------- llvm/lib/SandboxIR/Instruction.cpp | 4 +-- llvm/unittests/SandboxIR/SandboxIRTest.cpp | 16 ++++----- 4 files changed, 49 insertions(+), 43 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index b8e1f667f1467..f2056de87cb94 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -16,6 +16,8 @@ #include "llvm/SandboxIR/Tracker.h" #include "llvm/SandboxIR/Type.h" +#include + namespace llvm::sandboxir { class Argument; @@ -26,17 +28,19 @@ class Value; class Context { public: - // A RemoveInstrCallback receives the instruction about to be removed. - using RemoveInstrCallback = std::function; - // A InsertInstrCallback receives the instruction about to be created. - using InsertInstrCallback = std::function; + // A EraseInstrCallback receives the instruction about to be erased. + using EraseInstrCallback = std::function; + // A CreateInstrCallback receives the instruction about to be created. + using CreateInstrCallback = std::function; // A MoveInstrCallback receives the instruction about to be moved, the // destination BB and an iterator pointing to the insertion position. using MoveInstrCallback = std::function; - /// An ID for a registered callback. Used for deregistration. - using CallbackID = int; + /// An ID for a registered callback. Used for deregistration. Using a 64-bit + /// integer so we don't have to worry about the unlikely case of overflowing + /// a 32-bit counter. + using CallbackID = uint64_t; protected: LLVMContext &LLVMCtx; @@ -65,12 +69,12 @@ class Context { /// Type objects. DenseMap> LLVMTypeToTypeMap; - /// Callbacks called when an IR instruction is about to get removed. Keys are + /// Callbacks called when an IR instruction is about to get erased. Keys are /// used as IDs for deregistration. - MapVector RemoveInstrCallbacks; - /// Callbacks called when an IR instruction is about to get inserted. Keys are + MapVector EraseInstrCallbacks; + /// Callbacks called when an IR instruction is about to get created. Keys are /// used as IDs for deregistration. - MapVector InsertInstrCallbacks; + MapVector CreateInstrCallbacks; /// Callbacks called when an IR instruction is about to get moved. Keys are /// used as IDs for deregistration. MapVector MoveInstrCallbacks; @@ -102,8 +106,8 @@ class Context { Constant *getOrCreateConstant(llvm::Constant *LLVMC); friend class Utils; // For getMemoryBase - void runRemoveInstrCallbacks(Instruction *I); - void runInsertInstrCallbacks(Instruction *I); + void runEraseInstrCallbacks(Instruction *I); + void runCreateInstrCallbacks(Instruction *I); void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where); // Friends for getOrCreateConstant(). @@ -239,21 +243,23 @@ class Context { /// to be removed from its parent. Note that this will also be called when /// reverting the creation of an instruction. /// \Returns a callback ID for later deregistration. - CallbackID registerRemoveInstrCallback(RemoveInstrCallback CB); - void unregisterRemoveInstrCallback(CallbackID ID); + CallbackID registerEraseInstrCallback(EraseInstrCallback CB); + void unregisterEraseInstrCallback(CallbackID ID); /// Register a callback that gets called right after a SandboxIR instruction /// is created. Note that this will also be called when reverting the removal /// of an instruction. /// \Returns a callback ID for later deregistration. - CallbackID registerInsertInstrCallback(InsertInstrCallback CB); - void unregisterInsertInstrCallback(CallbackID ID); + CallbackID registerCreateInstrCallback(CreateInstrCallback CB); + void unregisterCreateInstrCallback(CallbackID ID); /// Register a callback that gets called when a SandboxIR instruction is about /// to be moved. Note that this will also be called when reverting a move. /// \Returns a callback ID for later deregistration. CallbackID registerMoveInstrCallback(MoveInstrCallback CB); void unregisterMoveInstrCallback(CallbackID ID); + + // TODO: Add callbacks for instructions inserted/removed if needed. }; } // namespace llvm::sandboxir diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index 213ad7f5c6d8a..c0b3559692856 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -46,7 +46,7 @@ Value *Context::registerValue(std::unique_ptr &&VPtr) { // creation. This is why the tracker class combines creation and insertion. if (auto *I = dyn_cast(V)) { getTracker().emplaceIfTracking(I); - runInsertInstrCallbacks(I); + runCreateInstrCallbacks(I); } return V; @@ -663,13 +663,13 @@ Module *Context::createModule(llvm::Module *LLVMM) { return M; } -void Context::runRemoveInstrCallbacks(Instruction *I) { - for (const auto &CBEntry : RemoveInstrCallbacks) +void Context::runEraseInstrCallbacks(Instruction *I) { + for (const auto &CBEntry : EraseInstrCallbacks) CBEntry.second(I); } -void Context::runInsertInstrCallbacks(Instruction *I) { - for (auto &CBEntry : InsertInstrCallbacks) +void Context::runCreateInstrCallbacks(Instruction *I) { + for (auto &CBEntry : CreateInstrCallbacks) CBEntry.second(I); } @@ -678,29 +678,31 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) { CBEntry.second(I, WhereIt); } -int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) { +Context::CallbackID +Context::registerEraseInstrCallback(EraseInstrCallback CB) { CallbackID ID = NextCallbackID++; - RemoveInstrCallbacks[ID] = CB; + EraseInstrCallbacks[ID] = CB; return ID; } -void Context::unregisterRemoveInstrCallback(CallbackID ID) { - [[maybe_unused]] bool Erased = RemoveInstrCallbacks.erase(ID); +void Context::unregisterEraseInstrCallback(CallbackID ID) { + [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(ID); assert(Erased && - "Callback ID not found in RemoveInstrCallbacks during deregistration"); + "Callback ID not found in EraseInstrCallbacks during deregistration"); } -int Context::registerInsertInstrCallback(InsertInstrCallback CB) { +Context::CallbackID +Context::registerCreateInstrCallback(CreateInstrCallback CB) { CallbackID ID = NextCallbackID++; - InsertInstrCallbacks[ID] = CB; + CreateInstrCallbacks[ID] = CB; return ID; } -void Context::unregisterInsertInstrCallback(CallbackID ID) { - [[maybe_unused]] bool Erased = InsertInstrCallbacks.erase(ID); +void Context::unregisterCreateInstrCallback(CallbackID ID) { + [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(ID); assert(Erased && - "Callback ID not found in InsertInstrCallbacks during deregistration"); + "Callback ID not found in CreateInstrCallbacks during deregistration"); } -int Context::registerMoveInstrCallback(MoveInstrCallback CB) { +Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) { CallbackID ID = NextCallbackID++; MoveInstrCallbacks[ID] = CB; return ID; diff --git a/llvm/lib/SandboxIR/Instruction.cpp b/llvm/lib/SandboxIR/Instruction.cpp index ddeb78eea19f7..096b827541eea 100644 --- a/llvm/lib/SandboxIR/Instruction.cpp +++ b/llvm/lib/SandboxIR/Instruction.cpp @@ -64,8 +64,6 @@ Instruction *Instruction::getPrevNode() const { } void Instruction::removeFromParent() { - Ctx.runRemoveInstrCallbacks(this); - Ctx.getTracker().emplaceIfTracking(this); // Detach all the LLVM IR instructions from their parent BB. @@ -76,7 +74,7 @@ void Instruction::removeFromParent() { void Instruction::eraseFromParent() { assert(users().empty() && "Still connected to users, can't erase!"); - Ctx.runRemoveInstrCallbacks(this); + Ctx.runEraseInstrCallbacks(this); std::unique_ptr Detached = Ctx.detach(this); auto LLVMInstrs = getLLVMInstrs(); diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 5bad56b406447..99e14292a91b9 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -5980,11 +5980,11 @@ TEST_F(SandboxIRTest, InstructionCallbacks) { sandboxir::Instruction *Ret = &BB.front(); SmallVector Inserted; - auto InsertCbId = Ctx.registerInsertInstrCallback( + auto InsertCbId = Ctx.registerCreateInstrCallback( [&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); }); SmallVector Removed; - auto RemoveCbId = Ctx.registerRemoveInstrCallback( + auto RemoveCbId = Ctx.registerEraseInstrCallback( [&Removed](sandboxir::Instruction *I) { Removed.push_back(I); }); // Keep the moved instruction and the instruction pointed by the Where @@ -6004,10 +6004,10 @@ TEST_F(SandboxIRTest, InstructionCallbacks) { // Two more insertion callbacks, to check that they're called in registration // order. SmallVector Order; - auto CheckOrderInsertCbId1 = Ctx.registerInsertInstrCallback( + auto CheckOrderInsertCbId1 = Ctx.registerCreateInstrCallback( [&Order](sandboxir::Instruction *I) { Order.push_back(1); }); - auto CheckOrderInsertCbId2 = Ctx.registerInsertInstrCallback( + auto CheckOrderInsertCbId2 = Ctx.registerCreateInstrCallback( [&Order](sandboxir::Instruction *I) { Order.push_back(2); }); Ctx.save(); @@ -6047,11 +6047,11 @@ TEST_F(SandboxIRTest, InstructionCallbacks) { Inserted.clear(); Removed.clear(); Moved.clear(); - Ctx.unregisterInsertInstrCallback(InsertCbId); - Ctx.unregisterRemoveInstrCallback(RemoveCbId); + Ctx.unregisterCreateInstrCallback(InsertCbId); + Ctx.unregisterEraseInstrCallback(RemoveCbId); Ctx.unregisterMoveInstrCallback(MoveCbId); - Ctx.unregisterInsertInstrCallback(CheckOrderInsertCbId1); - Ctx.unregisterInsertInstrCallback(CheckOrderInsertCbId2); + Ctx.unregisterCreateInstrCallback(CheckOrderInsertCbId1); + Ctx.unregisterCreateInstrCallback(CheckOrderInsertCbId2); auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt, Ret->getIterator(), Ctx); Ret->moveBefore(NewI2); From 35a20748234d9a448b1c4987ac2b8dc50f1e01df Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Fri, 25 Oct 2024 17:27:38 -0700 Subject: [PATCH 7/9] clang-format --- llvm/lib/SandboxIR/Context.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index c0b3559692856..0943f9526d024 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -678,8 +678,7 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) { CBEntry.second(I, WhereIt); } -Context::CallbackID -Context::registerEraseInstrCallback(EraseInstrCallback CB) { +Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) { CallbackID ID = NextCallbackID++; EraseInstrCallbacks[ID] = CB; return ID; From 367e5b59c9be280e57ff6569c599bda44965c1b4 Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Mon, 28 Oct 2024 18:17:38 -0700 Subject: [PATCH 8/9] Add assertion for max callbacks registered at once --- llvm/lib/SandboxIR/Context.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index 0943f9526d024..301b4b784016e 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -678,7 +678,13 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) { CBEntry.second(I, WhereIt); } +// An arbitrary limit, to check for accidental misuse. We expect a small number +// of callbacks to be registered at a time, but we can increase this number if +// we discover we needed more. +static constexpr int MaxRegisteredCallbacks = 16; + Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) { + assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks && "EraseInstrCallbacks size limit exceeded"); CallbackID ID = NextCallbackID++; EraseInstrCallbacks[ID] = CB; return ID; @@ -691,6 +697,7 @@ void Context::unregisterEraseInstrCallback(CallbackID ID) { Context::CallbackID Context::registerCreateInstrCallback(CreateInstrCallback CB) { + assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks && "CreateInstrCallbacks size limit exceeded"); CallbackID ID = NextCallbackID++; CreateInstrCallbacks[ID] = CB; return ID; @@ -702,6 +709,7 @@ void Context::unregisterCreateInstrCallback(CallbackID ID) { } Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) { + assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks && "MoveInstrCallbacks size limit exceeded"); CallbackID ID = NextCallbackID++; MoveInstrCallbacks[ID] = CB; return ID; From 3ef32def0f94eadc792d2d89d44d5a00f5f69ac7 Mon Sep 17 00:00:00 2001 From: Jorge Gorbe Moya Date: Mon, 28 Oct 2024 18:18:38 -0700 Subject: [PATCH 9/9] more clang-format --- llvm/lib/SandboxIR/Context.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index 301b4b784016e..5e5cbbbc4515d 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -684,7 +684,8 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) { static constexpr int MaxRegisteredCallbacks = 16; Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) { - assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks && "EraseInstrCallbacks size limit exceeded"); + assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks && + "EraseInstrCallbacks size limit exceeded"); CallbackID ID = NextCallbackID++; EraseInstrCallbacks[ID] = CB; return ID; @@ -697,7 +698,8 @@ void Context::unregisterEraseInstrCallback(CallbackID ID) { Context::CallbackID Context::registerCreateInstrCallback(CreateInstrCallback CB) { - assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks && "CreateInstrCallbacks size limit exceeded"); + assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks && + "CreateInstrCallbacks size limit exceeded"); CallbackID ID = NextCallbackID++; CreateInstrCallbacks[ID] = CB; return ID; @@ -709,7 +711,8 @@ void Context::unregisterCreateInstrCallback(CallbackID ID) { } Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) { - assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks && "MoveInstrCallbacks size limit exceeded"); + assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks && + "MoveInstrCallbacks size limit exceeded"); CallbackID ID = NextCallbackID++; MoveInstrCallbacks[ID] = CB; return ID;