-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[SandboxVec][DAG] Register move instr callback #120146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch implements the move instruction notifier for the DAG. Whenever an instruction moves the notifier will maintain the DAG.
|
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-vectorizers Author: vporpo (vporpo) ChangesThis patch implements the move instruction notifier for the DAG. Whenever an instruction moves the notifier will maintain the DAG. Full diff: https://github.com/llvm/llvm-project/pull/120146.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index b1cad2421bc0d2..f423e1ee456cd1 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -220,6 +220,14 @@ class MemDGNode final : public DGNode {
void setNextNode(MemDGNode *N) { NextMemN = N; }
void setPrevNode(MemDGNode *N) { PrevMemN = N; }
friend class DependencyGraph; // For setNextNode(), setPrevNode().
+ void detachFromChain() {
+ if (PrevMemN != nullptr)
+ PrevMemN->NextMemN = NextMemN;
+ if (NextMemN != nullptr)
+ NextMemN->PrevMemN = PrevMemN;
+ PrevMemN = nullptr;
+ NextMemN = nullptr;
+ }
public:
MemDGNode(Instruction *I) : DGNode(I, DGNodeID::MemDGNode) {
@@ -293,6 +301,7 @@ class DependencyGraph {
Context *Ctx = nullptr;
std::optional<Context::CallbackID> CreateInstrCB;
std::optional<Context::CallbackID> EraseInstrCB;
+ std::optional<Context::CallbackID> MoveInstrCB;
std::unique_ptr<BatchAAResults> BatchAA;
@@ -343,6 +352,9 @@ class DependencyGraph {
/// Called by the callbacks when instruction \p I is about to get
/// deleted.
void notifyEraseInstr(Instruction *I);
+ /// Called by the callbacks when instruction \p I is about to be moved to
+ /// \p To.
+ void notifyMoveInstr(Instruction *I, const BBIterator &To);
public:
/// This constructor also registers callbacks.
@@ -352,12 +364,18 @@ class DependencyGraph {
[this](Instruction *I) { notifyCreateInstr(I); });
EraseInstrCB = Ctx.registerEraseInstrCallback(
[this](Instruction *I) { notifyEraseInstr(I); });
+ MoveInstrCB = Ctx.registerMoveInstrCallback(
+ [this](Instruction *I, const BBIterator &To) {
+ notifyMoveInstr(I, To);
+ });
}
~DependencyGraph() {
if (CreateInstrCB)
Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
if (EraseInstrCB)
Ctx->unregisterEraseInstrCallback(*EraseInstrCB);
+ if (MoveInstrCB)
+ Ctx->unregisterMoveInstrCallback(*MoveInstrCB);
}
DGNode *getNode(Instruction *I) const {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 25f2665d450d13..ba62c45a4e704e 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -370,6 +370,52 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) {
}
}
+void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
+ // Early return if `I` doesn't actually move.
+ BasicBlock *BB = To.getNodeParent();
+ if (To != BB->end() && &*To == I->getNextNode())
+ return;
+
+ // Maintain the DAGInterval.
+ DAGInterval.notifyMoveInstr(I, To);
+
+ // TODO: Perhaps check if this is legal by checking the dependencies?
+
+ // Update the MemDGNode chain to reflect the instr movement if necessary.
+ DGNode *N = getNodeOrNull(I);
+ if (N == nullptr)
+ return;
+ MemDGNode *MemN = dyn_cast<MemDGNode>(N);
+ if (MemN == nullptr)
+ return;
+ // First detach it from the existing chain.
+ MemN->detachFromChain();
+ // Now insert it back into the chain at the new location.
+ if (To != BB->end()) {
+ DGNode *ToN = getNodeOrNull(&*To);
+ if (ToN != nullptr) {
+ MemDGNode *PrevMemN = getMemDGNodeBefore(ToN, /*IncludingN=*/false);
+ MemDGNode *NextMemN = getMemDGNodeAfter(ToN, /*IncludingN=*/true);
+ MemN->PrevMemN = PrevMemN;
+ if (PrevMemN != nullptr)
+ PrevMemN->NextMemN = MemN;
+ MemN->NextMemN = NextMemN;
+ if (NextMemN != nullptr)
+ NextMemN->PrevMemN = MemN;
+ }
+ } else {
+ // MemN becomes the last instruction in the BB.
+ auto *TermN = getNodeOrNull(BB->getTerminator());
+ if (TermN != nullptr) {
+ MemDGNode *PrevMemN = getMemDGNodeBefore(TermN, /*IncludingN=*/false);
+ PrevMemN->NextMemN = MemN;
+ MemN->PrevMemN = PrevMemN;
+ } else {
+ // The terminator is outside the DAG interval so do nothing.
+ }
+ }
+}
+
void DependencyGraph::notifyEraseInstr(Instruction *I) {
// Update the MemDGNode chain if this is a memory node.
if (auto *MemN = dyn_cast_or_null<MemDGNode>(getNodeOrNull(I))) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 8c73ee1def8ae1..3fa4de501f3f5d 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -801,7 +801,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
TEST_F(DependencyGraphTest, CreateInstrCallback) {
parseIR(C, R"IR(
-define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
store i8 %v1, ptr %ptr
store i8 %v2, ptr %ptr
store i8 %v3, ptr %ptr
@@ -893,3 +893,36 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
// TODO: Check the dependencies to/from NewSN after they land.
}
+
+TEST_F(DependencyGraphTest, MoveInstrCallback) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+ %ld0 = load i8, ptr %ptr2
+ store i8 %v1, ptr %ptr
+ store i8 %v2, ptr %ptr
+ store i8 %v3, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *Ld = cast<sandboxir::LoadInst>(&*It++);
+ auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+ DAG.extend({Ld, S3});
+ auto *LdN = cast<sandboxir::MemDGNode>(DAG.getNode(Ld));
+ auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+ auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+ EXPECT_EQ(S1N->getPrevNode(), LdN);
+ S1->moveBefore(Ld);
+ EXPECT_EQ(S1N->getPrevNode(), nullptr);
+ EXPECT_EQ(S1N->getNextNode(), LdN);
+ EXPECT_EQ(LdN->getPrevNode(), S1N);
+ EXPECT_EQ(LdN->getNextNode(), S2N);
+}
|
slackito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
This patch implements the move instruction notifier for the DAG. Whenever an instruction moves the notifier will maintain the DAG.