Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions llvm/include/llvm/SandboxIR/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,34 @@
#ifndef LLVM_SANDBOXIR_CONTEXT_H
#define LLVM_SANDBOXIR_CONTEXT_H

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/SandboxIR/Tracker.h"
#include "llvm/SandboxIR/Type.h"

namespace llvm::sandboxir {

class Module;
class Value;
class Argument;
class BBIterator;
class Constant;
class Module;
class Value;

class Context {
public:
// A RemoveInstrCallback receives the instruction about to be removed.
using RemoveInstrCallback = std::function<void(Instruction *)>;
// A InsertInstrCallback receives the instruction about to be created.
using InsertInstrCallback = std::function<void(Instruction *)>;
// A MoveInstrCallback receives the instruction about to be moved, the
// destination BB and an iterator pointing to the insertion position.
using MoveInstrCallback =
std::function<void(Instruction *, const BBIterator &)>;

/// An ID for a registered callback. Used for deregistration.
using CallbackID = int;
Copy link
Contributor

@vporpo vporpo Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use an ID and not the raw pointer of the callback itself as the ID ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-determinism

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pointers to the std::function objects stored in a map/vector can get invalidated when registering/unregistering more callbacks, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are we going to get non-determinism ? The ID itself is not printed anywhere, it's just used as a key to remove the callback if needed. If it's the raw pointer of the function it should still work just as well. Or am I missing something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pointers to the std::function objects stored in a map/vector can get invalidated when registering/unregistering more callbacks, right?

True but we can get around that in multiple ways, one of which is by allocating unique pointers and getting the raw pointer.

What I don't like about the integer IDs is that:

