From 26a263ae6232db7cf1592f51cec17cdfef7cd344 Mon Sep 17 00:00:00 2001 From: Vasileios Porpodas Date: Thu, 7 Nov 2024 12:46:53 -0800 Subject: [PATCH] [SandboxVec][DAG] Register move instr callback This patch implements the move instruction notifier for the DAG. Whenever an instruction moves the notifier will maintain the DAG. --- .../SandboxVectorizer/DependencyGraph.h | 18 ++++++++ .../SandboxVectorizer/DependencyGraph.cpp | 46 +++++++++++++++++++ .../SandboxVectorizer/DependencyGraphTest.cpp | 35 +++++++++++++- 3 files changed, 98 insertions(+), 1 deletion(-) diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index b1cad2421bc0d..f423e1ee456cd 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -220,6 +220,14 @@ class MemDGNode final : public DGNode { void setNextNode(MemDGNode *N) { NextMemN = N; } void setPrevNode(MemDGNode *N) { PrevMemN = N; } friend class DependencyGraph; // For setNextNode(), setPrevNode(). + void detachFromChain() { + if (PrevMemN != nullptr) + PrevMemN->NextMemN = NextMemN; + if (NextMemN != nullptr) + NextMemN->PrevMemN = PrevMemN; + PrevMemN = nullptr; + NextMemN = nullptr; + } public: MemDGNode(Instruction *I) : DGNode(I, DGNodeID::MemDGNode) { @@ -293,6 +301,7 @@ class DependencyGraph { Context *Ctx = nullptr; std::optional CreateInstrCB; std::optional EraseInstrCB; + std::optional MoveInstrCB; std::unique_ptr BatchAA; @@ -343,6 +352,9 @@ class DependencyGraph { /// Called by the callbacks when instruction \p I is about to get /// deleted. void notifyEraseInstr(Instruction *I); + /// Called by the callbacks when instruction \p I is about to be moved to + /// \p To. + void notifyMoveInstr(Instruction *I, const BBIterator &To); public: /// This constructor also registers callbacks. @@ -352,12 +364,18 @@ class DependencyGraph { [this](Instruction *I) { notifyCreateInstr(I); }); EraseInstrCB = Ctx.registerEraseInstrCallback( [this](Instruction *I) { notifyEraseInstr(I); }); + MoveInstrCB = Ctx.registerMoveInstrCallback( + [this](Instruction *I, const BBIterator &To) { + notifyMoveInstr(I, To); + }); } ~DependencyGraph() { if (CreateInstrCB) Ctx->unregisterCreateInstrCallback(*CreateInstrCB); if (EraseInstrCB) Ctx->unregisterEraseInstrCallback(*EraseInstrCB); + if (MoveInstrCB) + Ctx->unregisterMoveInstrCallback(*MoveInstrCB); } DGNode *getNode(Instruction *I) const { diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp index 25f2665d450d1..ba62c45a4e704 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp @@ -370,6 +370,52 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) { } } +void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) { + // Early return if `I` doesn't actually move. + BasicBlock *BB = To.getNodeParent(); + if (To != BB->end() && &*To == I->getNextNode()) + return; + + // Maintain the DAGInterval. + DAGInterval.notifyMoveInstr(I, To); + + // TODO: Perhaps check if this is legal by checking the dependencies? + + // Update the MemDGNode chain to reflect the instr movement if necessary. + DGNode *N = getNodeOrNull(I); + if (N == nullptr) + return; + MemDGNode *MemN = dyn_cast(N); + if (MemN == nullptr) + return; + // First detach it from the existing chain. + MemN->detachFromChain(); + // Now insert it back into the chain at the new location. + if (To != BB->end()) { + DGNode *ToN = getNodeOrNull(&*To); + if (ToN != nullptr) { + MemDGNode *PrevMemN = getMemDGNodeBefore(ToN, /*IncludingN=*/false); + MemDGNode *NextMemN = getMemDGNodeAfter(ToN, /*IncludingN=*/true); + MemN->PrevMemN = PrevMemN; + if (PrevMemN != nullptr) + PrevMemN->NextMemN = MemN; + MemN->NextMemN = NextMemN; + if (NextMemN != nullptr) + NextMemN->PrevMemN = MemN; + } + } else { + // MemN becomes the last instruction in the BB. + auto *TermN = getNodeOrNull(BB->getTerminator()); + if (TermN != nullptr) { + MemDGNode *PrevMemN = getMemDGNodeBefore(TermN, /*IncludingN=*/false); + PrevMemN->NextMemN = MemN; + MemN->PrevMemN = PrevMemN; + } else { + // The terminator is outside the DAG interval so do nothing. + } + } +} + void DependencyGraph::notifyEraseInstr(Instruction *I) { // Update the MemDGNode chain if this is a memory node. if (auto *MemN = dyn_cast_or_null(getNodeOrNull(I))) { diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp index 8c73ee1def8ae..3fa4de501f3f5 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp @@ -801,7 +801,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) { TEST_F(DependencyGraphTest, CreateInstrCallback) { parseIR(C, R"IR( -define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) { +define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) { store i8 %v1, ptr %ptr store i8 %v2, ptr %ptr store i8 %v3, ptr %ptr @@ -893,3 +893,36 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) { // TODO: Check the dependencies to/from NewSN after they land. } + +TEST_F(DependencyGraphTest, MoveInstrCallback) { + parseIR(C, R"IR( +define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) { + %ld0 = load i8, ptr %ptr2 + store i8 %v1, ptr %ptr + store i8 %v2, ptr %ptr + store i8 %v3, ptr %ptr + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + auto It = BB->begin(); + auto *Ld = cast(&*It++); + auto *S1 = cast(&*It++); + auto *S2 = cast(&*It++); + auto *S3 = cast(&*It++); + + sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx); + DAG.extend({Ld, S3}); + auto *LdN = cast(DAG.getNode(Ld)); + auto *S1N = cast(DAG.getNode(S1)); + auto *S2N = cast(DAG.getNode(S2)); + EXPECT_EQ(S1N->getPrevNode(), LdN); + S1->moveBefore(Ld); + EXPECT_EQ(S1N->getPrevNode(), nullptr); + EXPECT_EQ(S1N->getNextNode(), LdN); + EXPECT_EQ(LdN->getPrevNode(), S1N); + EXPECT_EQ(LdN->getNextNode(), S2N); +}