Skip to content

Conversation

@vporpo
Copy link
Contributor

@vporpo vporpo commented Nov 20, 2024

The DAG maintains a chain of MemDGNodes that links together all the nodes that may touch memroy.
Whenever a new instruction gets created we need to make sure that this chain gets updated. If the new instruction touches memory then its corresponding MemDGNode should be inserted into the chain.

The DAG maintains a chain of MemDGNodes that links together all the
nodes that may touch memroy.
Whenever a new instruction gets created we need to make sure that this
chain gets updated. If the new instruction touches memory then its
corresponding MemDGNode should be inserted into the chain.
@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2024

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

Changes

The DAG maintains a chain of MemDGNodes that links together all the nodes that may touch memroy.
Whenever a new instruction gets created we need to make sure that this chain gets updated. If the new instruction touches memory then its corresponding MemDGNode should be inserted into the chain.


Full diff: https://github.com/llvm/llvm-project/pull/116896.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h (+12-6)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+45)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+33-8)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 68a2daca1403df..911ee3e839521c 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -329,13 +329,19 @@ class DependencyGraph {
   /// chain.
   void createNewNodes(const Interval<Instruction> &NewInterval);
 
+  /// Helper for `notify*Instr()`. \Returns the first MemDGNode that comes
+  /// before \p N, including or excluding \p N based on \p IncludingN, or
+  /// nullptr if not found.
+  MemDGNode *getMemDGNodeBefore(DGNode *N, bool IncludingN) const;
+  /// Helper for `notifyMoveInstr()`. \Returns the first MemDGNode that comes
+  /// after \p N, including or excluding \p N based on \p IncludingN, or nullptr
+  /// if not found.
+  MemDGNode *getMemDGNodeAfter(DGNode *N, bool IncludingN) const;
+
   /// Called by the callbacks when a new instruction \p I has been created.
-  void notifyCreateInstr(Instruction *I) {
-    getOrCreateNode(I);
-    // TODO: Update the dependencies for the new node.
-    // TODO: Update the MemDGNode chain to include the new node if needed.
-  }
-  /// Called by the callbacks when instruction \p I is about to get deleted.
+  void notifyCreateInstr(Instruction *I);
+  /// Called by the callbacks when instruction \p I is about to get
+  /// deleted.
   void notifyEraseInstr(Instruction *I) {
     InstrToNodeMap.erase(I);
     // TODO: Update the dependencies.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 4b0e12c28f07b7..5cf44ba9dcbaaa 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -325,6 +325,51 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
   setDefUseUnscheduledSuccs(NewInterval);
 }
 
+MemDGNode *DependencyGraph::getMemDGNodeBefore(DGNode *N,
+                                               bool IncludingN) const {
+  auto *I = N->getInstruction();
+  for (auto *PrevI = IncludingN ? I : I->getPrevNode(); PrevI != nullptr;
+       PrevI = PrevI->getPrevNode()) {
+    auto *PrevN = getNodeOrNull(PrevI);
+    if (PrevN == nullptr)
+      return nullptr;
+    if (auto *PrevMemN = dyn_cast<MemDGNode>(PrevN))
+      return PrevMemN;
+  }
+  return nullptr;
+}
+
+MemDGNode *DependencyGraph::getMemDGNodeAfter(DGNode *N,
+                                              bool IncludingN) const {
+  auto *I = N->getInstruction();
+  for (auto *NextI = IncludingN ? I : I->getNextNode(); NextI != nullptr;
+       NextI = NextI->getNextNode()) {
+    auto *NextN = getNodeOrNull(NextI);
+    if (NextN == nullptr)
+      return nullptr;
+    if (auto *NextMemN = dyn_cast<MemDGNode>(NextN))
+      return NextMemN;
+  }
+  return nullptr;
+}
+
+void DependencyGraph::notifyCreateInstr(Instruction *I) {
+  auto *MemN = dyn_cast<MemDGNode>(getOrCreateNode(I));
+  // TODO: Update the dependencies for the new node.
+
+  // Update the MemDGNode chain if this is a memory node.
+  if (MemN != nullptr) {
+    if (auto *PrevMemN = getMemDGNodeBefore(MemN, /*IncludingN=*/false)) {
+      PrevMemN->NextMemN = MemN;
+      MemN->PrevMemN = PrevMemN;
+    }
+    if (auto *NextMemN = getMemDGNodeAfter(MemN, /*IncludingN=*/false)) {
+      NextMemN->PrevMemN = MemN;
+      MemN->NextMemN = NextMemN;
+    }
+  }
+}
+
 Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
   if (Instrs.empty())
     return {};
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index e6bb4b4684d262..1130c9c63c71da 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -814,21 +814,46 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
   auto *BB = &*F->begin();
   auto It = BB->begin();
   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
-  [[maybe_unused]] auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
   auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
   // Check new instruction callback.
   sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
-  DAG.extend({S1, S3});
+  DAG.extend({S1, Ret});
   auto *Arg = F->getArg(3);
   auto *Ptr = S1->getPointerOperand();
-  sandboxir::StoreInst *NewS =
-      sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
-                                   /*IsVolatile=*/true, Ctx);
-  auto *NewSN = DAG.getNode(NewS);
-  EXPECT_TRUE(NewSN != nullptr);
+  {
+    sandboxir::StoreInst *NewS =
+        sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
+                                     /*IsVolatile=*/true, Ctx);
+    auto *NewSN = DAG.getNode(NewS);
+    EXPECT_TRUE(NewSN != nullptr);
+
+    // Check the MemDGNode chain.
+    auto *S2MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
+    auto *NewMemSN = cast<sandboxir::MemDGNode>(NewSN);
+    auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+    EXPECT_EQ(S2MemN->getNextNode(), NewMemSN);
+    EXPECT_EQ(NewMemSN->getPrevNode(), S2MemN);
+    EXPECT_EQ(NewMemSN->getNextNode(), S3MemN);
+    EXPECT_EQ(S3MemN->getPrevNode(), NewMemSN);
+  }
+
+  {
+    // Also check if new node is at the end of the BB, after Ret.
+    sandboxir::StoreInst *NewS =
+        sandboxir::StoreInst::create(Arg, Ptr, Align(8), BB->end(),
+                                     /*IsVolatile=*/true, Ctx);
+    // Check the MemDGNode chain.
+    auto *S3MemN = cast<sandboxir::MemDGNode>(DAG.getNode(S3));
+    auto *NewMemSN = cast<sandboxir::MemDGNode>(DAG.getNode(NewS));
+    EXPECT_EQ(S3MemN->getNextNode(), NewMemSN);
+    EXPECT_EQ(NewMemSN->getPrevNode(), S3MemN);
+    EXPECT_EQ(NewMemSN->getNextNode(), nullptr);
+  }
+
   // TODO: Check the dependencies to/from NewSN after they land.
-  // TODO: Check the MemDGNode chain.
 }
 
 TEST_F(DependencyGraphTest, EraseInstrCallback) {

@vporpo vporpo merged commit eeb55d3 into llvm:main Dec 6, 2024
9 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants