Skip to content

Conversation

@vporpo
Copy link
Contributor

@vporpo vporpo commented Dec 16, 2024

This patch implements the move instruction notifier for the DAG. Whenever an instruction moves the notifier will maintain the DAG.

This patch implements the move instruction notifier for the DAG.
Whenever an instruction moves the notifier will maintain the DAG.
@llvmbot
Copy link
Member

llvmbot commented Dec 16, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-vectorizers

Author: vporpo (vporpo)

Changes

This patch implements the move instruction notifier for the DAG. Whenever an instruction moves the notifier will maintain the DAG.


Full diff: https://github.com/llvm/llvm-project/pull/120146.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+18)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+46)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+34-1)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index b1cad2421bc0d2..f423e1ee456cd1 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -220,6 +220,14 @@ class MemDGNode final : public DGNode {
   void setNextNode(MemDGNode *N) { NextMemN = N; }
   void setPrevNode(MemDGNode *N) { PrevMemN = N; }
   friend class DependencyGraph; // For setNextNode(), setPrevNode().
+  void detachFromChain() {
+    if (PrevMemN != nullptr)
+      PrevMemN->NextMemN = NextMemN;
+    if (NextMemN != nullptr)
+      NextMemN->PrevMemN = PrevMemN;
+    PrevMemN = nullptr;
+    NextMemN = nullptr;
+  }
 
 public:
   MemDGNode(Instruction *I) : DGNode(I, DGNodeID::MemDGNode) {
@@ -293,6 +301,7 @@ class DependencyGraph {
   Context *Ctx = nullptr;
   std::optional<Context::CallbackID> CreateInstrCB;
   std::optional<Context::CallbackID> EraseInstrCB;
+  std::optional<Context::CallbackID> MoveInstrCB;
 
   std::unique_ptr<BatchAAResults> BatchAA;
 
@@ -343,6 +352,9 @@ class DependencyGraph {
   /// Called by the callbacks when instruction \p I is about to get
   /// deleted.
   void notifyEraseInstr(Instruction *I);
+  /// Called by the callbacks when instruction \p I is about to be moved to
+  /// \p To.
+  void notifyMoveInstr(Instruction *I, const BBIterator &To);
 
 public:
   /// This constructor also registers callbacks.
@@ -352,12 +364,18 @@ class DependencyGraph {
         [this](Instruction *I) { notifyCreateInstr(I); });
     EraseInstrCB = Ctx.registerEraseInstrCallback(
         [this](Instruction *I) { notifyEraseInstr(I); });
+    MoveInstrCB = Ctx.registerMoveInstrCallback(
+        [this](Instruction *I, const BBIterator &To) {
+          notifyMoveInstr(I, To);
+        });
   }
   ~DependencyGraph() {
     if (CreateInstrCB)
       Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
     if (EraseInstrCB)
       Ctx->unregisterEraseInstrCallback(*EraseInstrCB);
+    if (MoveInstrCB)
+      Ctx->unregisterMoveInstrCallback(*MoveInstrCB);
   }
 
   DGNode *getNode(Instruction *I) const {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 25f2665d450d13..ba62c45a4e704e 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -370,6 +370,52 @@ void DependencyGraph::notifyCreateInstr(Instruction *I) {
   }
 }
 
+void DependencyGraph::notifyMoveInstr(Instruction *I, const BBIterator &To) {
+  // Early return if `I` doesn't actually move.
+  BasicBlock *BB = To.getNodeParent();
+  if (To != BB->end() && &*To == I->getNextNode())
+    return;
+
+  // Maintain the DAGInterval.
+  DAGInterval.notifyMoveInstr(I, To);
+
+  // TODO: Perhaps check if this is legal by checking the dependencies?
+
+  // Update the MemDGNode chain to reflect the instr movement if necessary.
+  DGNode *N = getNodeOrNull(I);
+  if (N == nullptr)
+    return;
+  MemDGNode *MemN = dyn_cast<MemDGNode>(N);
+  if (MemN == nullptr)
+    return;
+  // First detach it from the existing chain.
+  MemN->detachFromChain();
+  // Now insert it back into the chain at the new location.
+  if (To != BB->end()) {
+    DGNode *ToN = getNodeOrNull(&*To);
+    if (ToN != nullptr) {
+      MemDGNode *PrevMemN = getMemDGNodeBefore(ToN, /*IncludingN=*/false);
+      MemDGNode *NextMemN = getMemDGNodeAfter(ToN, /*IncludingN=*/true);
+      MemN->PrevMemN = PrevMemN;
+      if (PrevMemN != nullptr)
+        PrevMemN->NextMemN = MemN;
+      MemN->NextMemN = NextMemN;
+      if (NextMemN != nullptr)
+        NextMemN->PrevMemN = MemN;
+    }
+  } else {
+    // MemN becomes the last instruction in the BB.
+    auto *TermN = getNodeOrNull(BB->getTerminator());
+    if (TermN != nullptr) {
+      MemDGNode *PrevMemN = getMemDGNodeBefore(TermN, /*IncludingN=*/false);
+      PrevMemN->NextMemN = MemN;
+      MemN->PrevMemN = PrevMemN;
+    } else {
+      // The terminator is outside the DAG interval so do nothing.
+    }
+  }
+}
+
 void DependencyGraph::notifyEraseInstr(Instruction *I) {
   // Update the MemDGNode chain if this is a memory node.
   if (auto *MemN = dyn_cast_or_null<MemDGNode>(getNodeOrNull(I))) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index 8c73ee1def8ae1..3fa4de501f3f5d 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -801,7 +801,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
 
 TEST_F(DependencyGraphTest, CreateInstrCallback) {
   parseIR(C, R"IR(
-define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+define void @foo(ptr %ptr, ptr noalias %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
   store i8 %v1, ptr %ptr
   store i8 %v2, ptr %ptr
   store i8 %v3, ptr %ptr
@@ -893,3 +893,36 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
 
   // TODO: Check the dependencies to/from NewSN after they land.
 }
+
+TEST_F(DependencyGraphTest, MoveInstrCallback) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, ptr %ptr2, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
+  %ld0 = load i8, ptr %ptr2
+  store i8 %v1, ptr %ptr
+  store i8 %v2, ptr %ptr
+  store i8 %v3, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *Ld = cast<sandboxir::LoadInst>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+  DAG.extend({Ld, S3});
+  auto *LdN = cast<sandboxir::MemDGNode>(DAG.getNode(Ld));
+  auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
+  auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+  EXPECT_EQ(S1N->getPrevNode(), LdN);
+  S1->moveBefore(Ld);
+  EXPECT_EQ(S1N->getPrevNode(), nullptr);
+  EXPECT_EQ(S1N->getNextNode(), LdN);
+  EXPECT_EQ(LdN->getPrevNode(), S1N);
+  EXPECT_EQ(LdN->getNextNode(), S2N);
+}

Copy link
Collaborator

@slackito slackito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@vporpo vporpo merged commit 7a38445 into llvm:main Dec 21, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants