Skip to content

Commit 56ffef4

Browse files
committed
[SandboxIR] Add callbacks for instruction insert/remove/move ops.
1 parent 0a53f43 commit 56ffef4

File tree

4 files changed

+205
-7
lines changed

4 files changed

+205
-7
lines changed

llvm/include/llvm/SandboxIR/Context.h

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,31 @@
99
#ifndef LLVM_SANDBOXIR_CONTEXT_H
1010
#define LLVM_SANDBOXIR_CONTEXT_H
1111

12+
#include "llvm/ADT/DenseMap.h"
13+
#include "llvm/ADT/SmallVector.h"
1214
#include "llvm/IR/LLVMContext.h"
1315
#include "llvm/SandboxIR/Tracker.h"
1416
#include "llvm/SandboxIR/Type.h"
1517

1618
namespace llvm::sandboxir {
1719

18-
class Module;
19-
class Value;
2020
class Argument;
21+
class BBIterator;
2122
class Constant;
23+
class Module;
24+
class Value;
2225

2326
class Context {
27+
public:
28+
// A RemoveInstrCallback receives the instruction about to be removed.
29+
using RemoveInstrCallback = std::function<void(Instruction *)>;
30+
// A InsertInstrCallback receives the instruction about to be created.
31+
using InsertInstrCallback = std::function<void(Instruction *)>;
32+
// A MoveInstrCallback receives the instruction about to be moved, the
33+
// destination BB and an iterator pointing to the insertion position.
34+
using MoveInstrCallback =
35+
std::function<void(Instruction *, const BBIterator &)>;
36+
2437
protected:
2538
LLVMContext &LLVMCtx;
2639
friend class Type; // For LLVMCtx.
@@ -48,6 +61,21 @@ class Context {
4861
/// Type objects.
4962
DenseMap<llvm::Type *, std::unique_ptr<Type, TypeDeleter>> LLVMTypeToTypeMap;
5063

64+
/// Callbacks called when an IR instruction is about to get removed. Keys are
65+
/// used as IDs for deregistration.
66+
DenseMap<int, RemoveInstrCallback> RemoveInstrCallbacks;
67+
/// Callbacks called when an IR instruction is about to get inserted. Keys are
68+
/// used as IDs for deregistration.
69+
DenseMap<int, InsertInstrCallback> InsertInstrCallbacks;
70+
/// Callbacks called when an IR instruction is about to get moved. Keys are
71+
/// used as IDs for deregistration.
72+
DenseMap<int, MoveInstrCallback> MoveInstrCallbacks;
73+
74+
/// A counter used for assigning callback IDs during registration. The same
75+
/// counter is used for all kinds of callbacks so we can detect mismatched
76+
/// registration/deregistration.
77+
static int NextCallbackId;
78+
5179
/// Remove \p V from the maps and returns the unique_ptr.
5280
std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
5381
/// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively
@@ -70,6 +98,10 @@ class Context {
7098
Constant *getOrCreateConstant(llvm::Constant *LLVMC);
7199
friend class Utils; // For getMemoryBase
72100

101+
void runRemoveInstrCallbacks(Instruction *I);
102+
void runInsertInstrCallbacks(Instruction *I);
103+
void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where);
104+
73105
// Friends for getOrCreateConstant().
74106
#define DEF_CONST(ID, CLASS) friend class CLASS;
75107
#include "llvm/SandboxIR/Values.def"
@@ -198,6 +230,26 @@ class Context {
198230

199231
/// \Returns the number of values registered with Context.
200232
size_t getNumValues() const { return LLVMValueToValueMap.size(); }
233+
234+
/// Register a callback that gets called when a SandboxIR instruction is about
235+
/// to be removed from its parent. Note that this will also be called when
236+
/// reverting the creation of an instruction.
237+
/// \Returns a callback ID for later deregistration.
238+
int registerRemoveInstrCallback(RemoveInstrCallback CB);
239+
void unregisterRemoveInstrCallback(int CallbackId);
240+
241+
/// Register a callback that gets called right after a SandboxIR instruction
242+
/// is created. Note that this will also be called when reverting the removal
243+
/// of an instruction.
244+
/// \Returns a callback ID for later deregistration.
245+
int registerInsertInstrCallback(InsertInstrCallback CB);
246+
void unregisterInsertInstrCallback(int CallbackId);
247+
248+
/// Register a callback that gets called when a SandboxIR instruction is about
249+
/// to be moved. Note that this will also be called when reverting a move.
250+
/// \Returns a callback ID for later deregistration.
251+
int registerMoveInstrCallback(MoveInstrCallback CB);
252+
void unregisterMoveInstrCallback(int CallbackId);
201253
};
202254

203255
} // namespace llvm::sandboxir

llvm/lib/SandboxIR/Context.cpp

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,20 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
3535
assert(VPtr->getSubclassID() != Value::ClassID::User &&
3636
"Can't register a user!");
3737

38+
Value *V = VPtr.get();
39+
[[maybe_unused]] auto Pair =
40+
LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
41+
assert(Pair.second && "Already exists!");
42+
3843
// Track creation of instructions.
3944
// Please note that we don't allow the creation of detached instructions,
4045
// meaning that the instructions need to be inserted into a block upon
4146
// creation. This is why the tracker class combines creation and insertion.
42-
if (auto *I = dyn_cast<Instruction>(VPtr.get()))
47+
if (auto *I = dyn_cast<Instruction>(V)) {
4348
getTracker().emplaceIfTracking<CreateAndInsertInst>(I);
49+
runInsertInstrCallbacks(I);
50+
}
4451

45-
Value *V = VPtr.get();
46-
[[maybe_unused]] auto Pair =
47-
LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
48-
assert(Pair.second && "Already exists!");
4952
return V;
5053
}
5154

@@ -660,4 +663,57 @@ Module *Context::createModule(llvm::Module *LLVMM) {
660663
return M;
661664
}
662665

666+
void Context::runRemoveInstrCallbacks(Instruction *I) {
667+
for (const auto &CBEntry : RemoveInstrCallbacks) {
668+
CBEntry.second(I);
669+
}
670+
}
671+
672+
void Context::runInsertInstrCallbacks(Instruction *I) {
673+
for (auto &CBEntry : InsertInstrCallbacks) {
674+
CBEntry.second(I);
675+
}
676+
}
677+
678+
void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
679+
for (auto &CBEntry : MoveInstrCallbacks) {
680+
CBEntry.second(I, WhereIt);
681+
}
682+
}
683+
684+
int Context::NextCallbackId = 0;
685+
686+
int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) {
687+
int Id = NextCallbackId++;
688+
RemoveInstrCallbacks[Id] = CB;
689+
return Id;
690+
}
691+
void Context::unregisterRemoveInstrCallback(int CallbackId) {
692+
[[maybe_unused]] bool erased = RemoveInstrCallbacks.erase(CallbackId);
693+
assert(erased &&
694+
"Callback id not found in RemoveInstrCallbacks during deregistration");
695+
}
696+
697+
int Context::registerInsertInstrCallback(InsertInstrCallback CB) {
698+
int Id = NextCallbackId++;
699+
InsertInstrCallbacks[Id] = CB;
700+
return Id;
701+
}
702+
void Context::unregisterInsertInstrCallback(int CallbackId) {
703+
[[maybe_unused]] bool erased = InsertInstrCallbacks.erase(CallbackId);
704+
assert(erased &&
705+
"Callback id not found in InsertInstrCallbacks during deregistration");
706+
}
707+
708+
int Context::registerMoveInstrCallback(MoveInstrCallback CB) {
709+
int Id = NextCallbackId++;
710+
MoveInstrCallbacks[Id] = CB;
711+
return Id;
712+
}
713+
void Context::unregisterMoveInstrCallback(int CallbackId) {
714+
[[maybe_unused]] bool erased = MoveInstrCallbacks.erase(CallbackId);
715+
assert(erased &&
716+
"Callback id not found in MoveInstrCallbacks during deregistration");
717+
}
718+
663719
} // namespace llvm::sandboxir

