Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,14 @@ class MemDGNode final : public DGNode {
friend class PredIterator; // For MemPreds.
/// Creates both edges: this<->N.
void setNextNode(MemDGNode *N) {
assert(N != this && "About to point to self!");
NextMemN = N;
if (NextMemN != nullptr)
NextMemN->PrevMemN = this;
}
/// Creates both edges: N<->this.
void setPrevNode(MemDGNode *N) {
assert(N != this && "About to point to self!");
PrevMemN = N;
if (PrevMemN != nullptr)
PrevMemN->NextMemN = this;
Expand Down Expand Up @@ -348,13 +350,15 @@ class DependencyGraph {
void createNewNodes(const Interval<Instruction> &NewInterval);

/// Helper for `notify*Instr()`. \Returns the first MemDGNode that comes
/// before \p N, including or excluding \p N based on \p IncludingN, or
/// nullptr if not found.
MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN) const;
/// before \p N, skipping \p SkipN, including or excluding \p N based on
/// \p IncludingN, or nullptr if not found.
MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN,
MemDGNode *SkipN = nullptr) const;
/// Helper for `notifyMoveInstr()`. \Returns the first MemDGNode that comes
/// after \p N, including or excluding \p N based on \p IncludingN, or nullptr
/// if not found.
MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN) const;
/// after \p N, skipping \p SkipN, including or excluding \p N based on \p
/// IncludingN, or nullptr if not found.
MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN,
MemDGNode *SkipN = nullptr) const;

/// Called by the callbacks when a new instruction \p I has been created.
void notifyCreateInstr(Instruction *I);
Expand Down
70 changes: 50 additions & 20 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,29 +325,31 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
setDefUseUnscheduledSuccs(NewInterval);
}

MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N,
bool IncludingN) const {
MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N, bool IncludingN,
MemDGNode *SkipN) const {
auto *I = N->getInstruction();
for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
PrevI = PrevI->getPrevNode()) {
auto *PrevN = getNodeOrNull(PrevI);
if (PrevN == nullptr)
return nullptr;
if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN))
auto *PrevMemN = dyn_cast<MemDGNode>(PrevN);
if (PrevMemN != nullptr && PrevMemN != SkipN)
return PrevMemN;
}
return nullptr;
}

MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N,
bool IncludingN) const {
MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N, bool IncludingN,
MemDGNode *SkipN) const {
auto *I = N->getInstruction();
for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
NextI = NextI->getNextNode()) {
auto *NextN = getNodeOrNull(NextI);
if (NextN == nullptr)
return nullptr;
if (auto *NextMemN = dyn_cast<MemDGNode>(NextN))
auto *NextMemN = dyn_cast<MemDGNode>(NextN);
if (NextMemN != nullptr && NextMemN != SkipN)
return NextMemN;
}
return nullptr;
Expand Down Expand Up @@ -377,6 +379,20 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
!(To == BB->end() && std::next(I->getIterator()) == BB->end()) &&
"Should not have been called if destination is same as origin.");

// TODO: We can only handle fully internal movements within DAGInterval or at
// the borders, i.e., right before the top or right after the bottom.
assert(To.getNodeParent() == I->getParent() &&
"TODO: We don't support movement across BBs!");
assert(
(To == std::next(DAGInterval.bottom()->getIterator()) ||
(To != BB->end() && std::next(To) == DAGInterval.top()->getIterator()) ||
(To != BB->end() && DAGInterval.contains(&*To))) &&
"TODO: To should be either within the DAGInterval or right "
"before/after it.");

// Make a copy of the DAGInterval before we update it.
auto OrigDAGInterval = DAGInterval;

// Maintain the DAGInterval.
DAGInterval.notifyMoveInstr(I, To);

Expand All @@ -389,23 +405,37 @@ void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
MemDGNode *MemN = dyn_cast<MemDGNode>(N);
if (MemN == nullptr)
return;
// First detach it from the existing chain.

// First safely detach it from the existing chain.
MemN->detachFromChain();

// Now insert it back into the chain at the new location.
if (To != BB->end()) {
DGNode *ToN = getNodeOrNull(&*To);
if (ToN != nullptr) {
MemN->setPrevNode(getMemDGNodeBefore(ToN, /*IncludingN=*/false));
MemN->setNextNode(getMemDGNodeAfter(ToN, /*IncludingN=*/true));
}
//
// We won't always have a DGNode to insert before it. If `To` is BB->end() or
// if it points to an instr after DAGInterval.bottom() then we will have to
// find a node to insert *after*.
//
// BB: BB:
// I1 I1 ^
// I2 I2 | DAGInteval [I1 to I3]
// I3 I3 V
// I4 I4 <- `To` == right after DAGInterval
// <- `To` == BB->end()
//
if (To == BB->end() ||
To == std::next(OrigDAGInterval.bottom()->getIterator())) {
// If we don't have a node to insert before, find a node to insert after and
// update the chain.
DGNode *InsertAfterN = getNode(&*std::prev(To));
MemN->setPrevNode(
getMemDGNodeBefore(InsertAfterN, /*IncludingN=*/true, /*SkipN=*/MemN));
} else {
// MemN becomes the last instruction in the BB.
auto *TermN = getNodeOrNull(BB->getTerminator());
if (TermN != nullptr) {
MemN->setPrevNode(getMemDGNodeBefore(TermN, /*IncludingN=*/false));
} else {
// The terminator is outside the DAG interval so do nothing.
}
// We have a node to insert before, so update the chain.
DGNode *BeforeToN = getNode(&*To);
MemN->setPrevNode(
getMemDGNodeBefore(BeforeToN, /*IncludingN=*/false, /*SkipN=*/MemN));
MemN->setNextNode(
getMemDGNodeAfter(BeforeToN, /*IncludingN=*/true, /*SkipN=*/MemN));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -926,3 +926,46 @@ define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
EXPECT_EQ(LdN->getPrevNode(), S1N);
EXPECT_EQ(LdN->getNextNode(), S2N);
}

// Check that the mem chain is maintained correctly when the move destination is
// not a mem node.
TEST_F(DependencyGraphTest, MoveInstrCallbackWithNonMemInstrs) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %arg) {
%ld = load i8, ptr %ptr
%zext1 = zext i8 %arg to i32
%zext2 = zext i8 %arg to i32
store i8 %v1, ptr %ptr
store i8 %v2, ptr %ptr
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *Ld = cast<sandboxir::LoadInst>(&*It++);
[[maybe_unused]] auto *Zext1 = cast<sandboxir::CastInst>(&*It++);
auto *Zext2 = cast<sandboxir::CastInst>(&*It++);
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *S2 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
DAG.extend({Ld, S2});
auto *LdN = cast<sandboxir::MemDGNode>(DAG.getNode(Ld));
auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
EXPECT_EQ(LdN->getNextNode(), S1N);
EXPECT_EQ(S1N->getNextNode(), S2N);

S1->moveBefore(Zext2);
EXPECT_EQ(LdN->getNextNode(), S1N);
EXPECT_EQ(S1N->getNextNode(), S2N);

// Try move right after the end of the DAGInterval.
S1->moveBefore(Ret);
EXPECT_EQ(S2N->getNextNode(), S1N);
EXPECT_EQ(S1N->getNextNode(), nullptr);
}
Loading