Skip to content

Commit 7a38445

Browse files
authored
[SandboxVec][DAG] Register move instr callback (#120146)
This patch implements the move instruction notifier for the DAG. Whenever an instruction moves the notifier will maintain the DAG.
1 parent 665d79f commit 7a38445

File tree

3 files changed

+98
-1
lines changed

3 files changed

+98
-1
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,14 @@ class MemDGNode final : public DGNode {
220220
void setNextNode(MemDGNode *N) { NextMemN = N; }
221221
void setPrevNode(MemDGNode *N) { PrevMemN = N; }
222222
friend class DependencyGraph; // For setNextNode(), setPrevNode().
223+
void detachFromChain() {
224+
if (PrevMemN != nullptr)
225+
PrevMemN->NextMemN = NextMemN;
226+
if (NextMemN != nullptr)
227+
NextMemN->PrevMemN = PrevMemN;
228+
PrevMemN = nullptr;
229+
NextMemN = nullptr;
230+
}
223231

224232
public:
225233
MemDGNode(Instruction *I) : DGNode(I, DGNodeID::MemDGNode) {
@@ -293,6 +301,7 @@ class DependencyGraph {
293301
Context *Ctx = nullptr;
294302
std::optional<Context::CallbackID> CreateInstrCB;
295303
std::optional<Context::CallbackID> EraseInstrCB;
304+
std::optional<Context::CallbackID> MoveInstrCB;
296305

297306
std::unique_ptr<BatchAAResults> BatchAA;
298307

@@ -343,6 +352,9 @@ class DependencyGraph {
343352
/// Called by the callbacks when instruction \p I is about to get
344353
/// deleted.
345354
void notifyEraseInstr(Instruction *I);
355+
/// Called by the callbacks when instruction \p I is about to be moved to
356+
/// \p To.
357+
void notifyMoveInstr(Instruction *I, const BBIterator &To);
346358

347359
public:
348360
/// This constructor also registers callbacks.
@@ -352,12 +364,18 @@ class DependencyGraph {
352364
[this](Instruction *I) { notifyCreateInstr(I); });
353365
EraseInstrCB = Ctx.registerEraseInstrCallback(
354366
[this](Instruction *I) { notifyEraseInstr(I); });
367+
MoveInstrCB = Ctx.registerMoveInstrCallback(
368+
[this](Instruction *I, const BBIterator &To) {
369+
notifyMoveInstr(I, To);
370+
});
355371
}
356372
~DependencyGraph() {
357373
if (CreateInstrCB)
358374
Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
359375
if (EraseInstrCB)
360376
Ctx->unregisterEraseInstrCallback(*EraseInstrCB);
377+
if (MoveInstrCB)
378+
Ctx->unregisterMoveInstrCallback(*MoveInstrCB);
361379
}
362380

363381
DGNode *getNode(Instruction *I) const {

llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,52 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) {
370370
}
371371
}
372372

373+
void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
374+
// Early return if `I` doesn't actually move.
375+
BasicBlock *BB = To.getNodeParent();
376+
if (To != BB->end() && &*To == I->getNextNode())
377+
return;
378+
379+
// Maintain the DAGInterval.
380+
DAGInterval.notifyMoveInstr(I, To);
381+
382+
// TODO: Perhaps check if this is legal by checking the dependencies?
383+
384+
// Update the MemDGNode chain to reflect the instr movement if necessary.
385+
DGNode *N = getNodeOrNull(I);
386+
if (N == nullptr)
387+
return;
388+
MemDGNode *MemN = dyn_cast<MemDGNode>(N);
389+
if (MemN == nullptr)
390+
return;
391+
// First detach it from the existing chain.
392+
MemN->detachFromChain();
393+
// Now insert it back into the chain at the new location.
394+
if (To != BB->end()) {
395+
DGNode *ToN = getNodeOrNull(&*To);
396+
if (ToN != nullptr) {
397+
MemDGNode *PrevMemN = getMemDGNodeBefore(ToN, /*IncludingN=*/false);
398+
MemDGNode *NextMemN = getMemDGNodeAfter(ToN, /*IncludingN=*/true);
399+
MemN->PrevMemN = PrevMemN;
400+
if (PrevMemN != nullptr)
401+
PrevMemN->NextMemN = MemN;
402+
MemN->NextMemN = NextMemN;
403+
if (NextMemN != nullptr)
404+
NextMemN->PrevMemN = MemN;
405+
}
406+
} else {
407+
// MemN becomes the last instruction in the BB.
408+
auto *TermN = getNodeOrNull(BB->getTerminator());
409+
if (TermN != nullptr) {
410+
MemDGNode *PrevMemN = getMemDGNodeBefore(TermN, /*IncludingN=*/false);
411+
PrevMemN->NextMemN = MemN;
412+
MemN->PrevMemN = PrevMemN;
413+
} else {
414+
// The terminator is outside the DAG interval so do nothing.
415+
}
416+
}
417+
}
418+
373419
void DependencyGraph::notifyEraseInstr(Instruction *I) {
374420
// Update the MemDGNode chain if this is a memory node.
375421
if (auto *MemN = dyn_cast_or_null<MemDGNode>(getNodeOrNull(I))) {

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
801801

802802
TEST_F(DependencyGraphTest, CreateInstrCallback) {
803803
parseIR(C, R"IR(
804-
define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
804+
define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
805805
store i8 %v1, ptr %ptr
806806
store i8 %v2, ptr %ptr
807807
store i8 %v3, ptr %ptr
@@ -893,3 +893,36 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
893893

894894
// TODO: Check the dependencies to/from NewSN after they land.
895895
}
896+
897+
TEST_F(DependencyGraphTest, MoveInstrCallback) {
898+
parseIR(C, R"IR(
899+
define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
900+
%ld0 = load i8, ptr %ptr2
901+
store i8 %v1, ptr %ptr
902+
store i8 %v2, ptr %ptr
903+
store i8 %v3, ptr %ptr
904+
ret void
905+
}
906+
)IR");
907+
llvm::Function *LLVMF = &*M->getFunction("foo");
908+
sandboxir::Context Ctx(C);
909+
auto *F = Ctx.createFunction(LLVMF);
910+
auto *BB = &*F->begin();
911+
auto It = BB->begin();
912+
auto *Ld = cast<sandboxir::LoadInst>(&*It++);
913+
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
914+
auto *S2 = cast<sandboxir::StoreInst>(&*It++);
915+
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
916+
917+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
918+
DAG.extend({Ld, S3});
919+
auto *LdN = cast<sandboxir::MemDGNode>(DAG.getNode(Ld));
920+
auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
921+
auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
922+
EXPECT_EQ(S1N->getPrevNode(), LdN);
923+
S1->moveBefore(Ld);
924+
EXPECT_EQ(S1N->getPrevNode(), nullptr);
925+
EXPECT_EQ(S1N->getNextNode(), LdN);
926+
EXPECT_EQ(LdN->getPrevNode(), S1N);
927+
EXPECT_EQ(LdN->getNextNode(), S2N);
928+
}

0 commit comments

Comments
 (0)