llvm/lib/SandboxIR/Instruction.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ Instruction *Instruction::getPrevNode() const {
6464
}
6565

6666
void Instruction::removeFromParent() {
67+
Ctx.runRemoveInstrCallbacks(this);
68+
6769
Ctx.getTracker().emplaceIfTracking<RemoveFromParent>(this);
6870

6971
// Detach all the LLVM IR instructions from their parent BB.
@@ -73,6 +75,8 @@ void Instruction::removeFromParent() {
7375

7476
void Instruction::eraseFromParent() {
7577
assert(users().empty() && "Still connected to users, can't erase!");
78+
79+
Ctx.runRemoveInstrCallbacks(this);
7680
std::unique_ptr<Value> Detached = Ctx.detach(this);
7781
auto LLVMInstrs = getLLVMInstrs();
7882

@@ -100,6 +104,7 @@ void Instruction::moveBefore(BasicBlock &BB, const BBIterator &WhereIt) {
100104
// Destination is same as origin, nothing to do.
101105
return;
102106

107+
Ctx.runMoveInstrCallbacks(this, WhereIt);
103108
Ctx.getTracker().emplaceIfTracking<MoveInstr>(this);
104109

105110
auto *LLVMBB = cast<llvm::BasicBlock>(BB.Val);

llvm/unittests/SandboxIR/SandboxIRTest.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/SandboxIR/Value.h"
2323
#include "llvm/Support/SourceMgr.h"
2424
#include "gmock/gmock-matchers.h"
25+
#include "gmock/gmock-more-matchers.h"
2526
#include "gtest/gtest.h"
2627

2728
using namespace llvm;
@@ -5962,3 +5963,87 @@ TEST_F(SandboxIRTest, CheckClassof) {
59625963
EXPECT_NE(&sandboxir::CLASS::classof, &sandboxir::Instruction::classof);
59635964
#include "llvm/SandboxIR/Values.def"
59645965
}
5966+
5967+
TEST_F(SandboxIRTest, InstructionCallbacks) {
5968+
parseIR(C, R"IR(
5969+
define void @foo(ptr %ptr, i8 %val) {
5970+
ret void
5971+
}
5972+
)IR");
5973+
Function &LLVMF = *M->getFunction("foo");
5974+
sandboxir::Context Ctx(C);
5975+
5976+
auto &F = *Ctx.createFunction(&LLVMF);
5977+
auto &BB = *F.begin();
5978+
sandboxir::Argument *Ptr = F.getArg(0);
5979+
sandboxir::Argument *Val = F.getArg(1);
5980+
sandboxir::Instruction *Ret = &BB.front();
5981+
5982+
SmallVector<sandboxir::Instruction *> Inserted;
5983+
int InsertCbId = Ctx.registerInsertInstrCallback(
5984+
[&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); });
5985+
5986+
SmallVector<sandboxir::Instruction *> Removed;
5987+
int RemoveCbId = Ctx.registerRemoveInstrCallback(
5988+
[&Removed](sandboxir::Instruction *I) { Removed.push_back(I); });
5989+
5990+
// Keep the moved instruction and the instruction pointed by the Where
5991+
// iterator so we can check both callback arguments work as expected.
5992+
SmallVector<std::pair<sandboxir::Instruction *, sandboxir::Instruction *>>
5993+
Moved;
5994+
int MoveCbId = Ctx.registerMoveInstrCallback(
5995+
[&Moved](sandboxir::Instruction *I, const sandboxir::BBIterator &Where) {
5996+
// Use a nullptr to signal "move to end" to keep it single. We only
5997+
// have a basic block in this test case anyway.
5998+
if (Where == Where.getNodeParent()->end())
5999+
Moved.push_back(std::make_pair(I, nullptr));
6000+
else
6001+
Moved.push_back(std::make_pair(I, &*Where));
6002+
});
6003+
6004+
Ctx.save();
6005+
auto *NewI = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
6006+
Ret->getIterator(), Ctx);
6007+
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
6008+
EXPECT_THAT(Removed, testing::IsEmpty());
6009+
EXPECT_THAT(Moved, testing::IsEmpty());
6010+
6011+
Ret->moveBefore(NewI);
6012+
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
6013+
EXPECT_THAT(Removed, testing::IsEmpty());
6014+
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
6015+
6016+
Ret->eraseFromParent();
6017+
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
6018+
EXPECT_THAT(Removed, testing::ElementsAre(Ret));
6019+
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
6020+
6021+
NewI->eraseFromParent();
6022+
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
6023+
EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI));
6024+
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));
6025+
6026+
// Check that after revert the callbacks have been called for the inverse
6027+
// operations of the changes made so far.
6028+
Ctx.revert();
6029+
EXPECT_THAT(Inserted, testing::ElementsAre(NewI, NewI, Ret));
6030+
EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI, NewI));
6031+
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI),
6032+
std::make_pair(Ret, nullptr)));
6033+
6034+
// Check that deregistration works. Do an operation of each type after
6035+
// deregistering callbacks and check.
6036+
Inserted.clear();
6037+
Removed.clear();
6038+
Moved.clear();
6039+
Ctx.unregisterInsertInstrCallback(InsertCbId);
6040+
Ctx.unregisterRemoveInstrCallback(RemoveCbId);
6041+
Ctx.unregisterMoveInstrCallback(MoveCbId);
6042+
auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
6043+
Ret->getIterator(), Ctx);
6044+
Ret->moveBefore(NewI2);
6045+
Ret->eraseFromParent();
6046+
EXPECT_THAT(Inserted, testing::IsEmpty());
6047+
EXPECT_THAT(Removed, testing::IsEmpty());
6048+
EXPECT_THAT(Moved, testing::IsEmpty());
6049+
}

0 commit comments

Comments
 (0)