From b9828f73952b939612aac9cd9536f7a7f039604c Mon Sep 17 00:00:00 2001 From: Vasileios Porpodas Date: Thu, 30 Jan 2025 13:15:25 -0800 Subject: [PATCH] [SandboxIR] SetUse callback This patch implements a callback mechanism similar to the existing ones, but for getting notified whenever a Use edge gets updated. This is going to be used in a follow up patch by the Dependency Graph. --- llvm/include/llvm/SandboxIR/Context.h | 15 ++++- llvm/lib/SandboxIR/Context.cpp | 18 ++++++ llvm/lib/SandboxIR/User.cpp | 13 +++-- llvm/lib/SandboxIR/Value.cpp | 8 ++- llvm/unittests/SandboxIR/SandboxIRTest.cpp | 66 ++++++++++++++++++++++ 5 files changed, 111 insertions(+), 9 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index a88b0003f55bd..714d1ec78f452 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -26,6 +26,7 @@ class BBIterator; class Constant; class Module; class Value; +class Use; class Context { public: @@ -37,6 +38,8 @@ class Context { // destination BB and an iterator pointing to the insertion position. using MoveInstrCallback = std::function; + // A SetUseCallback receives the Use that is about to get its source set. + using SetUseCallback = std::function; /// An ID for a registered callback. Used for deregistration. A dedicated type /// is employed so as to keep IDs opaque to the end user; only Context should @@ -98,6 +101,9 @@ class Context { /// Callbacks called when an IR instruction is about to get moved. Keys are /// used as IDs for deregistration. MapVector MoveInstrCallbacks; + /// Callbacks called when a Use gets its source set. Keys are used as IDs for + /// deregistration. + MapVector SetUseCallbacks; /// A counter used for assigning callback IDs during registration. The same /// counter is used for all kinds of callbacks so we can detect mismatched @@ -129,6 +135,10 @@ class Context { void runEraseInstrCallbacks(Instruction *I); void runCreateInstrCallbacks(Instruction *I); void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where); + void runSetUseCallbacks(const Use &U, Value *NewSrc); + + friend class User; // For runSetUseCallbacks(). + friend class Value; // For runSetUseCallbacks(). // Friends for getOrCreateConstant(). #define DEF_CONST(ID, CLASS) friend class CLASS; @@ -281,7 +291,10 @@ class Context { CallbackID registerMoveInstrCallback(MoveInstrCallback CB); void unregisterMoveInstrCallback(CallbackID ID); - // TODO: Add callbacks for instructions inserted/removed if needed. + /// Register a callback that gets called when a Use gets set. + /// \Returns a callback ID for later deregistration. + CallbackID registerSetUseCallback(SetUseCallback CB); + void unregisterSetUseCallback(CallbackID ID); }; } // namespace sandboxir diff --git a/llvm/lib/SandboxIR/Context.cpp b/llvm/lib/SandboxIR/Context.cpp index 830f2832853fe..6a397b02d6bde 100644 --- a/llvm/lib/SandboxIR/Context.cpp +++ b/llvm/lib/SandboxIR/Context.cpp @@ -687,6 +687,11 @@ void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) { CBEntry.second(I, WhereIt); } +void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) { + for (auto &CBEntry : SetUseCallbacks) + CBEntry.second(U, NewSrc); +} + // An arbitrary limit, to check for accidental misuse. We expect a small number // of callbacks to be registered at a time, but we can increase this number if // we discover we needed more. @@ -732,4 +737,17 @@ void Context::unregisterMoveInstrCallback(CallbackID ID) { "Callback ID not found in MoveInstrCallbacks during deregistration"); } +Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) { + assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks && + "SetUseCallbacks size limit exceeded"); + CallbackID ID{NextCallbackID++}; + SetUseCallbacks[ID] = CB; + return ID; +} +void Context::unregisterSetUseCallback(CallbackID ID) { + [[maybe_unused]] bool Erased = SetUseCallbacks.erase(ID); + assert(Erased && + "Callback ID not found in SetUseCallbacks during deregistration"); +} + } // namespace llvm::sandboxir diff --git a/llvm/lib/SandboxIR/User.cpp b/llvm/lib/SandboxIR/User.cpp index d7e4656e6e90e..43fd565e23836 100644 --- a/llvm/lib/SandboxIR/User.cpp +++ b/llvm/lib/SandboxIR/User.cpp @@ -90,17 +90,20 @@ bool User::classof(const Value *From) { void User::setOperand(unsigned OperandIdx, Value *Operand) { assert(isa(Val) && "No operands!"); - Ctx.getTracker().emplaceIfTracking(getOperandUse(OperandIdx)); + const auto &U = getOperandUse(OperandIdx); + Ctx.getTracker().emplaceIfTracking(U); + Ctx.runSetUseCallbacks(U, Operand); // We are delegating to llvm::User::setOperand(). cast(Val)->setOperand(OperandIdx, Operand->Val); } bool User::replaceUsesOfWith(Value *FromV, Value *ToV) { auto &Tracker = Ctx.getTracker(); - if (Tracker.isTracking()) { - for (auto OpIdx : seq(0, getNumOperands())) { - auto Use = getOperandUse(OpIdx); - if (Use.get() == FromV) + for (auto OpIdx : seq(0, getNumOperands())) { + auto Use = getOperandUse(OpIdx); + if (Use.get() == FromV) { + Ctx.runSetUseCallbacks(Use, ToV); + if (Tracker.isTracking()) Tracker.emplaceIfTracking(Use); } } diff --git a/llvm/lib/SandboxIR/Value.cpp b/llvm/lib/SandboxIR/Value.cpp index b9d91c7e11f74..e39bbc44bca00 100644 --- a/llvm/lib/SandboxIR/Value.cpp +++ b/llvm/lib/SandboxIR/Value.cpp @@ -51,7 +51,7 @@ void Value::replaceUsesWithIf( llvm::Value *OtherVal = OtherV->Val; // We are delegating RUWIf to LLVM IR's RUWIf. Val->replaceUsesWithIf( - OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool { + OtherVal, [&ShouldReplace, this, OtherV](llvm::Use &LLVMUse) -> bool { User *DstU = cast_or_null(Ctx.getValue(LLVMUse.getUser())); if (DstU == nullptr) return false; @@ -59,6 +59,7 @@ void Value::replaceUsesWithIf( if (!ShouldReplace(UseToReplace)) return false; Ctx.getTracker().emplaceIfTracking(UseToReplace); + Ctx.runSetUseCallbacks(UseToReplace, OtherV); return true; }); } @@ -67,8 +68,9 @@ void Value::replaceAllUsesWith(Value *Other) { assert(getType() == Other->getType() && "Replacing with Value of different type!"); auto &Tracker = Ctx.getTracker(); - if (Tracker.isTracking()) { - for (auto Use : uses()) + for (auto Use : uses()) { + Ctx.runSetUseCallbacks(Use, Other); + if (Tracker.isTracking()) Tracker.track(std::make_unique(Use)); } // We are delegating RAUW to LLVM IR's RAUW. diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 9eeac9b60372f..2ad33659c609b 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -6081,6 +6081,72 @@ TEST_F(SandboxIRTest, InstructionCallbacks) { EXPECT_THAT(Moved, testing::IsEmpty()); } +// Check callbacks when we set a Use. +TEST_F(SandboxIRTest, SetUseCallbacks) { + parseIR(C, R"IR( +define void @foo(i8 %v0, i8 %v1) { + %add0 = add i8 %v0, %v1 + %add1 = add i8 %add0, %v1 + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *Arg0 = F->getArg(0); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *Add0 = cast(&*It++); + auto *Add1 = cast(&*It++); + + SmallVector> UsesSet; + auto Id = Ctx.registerSetUseCallback( + [&UsesSet](sandboxir::Use U, sandboxir::Value *NewSrc) { + UsesSet.push_back({U, NewSrc}); + }); + + // Now change %add1 operand to not use %add0. + Add1->setOperand(0, Arg0); + EXPECT_EQ(UsesSet.size(), 1u); + EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get()); + EXPECT_EQ(UsesSet[0].second, Arg0); + // Restore to previous state. + Add1->setOperand(0, Add0); + UsesSet.clear(); + + // RAUW + Add0->replaceAllUsesWith(Arg0); + EXPECT_EQ(UsesSet.size(), 1u); + EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get()); + EXPECT_EQ(UsesSet[0].second, Arg0); + // Restore to previous state. + Add1->setOperand(0, Add0); + UsesSet.clear(); + + // RUWIf + Add0->replaceUsesWithIf(Arg0, [](const auto &U) { return true; }); + EXPECT_EQ(UsesSet.size(), 1u); + EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get()); + EXPECT_EQ(UsesSet[0].second, Arg0); + // Restore to previous state. + Add1->setOperand(0, Add0); + UsesSet.clear(); + + // RUOW + Add1->replaceUsesOfWith(Add0, Arg0); + EXPECT_EQ(UsesSet.size(), 1u); + EXPECT_EQ(UsesSet[0].first.get(), Add1->getOperandUse(0).get()); + EXPECT_EQ(UsesSet[0].second, Arg0); + // Restore to previous state. + Add1->setOperand(0, Add0); + UsesSet.clear(); + + // Check unregister. + Ctx.unregisterSetUseCallback(Id); + Add0->replaceAllUsesWith(Arg0); + EXPECT_TRUE(UsesSet.empty()); +} + TEST_F(SandboxIRTest, FunctionObjectAlreadyExists) { parseIR(C, R"IR( define void @foo() {