-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[SandboxVec][DAG] Fix MemDGNode chain maintenance when move destination is non-mem #124227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…on is non mem This patch fixes a bug in the maintenance of the MemDGNode chain of the DAG. Whenever we move a memory instruction, the DAG gets notified about the move and maintains the chain of memory nodes. The bug was that if the destination of the move was not a memory instruction, then the memory node's next node would end up pointing to itself.
|
@llvm/pr-subscribers-llvm-transforms Author: vporpo (vporpo) Changes…on is non mem This patch fixes a bug in the maintenance of the MemDGNode chain of the DAG. Whenever we move a memory instruction, the DAG gets notified about the move and maintains the chain of memory nodes. The bug was that if the destination of the move was not a memory instruction, then the memory node's next node would end up pointing to itself. Full diff: https://github.com/llvm/llvm-project/pull/124227.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index b2d7c9b8aa8bbc..6e3f99d78b9329 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -218,12 +218,14 @@ class MemDGNode final : public DGNode {
friend class PredIterator; // For MemPreds.
/// Creates both edges: this<->N.
void setNextNode(MemDGNode *N) {
+ assert(N != this && "About to point to self!");
NextMemN = N;
if (NextMemN != nullptr)
NextMemN->PrevMemN = this;
}
/// Creates both edges: N<->this.
void setPrevNode(MemDGNode *N) {
+ assert(N != this && "About to point to self!");
PrevMemN = N;
if (PrevMemN != nullptr)
PrevMemN->NextMemN = this;
@@ -348,13 +350,15 @@ class DependencyGraph {
void createNewNodes(const Interval<Instruction> &NewInterval);
/// Helper for `notify*Instr()`. \Returns the first MemDGNode that comes
- /// before \p N, including or excluding \p N based on \p IncludingN, or
- /// nullptr if not found.
- MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN) const;
+ /// before \p N, skipping \p SkipN, including or excluding \p N based on
+ /// \p IncludingN, or nullptr if not found.
+ MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN = nullptr) const;
/// Helper for `notifyMoveInstr()`. \Returns the first MemDGNode that comes
- /// after \p N, including or excluding \p N based on \p IncludingN, or nullptr
- /// if not found.
- MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN) const;
+ /// after \p N, skipping \p SkipN, including or excluding \p N based on \p
+ /// IncludingN, or nullptr if not found.
+ MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN = nullptr) const;
/// Called by the callbacks when a new instruction \p I has been created.
void notifyCreateInstr(Instruction *I);
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index f080111f08d45e..390a5e9688cc78 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -325,29 +325,31 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
setDefUseUnscheduledSuccs(NewInterval);
}
-MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N,
- bool IncludingN) const {
+MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN) const {
auto *I = N->getInstruction();
for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
PrevI = PrevI->getPrevNode()) {
auto *PrevN = getNodeOrNull(PrevI);
if (PrevN == nullptr)
return nullptr;
- if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN))
+ auto *PrevMemN = dyn_cast<MemDGNode>(PrevN);
+ if (PrevMemN != nullptr && PrevMemN != SkipN)
return PrevMemN;
}
return nullptr;
}
-MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N,
- bool IncludingN) const {
+MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN) const {
auto *I = N->getInstruction();
for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
NextI = NextI->getNextNode()) {
auto *NextN = getNodeOrNull(NextI);
if (NextN == nullptr)
return nullptr;
- if (auto *NextMemN = dyn_cast<MemDGNode>(NextN))
+ auto *NextMemN = dyn_cast<MemDGNode>(NextN);
+ if (NextMemN != nullptr && NextMemN != SkipN)
return NextMemN;
}
return nullptr;
@@ -377,6 +379,20 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
!(To == BB->end() && std::next(I->getIterator()) == BB->end()) &&
"Should not have been called if destination is same as origin.");
+ // TODO: We can only handle fully internal movements within DAGInterval or at
+ // the borders, i.e., right before the top or right after the bottom.
+ assert(To.getNodeParent() == I->getParent() &&
+ "TODO: We don't support movement across BBs!");
+ assert(
+ (To == std::next(DAGInterval.bottom()->getIterator()) ||
+ (To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) ||
+ (To != BB->end() && DAGInterval.contains(&*To))) &&
+ "TODO: To should be either within the DAGInterval or right "
+ "before/after it.");
+
+ // Make a copy of the DAGInterval before we update it.
+ auto OrigDAGInterval = DAGInterval;
+
// Maintain the DAGInterval.
DAGInterval.notifyMoveInstr(I, To);
@@ -389,23 +405,37 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
MemDGNode *MemN = dyn_cast<MemDGNode>(N);
if (MemN == nullptr)
return;
- // First detach it from the existing chain.
+
+ // First safely detach it from the existing chain.
MemN->detachFromChain();
+
// Now insert it back into the chain at the new location.
- if (To != BB->end()) {
- DGNode *ToN = getNodeOrNull(&*To);
- if (ToN != nullptr) {
- MemN->setPrevNode(getMemDGNodeBefore(ToN, /*IncludingN=*/false));
- MemN->setNextNode(getMemDGNodeAfter(ToN, /*IncludingN=*/true));
- }
+ //
+ // We won't always have a DGNode to insert before it. If `To` is BB->end() or
+ // if it points to an instr after DAGInterval.bottom() then we will have to
+ // find a node to insert *after*.
+ //
+ // BB: BB:
+ // I1 I1 ^
+ // I2 I2 | DAGInteval [I1 to I3]
+ // I3 I3 V
+ // I4 I4 <- `To` == right after DAGInterval
+ // <- `To` == BB->end()
+ //
+ if (To == BB->end() ||
+ To == std::next(OrigDAGInterval.bottom()->getIterator())) {
+ // If we don't have a node to insert before, find a node to insert after and
+ // update the chain.
+ DGNode *InsertAfterN = getNode(&*std::prev(To));
+ MemN->setPrevNode(
+ getMemDGNodeBefore(InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN));
} else {
- // MemN becomes the last instruction in the BB.
- auto *TermN = getNodeOrNull(BB->getTerminator());
- if (TermN != nullptr) {
- MemN->setPrevNode(getMemDGNodeBefore(TermN, /*IncludingN=*/false));
- } else {
- // The terminator is outside the DAG interval so do nothing.
- }
+ // We have a node to insert before, so update the chain.
+ DGNode *BeforeToN = getNode(&*To);
+ MemN->setPrevNode(
+ getMemDGNodeBefore(BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN));
+ MemN->setNextNode(
+ getMemDGNodeAfter(BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN));
}
}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 3fa4de501f3f5d..29fc05a7f256a2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -926,3 +926,46 @@ define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
EXPECT_EQ(LdN->getPrevNode(), S1N);
EXPECT_EQ(LdN->getNextNode(), S2N);
}
+
+// Check that the mem chain is maintained correctly when the move destination is
+// not a mem node.
+TEST_F(DependencyGraphTest, MoveInstrCallbackWithNonMemInstrs) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %arg) {
+ %ld = load i8, ptr %ptr
+ %zext1 = zext i8 %arg to i32
+ %zext2 = zext i8 %arg to i32
+ store i8 %v1, ptr %ptr
+ store i8 %v2, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *Ld = cast<sandboxir::LoadInst>(&*It++);
+ [[maybe_unused]] auto *Zext1 = cast<sandboxir::CastInst>(&*It++);
+ auto *Zext2 = cast<sandboxir::CastInst>(&*It++);
+ auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+ auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+ DAG.extend({Ld, S2});
+ auto *LdN = cast<sandboxir::MemDGNode>(DAG.getNode(Ld));
+ auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+ auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+ EXPECT_EQ(LdN->getNextNode(), S1N);
+ EXPECT_EQ(S1N->getNextNode(), S2N);
+
+ S1->moveBefore(Zext2);
+ EXPECT_EQ(LdN->getNextNode(), S1N);
+ EXPECT_EQ(S1N->getNextNode(), S2N);
+
+ // Try move right after the end of the DAGInterval.
+ S1->moveBefore(Ret);
+ EXPECT_EQ(S2N->getNextNode(), S1N);
+ EXPECT_EQ(S1N->getNextNode(), nullptr);
+}
|
|
@llvm/pr-subscribers-vectorizers Author: vporpo (vporpo) Changes…on is non mem This patch fixes a bug in the maintenance of the MemDGNode chain of the DAG. Whenever we move a memory instruction, the DAG gets notified about the move and maintains the chain of memory nodes. The bug was that if the destination of the move was not a memory instruction, then the memory node's next node would end up pointing to itself. Full diff: https://github.com/llvm/llvm-project/pull/124227.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index b2d7c9b8aa8bbc..6e3f99d78b9329 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -218,12 +218,14 @@ class MemDGNode final : public DGNode {
friend class PredIterator; // For MemPreds.
/// Creates both edges: this<->N.
void setNextNode(MemDGNode *N) {
+ assert(N != this && "About to point to self!");
NextMemN = N;
if (NextMemN != nullptr)
NextMemN->PrevMemN = this;
}
/// Creates both edges: N<->this.
void setPrevNode(MemDGNode *N) {
+ assert(N != this && "About to point to self!");
PrevMemN = N;
if (PrevMemN != nullptr)
PrevMemN->NextMemN = this;
@@ -348,13 +350,15 @@ class DependencyGraph {
void createNewNodes(const Interval<Instruction> &NewInterval);
/// Helper for `notify*Instr()`. \Returns the first MemDGNode that comes
- /// before \p N, including or excluding \p N based on \p IncludingN, or
- /// nullptr if not found.
- MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN) const;
+ /// before \p N, skipping \p SkipN, including or excluding \p N based on
+ /// \p IncludingN, or nullptr if not found.
+ MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN = nullptr) const;
/// Helper for `notifyMoveInstr()`. \Returns the first MemDGNode that comes
- /// after \p N, including or excluding \p N based on \p IncludingN, or nullptr
- /// if not found.
- MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN) const;
+ /// after \p N, skipping \p SkipN, including or excluding \p N based on \p
+ /// IncludingN, or nullptr if not found.
+ MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN = nullptr) const;
/// Called by the callbacks when a new instruction \p I has been created.
void notifyCreateInstr(Instruction *I);
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index f080111f08d45e..390a5e9688cc78 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -325,29 +325,31 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
setDefUseUnscheduledSuccs(NewInterval);
}
-MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N,
- bool IncludingN) const {
+MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN) const {
auto *I = N->getInstruction();
for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
PrevI = PrevI->getPrevNode()) {
auto *PrevN = getNodeOrNull(PrevI);
if (PrevN == nullptr)
return nullptr;
- if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN))
+ auto *PrevMemN = dyn_cast<MemDGNode>(PrevN);
+ if (PrevMemN != nullptr && PrevMemN != SkipN)
return PrevMemN;
}
return nullptr;
}
-MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N,
- bool IncludingN) const {
+MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
+ MemDGNode *SkipN) const {
auto *I = N->getInstruction();
for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
NextI = NextI->getNextNode()) {
auto *NextN = getNodeOrNull(NextI);
if (NextN == nullptr)
return nullptr;
- if (auto *NextMemN = dyn_cast<MemDGNode>(NextN))
+ auto *NextMemN = dyn_cast<MemDGNode>(NextN);
+ if (NextMemN != nullptr && NextMemN != SkipN)
return NextMemN;
}
return nullptr;
@@ -377,6 +379,20 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
!(To == BB->end() && std::next(I->getIterator()) == BB->end()) &&
"Should not have been called if destination is same as origin.");
+ // TODO: We can only handle fully internal movements within DAGInterval or at
+ // the borders, i.e., right before the top or right after the bottom.
+ assert(To.getNodeParent() == I->getParent() &&
+ "TODO: We don't support movement across BBs!");
+ assert(
+ (To == std::next(DAGInterval.bottom()->getIterator()) ||
+ (To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) ||
+ (To != BB->end() && DAGInterval.contains(&*To))) &&
+ "TODO: To should be either within the DAGInterval or right "
+ "before/after it.");
+
+ // Make a copy of the DAGInterval before we update it.
+ auto OrigDAGInterval = DAGInterval;
+
// Maintain the DAGInterval.
DAGInterval.notifyMoveInstr(I, To);
@@ -389,23 +405,37 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
MemDGNode *MemN = dyn_cast<MemDGNode>(N);
if (MemN == nullptr)
return;
- // First detach it from the existing chain.
+
+ // First safely detach it from the existing chain.
MemN->detachFromChain();
+
// Now insert it back into the chain at the new location.
- if (To != BB->end()) {
- DGNode *ToN = getNodeOrNull(&*To);
- if (ToN != nullptr) {
- MemN->setPrevNode(getMemDGNodeBefore(ToN, /*IncludingN=*/false));
- MemN->setNextNode(getMemDGNodeAfter(ToN, /*IncludingN=*/true));
- }
+ //
+ // We won't always have a DGNode to insert before it. If `To` is BB->end() or
+ // if it points to an instr after DAGInterval.bottom() then we will have to
+ // find a node to insert *after*.
+ //
+ // BB: BB:
+ // I1 I1 ^
+ // I2 I2 | DAGInteval [I1 to I3]
+ // I3 I3 V
+ // I4 I4 <- `To` == right after DAGInterval
+ // <- `To` == BB->end()
+ //
+ if (To == BB->end() ||
+ To == std::next(OrigDAGInterval.bottom()->getIterator())) {
+ // If we don't have a node to insert before, find a node to insert after and
+ // update the chain.
+ DGNode *InsertAfterN = getNode(&*std::prev(To));
+ MemN->setPrevNode(
+ getMemDGNodeBefore(InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN));
} else {
- // MemN becomes the last instruction in the BB.
- auto *TermN = getNodeOrNull(BB->getTerminator());
- if (TermN != nullptr) {
- MemN->setPrevNode(getMemDGNodeBefore(TermN, /*IncludingN=*/false));
- } else {
- // The terminator is outside the DAG interval so do nothing.
- }
+ // We have a node to insert before, so update the chain.
+ DGNode *BeforeToN = getNode(&*To);
+ MemN->setPrevNode(
+ getMemDGNodeBefore(BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN));
+ MemN->setNextNode(
+ getMemDGNodeAfter(BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN));
}
}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 3fa4de501f3f5d..29fc05a7f256a2 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -926,3 +926,46 @@ define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
EXPECT_EQ(LdN->getPrevNode(), S1N);
EXPECT_EQ(LdN->getNextNode(), S2N);
}
+
+// Check that the mem chain is maintained correctly when the move destination is
+// not a mem node.
+TEST_F(DependencyGraphTest, MoveInstrCallbackWithNonMemInstrs) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %arg) {
+ %ld = load i8, ptr %ptr
+ %zext1 = zext i8 %arg to i32
+ %zext2 = zext i8 %arg to i32
+ store i8 %v1, ptr %ptr
+ store i8 %v2, ptr %ptr
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ auto *Ld = cast<sandboxir::LoadInst>(&*It++);
+ [[maybe_unused]] auto *Zext1 = cast<sandboxir::CastInst>(&*It++);
+ auto *Zext2 = cast<sandboxir::CastInst>(&*It++);
+ auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+ auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+ auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+ DAG.extend({Ld, S2});
+ auto *LdN = cast<sandboxir::MemDGNode>(DAG.getNode(Ld));
+ auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+ auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+ EXPECT_EQ(LdN->getNextNode(), S1N);
+ EXPECT_EQ(S1N->getNextNode(), S2N);
+
+ S1->moveBefore(Zext2);
+ EXPECT_EQ(LdN->getNextNode(), S1N);
+ EXPECT_EQ(S1N->getNextNode(), S2N);
+
+ // Try move right after the end of the DAGInterval.
+ S1->moveBefore(Ret);
+ EXPECT_EQ(S2N->getNextNode(), S1N);
+ EXPECT_EQ(S1N->getNextNode(), nullptr);
+}
|
This patch fixes a bug in the maintenance of the MemDGNode chain of the DAG. Whenever we move a memory instruction, the DAG gets notified about the move and maintains the chain of memory nodes. The bug was that if the destination of the move was not a memory instruction, then the memory node's next node would end up pointing to itself.