  • they are just being used as a key and they don't really convey anything else. They are not even being used for debugging purposes.
  • they need the static? counter

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is another layer of indirection and an extra heap allocation better than keeping an integer around?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use raw pointers as keys in maps, you will get non-determinism.

The context should own the counter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is another layer of indirection and an extra heap allocation better than keeping an integer around?

Yeah, it's definitely not great, and we should try to make the code that runs the callbacks as fast as possible.

Since @tschuett brought up non-determinism, should we try to make sure that the callbacks are run in the same order as registered? Because we are currently not only not getting that order but since we are iterating over the DenseMap we are also getting a non-deterministic order. Using a MapVector will fix non-determinism, but perhaps we should try to stick to registration order, and try to change it later if it turns out that it's not needed. I think we could use a vector instead of a map, because removal of a callback shouldn't happen too often (and we won't have too many callbacks) so a linear-time search should be fine. Wdyt?


protected:
LLVMContext &LLVMCtx;
friend class Type; // For LLVMCtx.
Expand Down Expand Up @@ -48,6 +64,21 @@ class Context {
/// Type objects.
DenseMap<llvm::Type *, std::unique_ptr<Type, TypeDeleter>> LLVMTypeToTypeMap;

/// Callbacks called when an IR instruction is about to get removed. Keys are
/// used as IDs for deregistration.
DenseMap<CallbackID, RemoveInstrCallback> RemoveInstrCallbacks;
/// Callbacks called when an IR instruction is about to get inserted. Keys are
/// used as IDs for deregistration.
DenseMap<CallbackID, InsertInstrCallback> InsertInstrCallbacks;
/// Callbacks called when an IR instruction is about to get moved. Keys are
/// used as IDs for deregistration.
DenseMap<CallbackID, MoveInstrCallback> MoveInstrCallbacks;

/// A counter used for assigning callback IDs during registration. The same
/// counter is used for all kinds of callbacks so we can detect mismatched
/// registration/deregistration.
CallbackID NextCallbackID = 0;

/// Remove \p V from the maps and returns the unique_ptr.
std::unique_ptr<Value> detachLLVMValue(llvm::Value *V);
/// Remove \p SBV from all SandboxIR maps and stop owning it. This effectively
Expand All @@ -70,6 +101,10 @@ class Context {
Constant *getOrCreateConstant(llvm::Constant *LLVMC);
friend class Utils; // For getMemoryBase

void runRemoveInstrCallbacks(Instruction *I);
void runInsertInstrCallbacks(Instruction *I);
void runMoveInstrCallbacks(Instruction *I, const BBIterator &Where);

// Friends for getOrCreateConstant().
#define DEF_CONST(ID, CLASS) friend class CLASS;
#include "llvm/SandboxIR/Values.def"
Expand Down Expand Up @@ -198,6 +233,26 @@ class Context {

/// \Returns the number of values registered with Context.
size_t getNumValues() const { return LLVMValueToValueMap.size(); }

/// Register a callback that gets called when a SandboxIR instruction is about
/// to be removed from its parent. Note that this will also be called when
/// reverting the creation of an instruction.
/// \Returns a callback ID for later deregistration.
CallbackID registerRemoveInstrCallback(RemoveInstrCallback CB);
void unregisterRemoveInstrCallback(CallbackID ID);

/// Register a callback that gets called right after a SandboxIR instruction
/// is created. Note that this will also be called when reverting the removal
/// of an instruction.
/// \Returns a callback ID for later deregistration.
CallbackID registerInsertInstrCallback(InsertInstrCallback CB);
void unregisterInsertInstrCallback(CallbackID ID);

/// Register a callback that gets called when a SandboxIR instruction is about
/// to be moved. Note that this will also be called when reverting a move.
/// \Returns a callback ID for later deregistration.
CallbackID registerMoveInstrCallback(MoveInstrCallback CB);
void unregisterMoveInstrCallback(CallbackID ID);
};

} // namespace llvm::sandboxir
Expand Down
61 changes: 56 additions & 5 deletions llvm/lib/SandboxIR/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,20 @@ Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
assert(VPtr->getSubclassID() != Value::ClassID::User &&
"Can't register a user!");

Value *V = VPtr.get();
[[maybe_unused]] auto Pair =
LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
assert(Pair.second && "Already exists!");

// Track creation of instructions.
// Please note that we don't allow the creation of detached instructions,
// meaning that the instructions need to be inserted into a block upon
// creation. This is why the tracker class combines creation and insertion.
if (auto *I = dyn_cast<Instruction>(VPtr.get()))
if (auto *I = dyn_cast<Instruction>(V)) {
getTracker().emplaceIfTracking<CreateAndInsertInst>(I);
runInsertInstrCallbacks(I);
}

Value *V = VPtr.get();
[[maybe_unused]] auto Pair =
LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
assert(Pair.second && "Already exists!");
return V;
}

Expand Down Expand Up @@ -660,4 +663,52 @@ Module *Context::createModule(llvm::Module *LLVMM) {
return M;
}

void Context::runRemoveInstrCallbacks(Instruction *I) {
for (const auto &CBEntry : RemoveInstrCallbacks)
CBEntry.second(I);
}

void Context::runInsertInstrCallbacks(Instruction *I) {
for (auto &CBEntry : InsertInstrCallbacks)
CBEntry.second(I);
}

void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
for (auto &CBEntry : MoveInstrCallbacks)
CBEntry.second(I, WhereIt);
}

int Context::registerRemoveInstrCallback(RemoveInstrCallback CB) {
CallbackID ID = NextCallbackID++;
RemoveInstrCallbacks[ID] = CB;
return ID;
}
void Context::unregisterRemoveInstrCallback(CallbackID ID) {
[[maybe_unused]] bool Erased = RemoveInstrCallbacks.erase(ID);
assert(Erased &&
"Callback ID not found in RemoveInstrCallbacks during deregistration");
}

int Context::registerInsertInstrCallback(InsertInstrCallback CB) {
CallbackID ID = NextCallbackID++;
InsertInstrCallbacks[ID] = CB;
return ID;
}
void Context::unregisterInsertInstrCallback(CallbackID ID) {
[[maybe_unused]] bool Erased = InsertInstrCallbacks.erase(ID);
assert(Erased &&
"Callback ID not found in InsertInstrCallbacks during deregistration");
}

int Context::registerMoveInstrCallback(MoveInstrCallback CB) {
CallbackID ID = NextCallbackID++;
MoveInstrCallbacks[ID] = CB;
return ID;
}
void Context::unregisterMoveInstrCallback(CallbackID ID) {
[[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID);
assert(Erased &&
"Callback ID not found in MoveInstrCallbacks during deregistration");
}

} // namespace llvm::sandboxir
5 changes: 5 additions & 0 deletions llvm/lib/SandboxIR/Instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ Instruction *Instruction::getPrevNode() const {
}

void Instruction::removeFromParent() {
Ctx.runRemoveInstrCallbacks(this);

Ctx.getTracker().emplaceIfTracking<RemoveFromParent>(this);

// Detach all the LLVM IR instructions from their parent BB.
Expand All @@ -73,6 +75,8 @@ void Instruction::removeFromParent() {

void Instruction::eraseFromParent() {
assert(users().empty() && "Still connected to users, can't erase!");

Ctx.runRemoveInstrCallbacks(this);
std::unique_ptr<Value> Detached = Ctx.detach(this);
auto LLVMInstrs = getLLVMInstrs();

Expand Down Expand Up @@ -100,6 +104,7 @@ void Instruction::moveBefore(BasicBlock &BB, const BBIterator &WhereIt) {
// Destination is same as origin, nothing to do.
return;

Ctx.runMoveInstrCallbacks(this, WhereIt);
Ctx.getTracker().emplaceIfTracking<MoveInstr>(this);

auto *LLVMBB = cast<llvm::BasicBlock>(BB.Val);
Expand Down
85 changes: 85 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/SourceMgr.h"
#include "gmock/gmock-matchers.h"
#include "gmock/gmock-more-matchers.h"
#include "gtest/gtest.h"

using namespace llvm;
Expand Down Expand Up @@ -5962,3 +5963,87 @@ TEST_F(SandboxIRTest, CheckClassof) {
EXPECT_NE(&sandboxir::CLASS::classof, &sandboxir::Instruction::classof);
#include "llvm/SandboxIR/Values.def"
}

TEST_F(SandboxIRTest, InstructionCallbacks) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %val) {
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
sandboxir::Argument *Ptr = F.getArg(0);
sandboxir::Argument *Val = F.getArg(1);
sandboxir::Instruction *Ret = &BB.front();

SmallVector<sandboxir::Instruction *> Inserted;
auto InsertCbId = Ctx.registerInsertInstrCallback(
[&Inserted](sandboxir::Instruction *I) { Inserted.push_back(I); });

SmallVector<sandboxir::Instruction *> Removed;
auto RemoveCbId = Ctx.registerRemoveInstrCallback(
[&Removed](sandboxir::Instruction *I) { Removed.push_back(I); });

// Keep the moved instruction and the instruction pointed by the Where
// iterator so we can check both callback arguments work as expected.
SmallVector<std::pair<sandboxir::Instruction *, sandboxir::Instruction *>>
Moved;
auto MoveCbId = Ctx.registerMoveInstrCallback(
[&Moved](sandboxir::Instruction *I, const sandboxir::BBIterator &Where) {
// Use a nullptr to signal "move to end" to keep it single. We only
// have a basic block in this test case anyway.
if (Where == Where.getNodeParent()->end())
Moved.push_back(std::make_pair(I, nullptr));
else
Moved.push_back(std::make_pair(I, &*Where));
});

Ctx.save();
auto *NewI = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
Ret->getIterator(), Ctx);
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
EXPECT_THAT(Removed, testing::IsEmpty());
EXPECT_THAT(Moved, testing::IsEmpty());

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A drive-by comment: How about testing insertBefore and insertAfter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed this offline because @vporpo's prototype only called them on create/erase, but called them insert/remove. We only need create/erase/move callbacks for the vectorizer, so we're only adding those for now, and I have renamed them to create/erase callbacks. This leaves space to add proper insert/remove callbacks in the future if we (or some future user of Sandbox IR) needs them.

Thus, no need to test those here as they have no associated callbacks for now.

Ret->moveBefore(NewI);
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
EXPECT_THAT(Removed, testing::IsEmpty());
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));

Ret->eraseFromParent();
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
EXPECT_THAT(Removed, testing::ElementsAre(Ret));
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));

NewI->eraseFromParent();
EXPECT_THAT(Inserted, testing::ElementsAre(NewI));
EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI));
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI)));

// Check that after revert the callbacks have been called for the inverse
// operations of the changes made so far.
Ctx.revert();
EXPECT_THAT(Inserted, testing::ElementsAre(NewI, NewI, Ret));
EXPECT_THAT(Removed, testing::ElementsAre(Ret, NewI, NewI));
EXPECT_THAT(Moved, testing::ElementsAre(std::make_pair(Ret, NewI),
std::make_pair(Ret, nullptr)));

// Check that deregistration works. Do an operation of each type after
// deregistering callbacks and check.
Inserted.clear();
Removed.clear();
Moved.clear();
Ctx.unregisterInsertInstrCallback(InsertCbId);
Ctx.unregisterRemoveInstrCallback(RemoveCbId);
Ctx.unregisterMoveInstrCallback(MoveCbId);
auto *NewI2 = sandboxir::StoreInst::create(Val, Ptr, /*Align=*/std::nullopt,
Ret->getIterator(), Ctx);
Ret->moveBefore(NewI2);
Ret->eraseFromParent();
EXPECT_THAT(Inserted, testing::IsEmpty());
EXPECT_THAT(Removed, testing::IsEmpty());
EXPECT_THAT(Moved, testing::IsEmpty());
}
Loading