Skip to content

Conversation

joker-eph
Copy link
Collaborator

This is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition:

  • A RegionBranchPoint is either the parent (RegionBranchOpInterface) op or a RegionBranchTerminatorOpInterface operation in a nested region.
  • A RegionSuccessor is either one of the nested region or the parent RegionBranchOpInterface

Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface.

It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately.

@llvmbot
Copy link
Member

llvmbot commented Oct 1, 2025

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir-emitc
@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir-gpu

Author: Mehdi Amini (joker-eph)

Changes

This is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition:

  • A RegionBranchPoint is either the parent (RegionBranchOpInterface) op or a RegionBranchTerminatorOpInterface operation in a nested region.
  • A RegionSuccessor is either one of the nested region or the parent RegionBranchOpInterface

Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface.

It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately.


Patch is 197.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161575.diff

38 Files Affected:

  • (modified) mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h (+3-3)
  • (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+7)
  • (modified) mlir/include/mlir/IR/Diagnostics.h (+2)
  • (modified) mlir/include/mlir/IR/Operation.h (+1)
  • (modified) mlir/include/mlir/IR/Region.h (+2)
  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+66-33)
  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+107-17)
  • (modified) mlir/include/mlir/TableGen/Interfaces.h (+6-1)
  • (modified) mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp (+222-103)
  • (modified) mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp (+6-7)
  • (modified) mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp (+3-1)
  • (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+3-3)
  • (modified) mlir/lib/Analysis/SliceWalk.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+28-21)
  • (modified) mlir/lib/Dialect/Async/IR/Async.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp (+7-4)
  • (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+5-3)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+1-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+28-22)
  • (modified) mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp (-1)
  • (modified) mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp (-1)
  • (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+19-15)
  • (modified) mlir/lib/IR/Diagnostics.cpp (+4)
  • (modified) mlir/lib/IR/Region.cpp (+15)
  • (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+217-88)
  • (modified) mlir/lib/TableGen/Interfaces.cpp (+14-3)
  • (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+18-7)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (-1958)
  • (modified) mlir/test/Dialect/SCF/invalid.mlir (+4-4)
  • (modified) mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp (+71-7)
  • (modified) mlir/test/lib/Analysis/TestAliasAnalysis.cpp (+8-1)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+13-12)
  • (modified) mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (+19-15)
  • (modified) mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp (+23-7)
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 8bcfe51ad7cd1..3c87c453a4cf0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   /// itself.
   virtual void visitRegionBranchControlFlowTransfer(
       RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
-      RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+      RegionSuccessor regionTo, const AbstractDenseLattice &after,
       AbstractDenseLattice *before) {
     meet(before, after);
   }
@@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
   /// and "to" regions.
   virtual void visitRegionBranchControlFlowTransfer(
       RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
-      RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
+      RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
     AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
         branch, regionFrom, regionTo, after, before);
   }
@@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
   }
   void visitRegionBranchControlFlowTransfer(
       RegionBranchOpInterface branch, RegionBranchPoint regionForm,
-      RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+      RegionSuccessor regionTo, const AbstractDenseLattice &after,
       AbstractDenseLattice *before) final {
     visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
                                          static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 3f8874d02afad..72f717e163fb6 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
   /// and propagating therefrom.
   virtual void
   visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
-                        RegionBranchPoint successor,
+                        RegionSuccessor successor,
                         ArrayRef<AbstractSparseLattice *> lattices);
 };
 
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index fadd3fc10bfc4..48690151caf01 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [
 
     /// Returns true if the mapping specified for this forall op is linear.
     bool usesLinearMapping();
+
+    /// RegionBranchOpInterface
+
+    OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
+      return getInits();
+    }
+
   }];
 }
 
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 7ff718ad7f241..a0a99f4953822 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -29,6 +29,7 @@ class MLIRContext;
 class Operation;
 class OperationName;
 class OpPrintingFlags;
+class OpWithFlags;
 class Type;
 class Value;
 
@@ -199,6 +200,7 @@ class Diagnostic {
 
   /// Stream in an Operation.
   Diagnostic &operator<<(Operation &op);
+  Diagnostic &operator<<(OpWithFlags op);
   Diagnostic &operator<<(Operation *op) { return *this << *op; }
   /// Append an operation with the given printing flags.
   Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 5569392cf0b41..b2019574a820d 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1114,6 +1114,7 @@ class OpWithFlags {
       : op(op), theFlags(flags) {}
   OpPrintingFlags &flags() { return theFlags; }
   const OpPrintingFlags &flags() const { return theFlags; }
+  Operation *getOperation() const { return op; }
 
 private:
   Operation *op;
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 1fcb316750230..53d461df98710 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -379,6 +379,8 @@ class RegionRange
   friend RangeBaseT;
 };
 
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region &region);
+
 } // namespace mlir
 
 #endif // MLIR_IR_REGION_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index d63800c12d132..c8304829a4df8 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -15,10 +15,15 @@
 #define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
 
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 class BranchOpInterface;
 class RegionBranchOpInterface;
+class RegionBranchTerminatorOpInterface;
 
 /// This class models how operands are forwarded to block arguments in control
 /// flow. It consists of a number, denoting how many of the successors block
@@ -186,27 +191,46 @@ class RegionSuccessor {
 public:
   /// Initialize a successor that branches to another region of the parent
   /// operation.
+  /// TODO: the default value for the regionInputs is somehow broken.
+  /// A region successor should have its input correctly set.
   RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
-      : region(region), inputs(regionInputs) {}
+      : region(region), inputs(regionInputs) {
+    assert(region && "Region must not be null");
+  }
   /// Initialize a successor that branches back to/out of the parent operation.
-  RegionSuccessor(Operation::result_range results)
-      : inputs(ValueRange(results)) {}
-  /// Constructor with no arguments.
-  RegionSuccessor() : inputs(ValueRange()) {}
+  /// The target must be one of the recursive parent operations.
+  RegionSuccessor(Operation *successorOop, Operation::result_range results)
+      : successorOp(successorOop), inputs(ValueRange(results)) {
+    assert(successorOop && "Successor op must not be null");
+  }
 
   /// Return the given region successor. Returns nullptr if the successor is the
   /// parent operation.
   Region *getSuccessor() const { return region; }
 
   /// Return true if the successor is the parent operation.
-  bool isParent() const { return region == nullptr; }
+  bool isParent() const {
+    assert((region != nullptr || successorOp != nullptr) &&
+           "Region and successor op must not be null");
+    return region == nullptr;
+  }
 
   /// Return the inputs to the successor that are remapped by the exit values of
   /// the current region.
   ValueRange getSuccessorInputs() const { return inputs; }
 
+  bool operator==(RegionSuccessor rhs) const {
+    return region == rhs.region && successorOp == rhs.successorOp &&
+           inputs == rhs.inputs;
+  }
+
+  friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
+    return !(lhs == rhs);
+  }
+
 private:
   Region *region{nullptr};
+  Operation *successorOp{nullptr};
   ValueRange inputs;
 };
 
@@ -214,7 +238,8 @@ class RegionSuccessor {
 /// `RegionBranchOpInterface`.
 /// One can branch from one of two kinds of places:
 /// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A region within the parent operation.
+/// * A RegionBranchTerminatorOpInterface inside a region within the parent
+//    operation.
 class RegionBranchPoint {
 public:
   /// Returns an instance of `RegionBranchPoint` representing the parent
@@ -223,55 +248,57 @@ class RegionBranchPoint {
 
   /// Creates a `RegionBranchPoint` that branches from the given region.
   /// The pointer must not be null.
-  RegionBranchPoint(Region *region) : maybeRegion(region) {
-    assert(region && "Region must not be null");
-  }
-
-  RegionBranchPoint(Region &region) : RegionBranchPoint(&region) {}
+  inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
 
   /// Explicitly stops users from constructing with `nullptr`.
   RegionBranchPoint(std::nullptr_t) = delete;
 
-  /// Constructs a `RegionBranchPoint` from the the target of a
-  /// `RegionSuccessor` instance.
-  RegionBranchPoint(RegionSuccessor successor) {
-    if (successor.isParent())
-      maybeRegion = nullptr;
-    else
-      maybeRegion = successor.getSuccessor();
-  }
-
-  /// Assigns a region being branched from.
-  RegionBranchPoint &operator=(Region &region) {
-    maybeRegion = &region;
-    return *this;
-  }
-
   /// Returns true if branching from the parent op.
-  bool isParent() const { return maybeRegion == nullptr; }
+  bool isParent() const { return predecessor == nullptr; }
 
   /// Returns the region if branching from a region.
   /// A null pointer otherwise.
-  Region *getRegionOrNull() const { return maybeRegion; }
+  Operation *getPredecessorOrNull() const { return predecessor; }
 
   /// Returns true if the two branch points are equal.
   friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
-    return lhs.maybeRegion == rhs.maybeRegion;
+    return lhs.predecessor == rhs.predecessor;
   }
 
 private:
   // Private constructor to encourage the use of `RegionBranchPoint::parent`.
-  constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+  constexpr RegionBranchPoint() = default;
 
   /// Internal encoding. Uses nullptr for representing branching from the parent
-  /// op and the region being branched from otherwise.
-  Region *maybeRegion;
+  /// op and the region terminator being branched from otherwise.
+  Operation *predecessor = nullptr;
 };
 
 inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
   return !(lhs == rhs);
 }
 
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                                     RegionBranchPoint point) {
+  if (point.isParent())
+    return os << "<from parent>";
+  return os
+         << "<region #"
+         << point.getPredecessorOrNull()->getParentRegion()->getRegionNumber()
+         << ", terminator "
+         << OpWithFlags(point.getPredecessorOrNull(),
+                        OpPrintingFlags().skipRegions())
+         << ">";
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                                     RegionSuccessor successor) {
+  if (successor.isParent())
+    return os << "<to parent>";
+  return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
+            << " with " << successor.getSuccessorInputs().size() << " inputs>";
+}
+
 /// This class represents upper and lower bounds on the number of times a region
 /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
 /// zero, but the upper bound may not be known.
@@ -348,4 +375,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
 /// Include the generated interface declarations.
 #include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
 
+namespace mlir {
+inline RegionBranchPoint::RegionBranchPoint(
+    RegionBranchTerminatorOpInterface predecessor)
+    : predecessor(predecessor.getOperation()) {}
+} // namespace mlir
+
 #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index b8d08cc553caa..c9fe3354bc6a1 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
 
 def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
   let description = [{
-    This interface provides information for region operations that exhibit
+    This interface provides information for region-holding operations that exhibit
     branching behavior between held regions. I.e., this interface allows for
     expressing control flow information for region holding operations.
 
@@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     be side-effect free.
 
     A "region branch point" indicates a point from which a branch originates. It
-    can indicate either a region of this op or `RegionBranchPoint::parent()`. In
-    the latter case, the branch originates from outside of the op, i.e., when
-    first executing this op.
+    can indicate either a terminator in any of the recursively nested region of
+    this op or `RegionBranchPoint::parent()`. In the latter case, the branch
+    originates from outside of the op, i.e., when first executing this op.
 
     A "region successor" indicates the target of a branch. It can indicate
-    either a region of this op or this op. In the former case, the region
+    either a region of this op or this op itself. In the former case, the region
     successor is a region pointer and a range of block arguments to which the
     "successor operands" are forwarded to. In the latter case, the control flow
     leaves this op and the region successor is a range of results of this op to
@@ -151,10 +151,25 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     }
     ```
 
-    `scf.for` has one region. The region has two region successors: the region
-    itself and the `scf.for` op. %b is an entry successor operand. %c is a
-    successor operand. %a is a successor block argument. %r is a successor
-    result.
+    `scf.for` has one region. The `scf.yield` has two region successors: the
+    region body itself and the `scf.for` op. `%b` is an entry successor
+    operand. `%c` is a successor operand. `%a` is a successor block argument.
+    `%r` is a successor result.
+
+    
+    ```
+    %r = scf.loop iter_args(%a = %b)
+        -> tensor<5xf32> {
+      ...
+      scf.yield %c : tensor<5xf32>
+    }
+    ```
+
+    `scf.for` has one region. The `scf.yield` has two region successors: the
+    region body itself and the `scf.for` op. `%b` is an entry successor
+    operand. `%c` is a successor operand. `%a` is a successor block argument.
+    `%r` is a successor result.
+
   }];
   let cppNamespace = "::mlir";
 
@@ -162,16 +177,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     InterfaceMethod<[{
         Returns the operands of this operation that are forwarded to the region
         successor's block arguments or this operation's results when branching
-        to `point`. `point` is guaranteed to be among the successors that are
+        to `successor`. `successor` is guaranteed to be among the successors that are
         returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
 
         Example: In the above example, this method returns the operand %b of the
-        `scf.for` op, regardless of the value of `point`. I.e., this op always
+        `scf.for` op, regardless of the value of `successor`. I.e., this op always
         forwards the same operands, regardless of whether the loop has 0 or more
         iterations.
       }],
       "::mlir::OperandRange", "getEntrySuccessorOperands",
-      (ins "::mlir::RegionBranchPoint":$point), [{}],
+      (ins "::mlir::RegionSuccessor":$successor), [{}],
       /*defaultImplementation=*/[{
         auto operandEnd = this->getOperation()->operand_end();
         return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -224,6 +239,81 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       (ins "::mlir::RegionBranchPoint":$point,
            "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
     >,
+    InterfaceMethod<[{
+        Returns the potential region successors when branching from any
+        terminator in `region`.
+        These are the regions that may be selected during the flow of control.
+      }],
+      "void", "getSuccessorRegions",
+      (ins "::mlir::Region&":$region,
+           "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+      [{}],
+      /*defaultImplementation=*/[{
+        for (::mlir::Block &block : region) {
+          if (block.empty())
+            continue;
+          if (auto terminator =
+                  dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+            $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+                                     regions);
+        }
+      }]>,
+    InterfaceMethod<[{
+        Returns the potential branching point (predecessors) for a given successor.
+      }],
+      "void", "getPredecessors",
+      (ins "::mlir::RegionSuccessor":$successor,
+           "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
+      [{}],
+      /*defaultImplementation=*/[{
+        ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+        $_op.getSuccessorRegions(RegionBranchPoint::parent(),
+                                 successors);
+        if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+            return succ.getSuccessor() == successor.getSuccessor() ||
+              (succ.isParent() && successor.isParent());
+          }))
+          predecessors.push_back(RegionBranchPoint::parent());
+        for (Region &region : $_op->getRegions()) {
+          for (::mlir::Block &block : region) {
+            if (block.empty())
+              continue;
+            if (auto terminator =
+                    dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) {
+              ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+              $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+                                       successors);
+              if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+                  return succ.getSuccessor() == successor.getSuccessor() ||
+                    (succ.isParent() && successor.isParent());
+                }))
+                predecessors.push_back(terminator);
+            }
+          }
+        }
+      }]>,
+    InterfaceMethod<[{
+        Returns the potential values across all (predecessors) for a given successor
+        input, modeled by its index.
+      }],
+      "void", "getPredecessorValues",
+      (ins "::mlir::RegionSuccessor":$successor,
+           "int":$index,
+           "::llvm::SmallVectorImpl<::mlir::Value> &":$predecessorValues),
+      [{}],
+      /*defaultImplementation=*/[{
+        ::llvm::SmallVector<::mlir::RegionBranchPoint> predecessors;
+        $_op.getPredecessors(successor, predecessors);
+        for (auto predecessor : predecessors) {
+          if (predecessor.isParent()) {
+            predecessorValues.push_back($_op.getEntrySuccessorOperands(successor)[index]);
+            continue;
+          }
+          auto terminator = cast<RegionBranchTerminatorOpInterface>(predecessor.getPredecessorOrNull());
+          predecessorValues.push_back(terminator.getSuccessorOperands(successor)[index]);
+
+        }
+      }]>,
     InterfaceMethod<[{
         Populates `invocationBounds` with the minimum and maximum number of
         times this operation will invoke the attached regions (assuming the
@@ -298,7 +388,7 @@ def RegionBranchTerminatorOpInterface :
         passing them to the region successor indicated by `point`.
       }],
       "::mlir::MutableOperandRange", "getMutableSuccessorOperands",
-      (ins "::mlir::RegionBranchPoint":$point)
+      (ins "::mlir::RegionSuccessor":$point)
     >,
     InterfaceMethod<[{
         Returns the potential region successors that are branched to after this
@@ -317,7 +407,7 @@ def RegionBranchTerminatorOpInterface :
       /*defaultImplementation=*/[{
         ::mlir::Operation *op = $_op;
         ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
-          .getSuccessorRegions(op->getParentRegion(), regions);
+          .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions);
       }]
     >,
   ];
@@ -337,8 +427,8 @@ def RegionBranchTerminatorOpInterface :
     // them to the region successor given by `index`.  If `index` is None, this
     // function returns the operands that are passed as a result to the parent
     // operation.
-    ::mlir::OperandRange getSucces...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 1, 2025

@llvm/pr-subscribers-mlir-shape

Author: Mehdi Amini (joker-eph)

Changes

This is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition:

  • A RegionBranchPoint is either the parent (RegionBranchOpInterface) op or a RegionBranchTerminatorOpInterface operation in a nested region.
  • A RegionSuccessor is either one of the nested region or the parent RegionBranchOpInterface

Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface.

It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately.


Patch is 197.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161575.diff

38 Files Affected:

  • (modified) mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h (+3-3)
  • (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+7)
  • (modified) mlir/include/mlir/IR/Diagnostics.h (+2)
  • (modified) mlir/include/mlir/IR/Operation.h (+1)
  • (modified) mlir/include/mlir/IR/Region.h (+2)
  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+66-33)
  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+107-17)
  • (modified) mlir/include/mlir/TableGen/Interfaces.h (+6-1)
  • (modified) mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp (+222-103)
  • (modified) mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp (+6-7)
  • (modified) mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp (+3-1)
  • (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+3-3)
  • (modified) mlir/lib/Analysis/SliceWalk.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+28-21)
  • (modified) mlir/lib/Dialect/Async/IR/Async.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp (+7-4)
  • (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+5-3)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+1-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+28-22)
  • (modified) mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp (-1)
  • (modified) mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp (-1)
  • (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+19-15)
  • (modified) mlir/lib/IR/Diagnostics.cpp (+4)
  • (modified) mlir/lib/IR/Region.cpp (+15)
  • (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+217-88)
  • (modified) mlir/lib/TableGen/Interfaces.cpp (+14-3)
  • (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+18-7)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (-1958)
  • (modified) mlir/test/Dialect/SCF/invalid.mlir (+4-4)
  • (modified) mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp (+71-7)
  • (modified) mlir/test/lib/Analysis/TestAliasAnalysis.cpp (+8-1)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+13-12)
  • (modified) mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (+19-15)
  • (modified) mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp (+23-7)
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 8bcfe51ad7cd1..3c87c453a4cf0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   /// itself.
   virtual void visitRegionBranchControlFlowTransfer(
       RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
-      RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+      RegionSuccessor regionTo, const AbstractDenseLattice &after,
       AbstractDenseLattice *before) {
     meet(before, after);
   }
@@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
   /// and "to" regions.
   virtual void visitRegionBranchControlFlowTransfer(
       RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
-      RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
+      RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
     AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
         branch, regionFrom, regionTo, after, before);
   }
@@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
   }
   void visitRegionBranchControlFlowTransfer(
       RegionBranchOpInterface branch, RegionBranchPoint regionForm,
-      RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+      RegionSuccessor regionTo, const AbstractDenseLattice &after,
       AbstractDenseLattice *before) final {
     visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
                                          static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 3f8874d02afad..72f717e163fb6 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
   /// and propagating therefrom.
   virtual void
   visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
-                        RegionBranchPoint successor,
+                        RegionSuccessor successor,
                         ArrayRef<AbstractSparseLattice *> lattices);
 };
 
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index fadd3fc10bfc4..48690151caf01 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [
 
     /// Returns true if the mapping specified for this forall op is linear.
     bool usesLinearMapping();
+
+    /// RegionBranchOpInterface
+
+    OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
+      return getInits();
+    }
+
   }];
 }
 
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 7ff718ad7f241..a0a99f4953822 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -29,6 +29,7 @@ class MLIRContext;
 class Operation;
 class OperationName;
 class OpPrintingFlags;
+class OpWithFlags;
 class Type;
 class Value;
 
@@ -199,6 +200,7 @@ class Diagnostic {
 
   /// Stream in an Operation.
   Diagnostic &operator<<(Operation &op);
+  Diagnostic &operator<<(OpWithFlags op);
   Diagnostic &operator<<(Operation *op) { return *this << *op; }
   /// Append an operation with the given printing flags.
   Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 5569392cf0b41..b2019574a820d 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1114,6 +1114,7 @@ class OpWithFlags {
       : op(op), theFlags(flags) {}
   OpPrintingFlags &flags() { return theFlags; }
   const OpPrintingFlags &flags() const { return theFlags; }
+  Operation *getOperation() const { return op; }
 
 private:
   Operation *op;
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 1fcb316750230..53d461df98710 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -379,6 +379,8 @@ class RegionRange
   friend RangeBaseT;
 };
 
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region &region);
+
 } // namespace mlir
 
 #endif // MLIR_IR_REGION_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index d63800c12d132..c8304829a4df8 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -15,10 +15,15 @@
 #define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
 
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 class BranchOpInterface;
 class RegionBranchOpInterface;
+class RegionBranchTerminatorOpInterface;
 
 /// This class models how operands are forwarded to block arguments in control
 /// flow. It consists of a number, denoting how many of the successors block
@@ -186,27 +191,46 @@ class RegionSuccessor {
 public:
   /// Initialize a successor that branches to another region of the parent
   /// operation.
+  /// TODO: the default value for the regionInputs is somehow broken.
+  /// A region successor should have its input correctly set.
   RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
-      : region(region), inputs(regionInputs) {}
+      : region(region), inputs(regionInputs) {
+    assert(region && "Region must not be null");
+  }
   /// Initialize a successor that branches back to/out of the parent operation.
-  RegionSuccessor(Operation::result_range results)
-      : inputs(ValueRange(results)) {}
-  /// Constructor with no arguments.
-  RegionSuccessor() : inputs(ValueRange()) {}
+  /// The target must be one of the recursive parent operations.
+  RegionSuccessor(Operation *successorOop, Operation::result_range results)
+      : successorOp(successorOop), inputs(ValueRange(results)) {
+    assert(successorOop && "Successor op must not be null");
+  }
 
   /// Return the given region successor. Returns nullptr if the successor is the
   /// parent operation.
   Region *getSuccessor() const { return region; }
 
   /// Return true if the successor is the parent operation.
-  bool isParent() const { return region == nullptr; }
+  bool isParent() const {
+    assert((region != nullptr || successorOp != nullptr) &&
+           "Region and successor op must not be null");
+    return region == nullptr;
+  }
 
   /// Return the inputs to the successor that are remapped by the exit values of
   /// the current region.
   ValueRange getSuccessorInputs() const { return inputs; }
 
+  bool operator==(RegionSuccessor rhs) const {
+    return region == rhs.region && successorOp == rhs.successorOp &&
+           inputs == rhs.inputs;
+  }
+
+  friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
+    return !(lhs == rhs);
+  }
+
 private:
   Region *region{nullptr};
+  Operation *successorOp{nullptr};
   ValueRange inputs;
 };
 
@@ -214,7 +238,8 @@ class RegionSuccessor {
 /// `RegionBranchOpInterface`.
 /// One can branch from one of two kinds of places:
 /// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A region within the parent operation.
+/// * A RegionBranchTerminatorOpInterface inside a region within the parent
+//    operation.
 class RegionBranchPoint {
 public:
   /// Returns an instance of `RegionBranchPoint` representing the parent
@@ -223,55 +248,57 @@ class RegionBranchPoint {
 
   /// Creates a `RegionBranchPoint` that branches from the given region.
   /// The pointer must not be null.
-  RegionBranchPoint(Region *region) : maybeRegion(region) {
-    assert(region && "Region must not be null");
-  }
-
-  RegionBranchPoint(Region &region) : RegionBranchPoint(&region) {}
+  inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
 
   /// Explicitly stops users from constructing with `nullptr`.
   RegionBranchPoint(std::nullptr_t) = delete;
 
-  /// Constructs a `RegionBranchPoint` from the the target of a
-  /// `RegionSuccessor` instance.
-  RegionBranchPoint(RegionSuccessor successor) {
-    if (successor.isParent())
-      maybeRegion = nullptr;
-    else
-      maybeRegion = successor.getSuccessor();
-  }
-
-  /// Assigns a region being branched from.
-  RegionBranchPoint &operator=(Region &region) {
-    maybeRegion = &region;
-    return *this;
-  }
-
   /// Returns true if branching from the parent op.
-  bool isParent() const { return maybeRegion == nullptr; }
+  bool isParent() const { return predecessor == nullptr; }
 
   /// Returns the region if branching from a region.
   /// A null pointer otherwise.
-  Region *getRegionOrNull() const { return maybeRegion; }
+  Operation *getPredecessorOrNull() const { return predecessor; }
 
   /// Returns true if the two branch points are equal.
   friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
-    return lhs.maybeRegion == rhs.maybeRegion;
+    return lhs.predecessor == rhs.predecessor;
   }
 
 private:
   // Private constructor to encourage the use of `RegionBranchPoint::parent`.
-  constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+  constexpr RegionBranchPoint() = default;
 
   /// Internal encoding. Uses nullptr for representing branching from the parent
-  /// op and the region being branched from otherwise.
-  Region *maybeRegion;
+  /// op and the region terminator being branched from otherwise.
+  Operation *predecessor = nullptr;
 };
 
 inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
   return !(lhs == rhs);
 }
 
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                                     RegionBranchPoint point) {
+  if (point.isParent())
+    return os << "<from parent>";
+  return os
+         << "<region #"
+         << point.getPredecessorOrNull()->getParentRegion()->getRegionNumber()
+         << ", terminator "
+         << OpWithFlags(point.getPredecessorOrNull(),
+                        OpPrintingFlags().skipRegions())
+         << ">";
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                                     RegionSuccessor successor) {
+  if (successor.isParent())
+    return os << "<to parent>";
+  return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
+            << " with " << successor.getSuccessorInputs().size() << " inputs>";
+}
+
 /// This class represents upper and lower bounds on the number of times a region
 /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
 /// zero, but the upper bound may not be known.
@@ -348,4 +375,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
 /// Include the generated interface declarations.
 #include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
 
+namespace mlir {
+inline RegionBranchPoint::RegionBranchPoint(
+    RegionBranchTerminatorOpInterface predecessor)
+    : predecessor(predecessor.getOperation()) {}
+} // namespace mlir
+
 #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index b8d08cc553caa..c9fe3354bc6a1 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
 
 def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
   let description = [{
-    This interface provides information for region operations that exhibit
+    This interface provides information for region-holding operations that exhibit
     branching behavior between held regions. I.e., this interface allows for
     expressing control flow information for region holding operations.
 
@@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     be side-effect free.
 
     A "region branch point" indicates a point from which a branch originates. It
-    can indicate either a region of this op or `RegionBranchPoint::parent()`. In
-    the latter case, the branch originates from outside of the op, i.e., when
-    first executing this op.
+    can indicate either a terminator in any of the recursively nested region of
+    this op or `RegionBranchPoint::parent()`. In the latter case, the branch
+    originates from outside of the op, i.e., when first executing this op.
 
     A "region successor" indicates the target of a branch. It can indicate
-    either a region of this op or this op. In the former case, the region
+    either a region of this op or this op itself. In the former case, the region
     successor is a region pointer and a range of block arguments to which the
     "successor operands" are forwarded to. In the latter case, the control flow
     leaves this op and the region successor is a range of results of this op to
@@ -151,10 +151,25 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     }
     ```
 
-    `scf.for` has one region. The region has two region successors: the region
-    itself and the `scf.for` op. %b is an entry successor operand. %c is a
-    successor operand. %a is a successor block argument. %r is a successor
-    result.
+    `scf.for` has one region. The `scf.yield` has two region successors: the
+    region body itself and the `scf.for` op. `%b` is an entry successor
+    operand. `%c` is a successor operand. `%a` is a successor block argument.
+    `%r` is a successor result.
+
+    
+    ```
+    %r = scf.loop iter_args(%a = %b)
+        -> tensor<5xf32> {
+      ...
+      scf.yield %c : tensor<5xf32>
+    }
+    ```
+
+    `scf.for` has one region. The `scf.yield` has two region successors: the
+    region body itself and the `scf.for` op. `%b` is an entry successor
+    operand. `%c` is a successor operand. `%a` is a successor block argument.
+    `%r` is a successor result.
+
   }];
   let cppNamespace = "::mlir";
 
@@ -162,16 +177,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     InterfaceMethod<[{
         Returns the operands of this operation that are forwarded to the region
         successor's block arguments or this operation's results when branching
-        to `point`. `point` is guaranteed to be among the successors that are
+        to `successor`. `successor` is guaranteed to be among the successors that are
         returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
 
         Example: In the above example, this method returns the operand %b of the
-        `scf.for` op, regardless of the value of `point`. I.e., this op always
+        `scf.for` op, regardless of the value of `successor`. I.e., this op always
         forwards the same operands, regardless of whether the loop has 0 or more
         iterations.
       }],
       "::mlir::OperandRange", "getEntrySuccessorOperands",
-      (ins "::mlir::RegionBranchPoint":$point), [{}],
+      (ins "::mlir::RegionSuccessor":$successor), [{}],
       /*defaultImplementation=*/[{
         auto operandEnd = this->getOperation()->operand_end();
         return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -224,6 +239,81 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       (ins "::mlir::RegionBranchPoint":$point,
            "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
     >,
+    InterfaceMethod<[{
+        Returns the potential region successors when branching from any
+        terminator in `region`.
+        These are the regions that may be selected during the flow of control.
+      }],
+      "void", "getSuccessorRegions",
+      (ins "::mlir::Region&":$region,
+           "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+      [{}],
+      /*defaultImplementation=*/[{
+        for (::mlir::Block &block : region) {
+          if (block.empty())
+            continue;
+          if (auto terminator =
+                  dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+            $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+                                     regions);
+        }
+      }]>,
+    InterfaceMethod<[{
+        Returns the potential branching point (predecessors) for a given successor.
+      }],
+      "void", "getPredecessors",
+      (ins "::mlir::RegionSuccessor":$successor,
+           "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
+      [{}],
+      /*defaultImplementation=*/[{
+        ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+        $_op.getSuccessorRegions(RegionBranchPoint::parent(),
+                                 successors);
+        if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+            return succ.getSuccessor() == successor.getSuccessor() ||
+              (succ.isParent() && successor.isParent());
+          }))
+          predecessors.push_back(RegionBranchPoint::parent());
+        for (Region &region : $_op->getRegions()) {
+          for (::mlir::Block &block : region) {
+            if (block.empty())
+              continue;
+            if (auto terminator =
+                    dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) {
+              ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+              $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+                                       successors);
+              if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+                  return succ.getSuccessor() == successor.getSuccessor() ||
+                    (succ.isParent() && successor.isParent());
+                }))
+                predecessors.push_back(terminator);
+            }
+          }
+        }
+      }]>,
+    InterfaceMethod<[{
+        Returns the potential values across all (predecessors) for a given successor
+        input, modeled by its index.
+      }],
+      "void", "getPredecessorValues",
+      (ins "::mlir::RegionSuccessor":$successor,
+           "int":$index,
+           "::llvm::SmallVectorImpl<::mlir::Value> &":$predecessorValues),
+      [{}],
+      /*defaultImplementation=*/[{
+        ::llvm::SmallVector<::mlir::RegionBranchPoint> predecessors;
+        $_op.getPredecessors(successor, predecessors);
+        for (auto predecessor : predecessors) {
+          if (predecessor.isParent()) {
+            predecessorValues.push_back($_op.getEntrySuccessorOperands(successor)[index]);
+            continue;
+          }
+          auto terminator = cast<RegionBranchTerminatorOpInterface>(predecessor.getPredecessorOrNull());
+          predecessorValues.push_back(terminator.getSuccessorOperands(successor)[index]);
+
+        }
+      }]>,
     InterfaceMethod<[{
         Populates `invocationBounds` with the minimum and maximum number of
         times this operation will invoke the attached regions (assuming the
@@ -298,7 +388,7 @@ def RegionBranchTerminatorOpInterface :
         passing them to the region successor indicated by `point`.
       }],
       "::mlir::MutableOperandRange", "getMutableSuccessorOperands",
-      (ins "::mlir::RegionBranchPoint":$point)
+      (ins "::mlir::RegionSuccessor":$point)
     >,
     InterfaceMethod<[{
         Returns the potential region successors that are branched to after this
@@ -317,7 +407,7 @@ def RegionBranchTerminatorOpInterface :
       /*defaultImplementation=*/[{
         ::mlir::Operation *op = $_op;
         ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
-          .getSuccessorRegions(op->getParentRegion(), regions);
+          .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions);
       }]
     >,
   ];
@@ -337,8 +427,8 @@ def RegionBranchTerminatorOpInterface :
     // them to the region successor given by `index`.  If `index` is None, this
     // function returns the operands that are passed as a result to the parent
     // operation.
-    ::mlir::OperandRange getSucces...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 1, 2025

@llvm/pr-subscribers-mlir-sparse

Author: Mehdi Amini (joker-eph)

Changes

This is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition:

  • A RegionBranchPoint is either the parent (RegionBranchOpInterface) op or a RegionBranchTerminatorOpInterface operation in a nested region.
  • A RegionSuccessor is either one of the nested region or the parent RegionBranchOpInterface

Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface.

It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately.


Patch is 197.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161575.diff

38 Files Affected:

  • (modified) mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h (+3-3)
  • (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+1-1)
  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+7)
  • (modified) mlir/include/mlir/IR/Diagnostics.h (+2)
  • (modified) mlir/include/mlir/IR/Operation.h (+1)
  • (modified) mlir/include/mlir/IR/Region.h (+2)
  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+66-33)
  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+107-17)
  • (modified) mlir/include/mlir/TableGen/Interfaces.h (+6-1)
  • (modified) mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp (+222-103)
  • (modified) mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp (+6-7)
  • (modified) mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp (+3-1)
  • (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+3-3)
  • (modified) mlir/lib/Analysis/SliceWalk.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+28-21)
  • (modified) mlir/lib/Dialect/Async/IR/Async.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp (+7-4)
  • (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+5-3)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+1-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+28-22)
  • (modified) mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp (-1)
  • (modified) mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp (-1)
  • (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+19-15)
  • (modified) mlir/lib/IR/Diagnostics.cpp (+4)
  • (modified) mlir/lib/IR/Region.cpp (+15)
  • (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+217-88)
  • (modified) mlir/lib/TableGen/Interfaces.cpp (+14-3)
  • (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+18-7)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (-1958)
  • (modified) mlir/test/Dialect/SCF/invalid.mlir (+4-4)
  • (modified) mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp (+71-7)
  • (modified) mlir/test/lib/Analysis/TestAliasAnalysis.cpp (+8-1)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+13-12)
  • (modified) mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (+19-15)
  • (modified) mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp (+23-7)
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 8bcfe51ad7cd1..3c87c453a4cf0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   /// itself.
   virtual void visitRegionBranchControlFlowTransfer(
       RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
-      RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+      RegionSuccessor regionTo, const AbstractDenseLattice &after,
       AbstractDenseLattice *before) {
     meet(before, after);
   }
@@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
   /// and "to" regions.
   virtual void visitRegionBranchControlFlowTransfer(
       RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
-      RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
+      RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
     AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
         branch, regionFrom, regionTo, after, before);
   }
@@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
   }
   void visitRegionBranchControlFlowTransfer(
       RegionBranchOpInterface branch, RegionBranchPoint regionForm,
-      RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+      RegionSuccessor regionTo, const AbstractDenseLattice &after,
       AbstractDenseLattice *before) final {
     visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
                                          static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 3f8874d02afad..72f717e163fb6 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
   /// and propagating therefrom.
   virtual void
   visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
-                        RegionBranchPoint successor,
+                        RegionSuccessor successor,
                         ArrayRef<AbstractSparseLattice *> lattices);
 };
 
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index fadd3fc10bfc4..48690151caf01 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [
 
     /// Returns true if the mapping specified for this forall op is linear.
     bool usesLinearMapping();
+
+    /// RegionBranchOpInterface
+
+    OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
+      return getInits();
+    }
+
   }];
 }
 
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 7ff718ad7f241..a0a99f4953822 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -29,6 +29,7 @@ class MLIRContext;
 class Operation;
 class OperationName;
 class OpPrintingFlags;
+class OpWithFlags;
 class Type;
 class Value;
 
@@ -199,6 +200,7 @@ class Diagnostic {
 
   /// Stream in an Operation.
   Diagnostic &operator<<(Operation &op);
+  Diagnostic &operator<<(OpWithFlags op);
   Diagnostic &operator<<(Operation *op) { return *this << *op; }
   /// Append an operation with the given printing flags.
   Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 5569392cf0b41..b2019574a820d 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1114,6 +1114,7 @@ class OpWithFlags {
       : op(op), theFlags(flags) {}
   OpPrintingFlags &flags() { return theFlags; }
   const OpPrintingFlags &flags() const { return theFlags; }
+  Operation *getOperation() const { return op; }
 
 private:
   Operation *op;
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 1fcb316750230..53d461df98710 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -379,6 +379,8 @@ class RegionRange
   friend RangeBaseT;
 };
 
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region &region);
+
 } // namespace mlir
 
 #endif // MLIR_IR_REGION_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index d63800c12d132..c8304829a4df8 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -15,10 +15,15 @@
 #define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
 
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 class BranchOpInterface;
 class RegionBranchOpInterface;
+class RegionBranchTerminatorOpInterface;
 
 /// This class models how operands are forwarded to block arguments in control
 /// flow. It consists of a number, denoting how many of the successors block
@@ -186,27 +191,46 @@ class RegionSuccessor {
 public:
   /// Initialize a successor that branches to another region of the parent
   /// operation.
+  /// TODO: the default value for the regionInputs is somehow broken.
+  /// A region successor should have its input correctly set.
   RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
-      : region(region), inputs(regionInputs) {}
+      : region(region), inputs(regionInputs) {
+    assert(region && "Region must not be null");
+  }
   /// Initialize a successor that branches back to/out of the parent operation.
-  RegionSuccessor(Operation::result_range results)
-      : inputs(ValueRange(results)) {}
-  /// Constructor with no arguments.
-  RegionSuccessor() : inputs(ValueRange()) {}
+  /// The target must be one of the recursive parent operations.
+  RegionSuccessor(Operation *successorOop, Operation::result_range results)
+      : successorOp(successorOop), inputs(ValueRange(results)) {
+    assert(successorOop && "Successor op must not be null");
+  }
 
   /// Return the given region successor. Returns nullptr if the successor is the
   /// parent operation.
   Region *getSuccessor() const { return region; }
 
   /// Return true if the successor is the parent operation.
-  bool isParent() const { return region == nullptr; }
+  bool isParent() const {
+    assert((region != nullptr || successorOp != nullptr) &&
+           "Region and successor op must not be null");
+    return region == nullptr;
+  }
 
   /// Return the inputs to the successor that are remapped by the exit values of
   /// the current region.
   ValueRange getSuccessorInputs() const { return inputs; }
 
+  bool operator==(RegionSuccessor rhs) const {
+    return region == rhs.region && successorOp == rhs.successorOp &&
+           inputs == rhs.inputs;
+  }
+
+  friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
+    return !(lhs == rhs);
+  }
+
 private:
   Region *region{nullptr};
+  Operation *successorOp{nullptr};
   ValueRange inputs;
 };
 
@@ -214,7 +238,8 @@ class RegionSuccessor {
 /// `RegionBranchOpInterface`.
 /// One can branch from one of two kinds of places:
 /// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A region within the parent operation.
+/// * A RegionBranchTerminatorOpInterface inside a region within the parent
+//    operation.
 class RegionBranchPoint {
 public:
   /// Returns an instance of `RegionBranchPoint` representing the parent
@@ -223,55 +248,57 @@ class RegionBranchPoint {
 
   /// Creates a `RegionBranchPoint` that branches from the given region.
   /// The pointer must not be null.
-  RegionBranchPoint(Region *region) : maybeRegion(region) {
-    assert(region && "Region must not be null");
-  }
-
-  RegionBranchPoint(Region &region) : RegionBranchPoint(&region) {}
+  inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
 
   /// Explicitly stops users from constructing with `nullptr`.
   RegionBranchPoint(std::nullptr_t) = delete;
 
-  /// Constructs a `RegionBranchPoint` from the the target of a
-  /// `RegionSuccessor` instance.
-  RegionBranchPoint(RegionSuccessor successor) {
-    if (successor.isParent())
-      maybeRegion = nullptr;
-    else
-      maybeRegion = successor.getSuccessor();
-  }
-
-  /// Assigns a region being branched from.
-  RegionBranchPoint &operator=(Region &region) {
-    maybeRegion = &region;
-    return *this;
-  }
-
   /// Returns true if branching from the parent op.
-  bool isParent() const { return maybeRegion == nullptr; }
+  bool isParent() const { return predecessor == nullptr; }
 
   /// Returns the region if branching from a region.
   /// A null pointer otherwise.
-  Region *getRegionOrNull() const { return maybeRegion; }
+  Operation *getPredecessorOrNull() const { return predecessor; }
 
   /// Returns true if the two branch points are equal.
   friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
-    return lhs.maybeRegion == rhs.maybeRegion;
+    return lhs.predecessor == rhs.predecessor;
   }
 
 private:
   // Private constructor to encourage the use of `RegionBranchPoint::parent`.
-  constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+  constexpr RegionBranchPoint() = default;
 
   /// Internal encoding. Uses nullptr for representing branching from the parent
-  /// op and the region being branched from otherwise.
-  Region *maybeRegion;
+  /// op and the region terminator being branched from otherwise.
+  Operation *predecessor = nullptr;
 };
 
 inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
   return !(lhs == rhs);
 }
 
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                                     RegionBranchPoint point) {
+  if (point.isParent())
+    return os << "<from parent>";
+  return os
+         << "<region #"
+         << point.getPredecessorOrNull()->getParentRegion()->getRegionNumber()
+         << ", terminator "
+         << OpWithFlags(point.getPredecessorOrNull(),
+                        OpPrintingFlags().skipRegions())
+         << ">";
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                                     RegionSuccessor successor) {
+  if (successor.isParent())
+    return os << "<to parent>";
+  return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
+            << " with " << successor.getSuccessorInputs().size() << " inputs>";
+}
+
 /// This class represents upper and lower bounds on the number of times a region
 /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
 /// zero, but the upper bound may not be known.
@@ -348,4 +375,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
 /// Include the generated interface declarations.
 #include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
 
+namespace mlir {
+inline RegionBranchPoint::RegionBranchPoint(
+    RegionBranchTerminatorOpInterface predecessor)
+    : predecessor(predecessor.getOperation()) {}
+} // namespace mlir
+
 #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index b8d08cc553caa..c9fe3354bc6a1 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
 
 def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
   let description = [{
-    This interface provides information for region operations that exhibit
+    This interface provides information for region-holding operations that exhibit
     branching behavior between held regions. I.e., this interface allows for
     expressing control flow information for region holding operations.
 
@@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     be side-effect free.
 
     A "region branch point" indicates a point from which a branch originates. It
-    can indicate either a region of this op or `RegionBranchPoint::parent()`. In
-    the latter case, the branch originates from outside of the op, i.e., when
-    first executing this op.
+    can indicate either a terminator in any of the recursively nested region of
+    this op or `RegionBranchPoint::parent()`. In the latter case, the branch
+    originates from outside of the op, i.e., when first executing this op.
 
     A "region successor" indicates the target of a branch. It can indicate
-    either a region of this op or this op. In the former case, the region
+    either a region of this op or this op itself. In the former case, the region
     successor is a region pointer and a range of block arguments to which the
     "successor operands" are forwarded to. In the latter case, the control flow
     leaves this op and the region successor is a range of results of this op to
@@ -151,10 +151,25 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     }
     ```
 
-    `scf.for` has one region. The region has two region successors: the region
-    itself and the `scf.for` op. %b is an entry successor operand. %c is a
-    successor operand. %a is a successor block argument. %r is a successor
-    result.
+    `scf.for` has one region. The `scf.yield` has two region successors: the
+    region body itself and the `scf.for` op. `%b` is an entry successor
+    operand. `%c` is a successor operand. `%a` is a successor block argument.
+    `%r` is a successor result.
+
+    
+    ```
+    %r = scf.loop iter_args(%a = %b)
+        -> tensor<5xf32> {
+      ...
+      scf.yield %c : tensor<5xf32>
+    }
+    ```
+
+    `scf.for` has one region. The `scf.yield` has two region successors: the
+    region body itself and the `scf.for` op. `%b` is an entry successor
+    operand. `%c` is a successor operand. `%a` is a successor block argument.
+    `%r` is a successor result.
+
   }];
   let cppNamespace = "::mlir";
 
@@ -162,16 +177,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     InterfaceMethod<[{
         Returns the operands of this operation that are forwarded to the region
         successor's block arguments or this operation's results when branching
-        to `point`. `point` is guaranteed to be among the successors that are
+        to `successor`. `successor` is guaranteed to be among the successors that are
         returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
 
         Example: In the above example, this method returns the operand %b of the
-        `scf.for` op, regardless of the value of `point`. I.e., this op always
+        `scf.for` op, regardless of the value of `successor`. I.e., this op always
         forwards the same operands, regardless of whether the loop has 0 or more
         iterations.
       }],
       "::mlir::OperandRange", "getEntrySuccessorOperands",
-      (ins "::mlir::RegionBranchPoint":$point), [{}],
+      (ins "::mlir::RegionSuccessor":$successor), [{}],
       /*defaultImplementation=*/[{
         auto operandEnd = this->getOperation()->operand_end();
         return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -224,6 +239,81 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       (ins "::mlir::RegionBranchPoint":$point,
            "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
     >,
+    InterfaceMethod<[{
+        Returns the potential region successors when branching from any
+        terminator in `region`.
+        These are the regions that may be selected during the flow of control.
+      }],
+      "void", "getSuccessorRegions",
+      (ins "::mlir::Region&":$region,
+           "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+      [{}],
+      /*defaultImplementation=*/[{
+        for (::mlir::Block &block : region) {
+          if (block.empty())
+            continue;
+          if (auto terminator =
+                  dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+            $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+                                     regions);
+        }
+      }]>,
+    InterfaceMethod<[{
+        Returns the potential branching point (predecessors) for a given successor.
+      }],
+      "void", "getPredecessors",
+      (ins "::mlir::RegionSuccessor":$successor,
+           "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
+      [{}],
+      /*defaultImplementation=*/[{
+        ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+        $_op.getSuccessorRegions(RegionBranchPoint::parent(),
+                                 successors);
+        if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+            return succ.getSuccessor() == successor.getSuccessor() ||
+              (succ.isParent() && successor.isParent());
+          }))
+          predecessors.push_back(RegionBranchPoint::parent());
+        for (Region &region : $_op->getRegions()) {
+          for (::mlir::Block &block : region) {
+            if (block.empty())
+              continue;
+            if (auto terminator =
+                    dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) {
+              ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+              $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+                                       successors);
+              if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+                  return succ.getSuccessor() == successor.getSuccessor() ||
+                    (succ.isParent() && successor.isParent());
+                }))
+                predecessors.push_back(terminator);
+            }
+          }
+        }
+      }]>,
+    InterfaceMethod<[{
+        Returns the potential values across all (predecessors) for a given successor
+        input, modeled by its index.
+      }],
+      "void", "getPredecessorValues",
+      (ins "::mlir::RegionSuccessor":$successor,
+           "int":$index,
+           "::llvm::SmallVectorImpl<::mlir::Value> &":$predecessorValues),
+      [{}],
+      /*defaultImplementation=*/[{
+        ::llvm::SmallVector<::mlir::RegionBranchPoint> predecessors;
+        $_op.getPredecessors(successor, predecessors);
+        for (auto predecessor : predecessors) {
+          if (predecessor.isParent()) {
+            predecessorValues.push_back($_op.getEntrySuccessorOperands(successor)[index]);
+            continue;
+          }
+          auto terminator = cast<RegionBranchTerminatorOpInterface>(predecessor.getPredecessorOrNull());
+          predecessorValues.push_back(terminator.getSuccessorOperands(successor)[index]);
+
+        }
+      }]>,
     InterfaceMethod<[{
         Populates `invocationBounds` with the minimum and maximum number of
         times this operation will invoke the attached regions (assuming the
@@ -298,7 +388,7 @@ def RegionBranchTerminatorOpInterface :
         passing them to the region successor indicated by `point`.
       }],
       "::mlir::MutableOperandRange", "getMutableSuccessorOperands",
-      (ins "::mlir::RegionBranchPoint":$point)
+      (ins "::mlir::RegionSuccessor":$point)
     >,
     InterfaceMethod<[{
         Returns the potential region successors that are branched to after this
@@ -317,7 +407,7 @@ def RegionBranchTerminatorOpInterface :
       /*defaultImplementation=*/[{
         ::mlir::Operation *op = $_op;
         ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
-          .getSuccessorRegions(op->getParentRegion(), regions);
+          .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions);
       }]
     >,
   ];
@@ -337,8 +427,8 @@ def RegionBranchTerminatorOpInterface :
     // them to the region successor given by `index`.  If `index` is None, this
     // function returns the operands that are passed as a result to the parent
     // operation.
-    ::mlir::OperandRange getSucces...
[truncated]

can indicate either a region of this op or `RegionBranchPoint::parent()`. In
the latter case, the branch originates from outside of the op, i.e., when
first executing this op.
can indicate either a terminator in any of the recursively nested region of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it interesting that, since you explicitly talk about terminators, there could be multiple terminators in case the region has multiple blocks. I'm wondering if the interface actually supports that. There may be some code that tries to find the region terminator op, in order to understand the data flow.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be some code that tries to find the region terminator op, in order to understand the data flow.

Yes there is, in theory it should collect all the terminators and work from this.

@@ -186,35 +191,55 @@ class RegionSuccessor {
public:
/// Initialize a successor that branches to another region of the parent
/// operation.
/// TODO: the default value for the regionInputs is somehow broken.
/// A region successor should have its input correctly set.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to just drop the default value? result_range is not optional either.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it a TODO because it is challenging right now.
There is code like the following:

      RegionSuccessor region(blockArg.getParentRegion());
      SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
          regionBranchOp, region, blockArg.getArgNumber());

Here we don't know which block args to use to initialize the RegionSuccessor.
I think this code is actually buggy, but fixing it requires more changes I think, and this patch felt it was growing quite large already so I punted to a follow-up with a TODO.

}];
let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Returns the operands of this operation that are forwarded to the region
successor's block arguments or this operation's results when branching
to `point`. `point` is guaranteed to be among the successors that are
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about the old API here. RegionBranchPoint is the point from which the branch is originating. Yet, the old documentation here says that point is a successor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's somehow what led me to work on this patch: I have looked at the history of how we got there, but the current state is quite messy.

Copy link
Contributor

@fabianmcg fabianmcg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only skim through it, I'll give it another look this weekend.

Comment on lines +647 to +653

/// RegionBranchOpInterface

OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
return getInits();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer removing the interface from forall, untill the op and its terminator are really compatible with the interface. But that can be in a different PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still unsure about the fundamental compatibility of trying to map traditional control-flow over operations with parallel execution... But I would tend to keep this as a separate discussion indeed.

@joker-eph joker-eph force-pushed the regionbranchopiface branch 3 times, most recently from b89ed4c to c250b29 Compare October 6, 2025 19:30
@joker-eph joker-eph force-pushed the regionbranchopiface branch from c250b29 to 6a3c143 Compare October 13, 2025 21:18
@joker-eph
Copy link
Collaborator Author

Any more thoughts on this?

@fabianmcg
Copy link
Contributor

In general LGTM. There are still failures in flang. Also, a PSA in discourse would be nice.

This is still somehow a WIP, we have some issues with this interface that
are not trivial to solve. This patch tries to make the concepts of
RegionBranchPoint and RegionSuccessor more robust and aligned with their
definition:
- A `RegionBranchPoint` is either the parent (`RegionBranchOpInterface`)
  op or a `RegionBranchTerminatorOpInterface` operation in a nested region.
- A `RegionSuccessor` is either one of the nested region or the parent
  `RegionBranchOpInterface`

Some new methods with reasonnable default implementation are added to help
resolving the flow of values across the RegionBranchOpInterface.

It is still not trivial in the current state to walk the def-use chain
backward with this interface. For example when you have the 3rd block
argument in the entry block of a for-loop, finding the matching operands
requires to know about the hidden loop iterator block argument and where
the iterargs start. The API is designed around forward-tracking of the
chain unfortunately.
@joker-eph joker-eph force-pushed the regionbranchopiface branch from 6a3c143 to 116a4c4 Compare October 13, 2025 21:58
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Oct 13, 2025
@joker-eph
Copy link
Collaborator Author

In general LGTM. There are still failures in flang. Also, a PSA in discourse would be nice.

https://discourse.llvm.org/t/psa-regionbranchopinterface-revamping/88583

@fabianmcg fabianmcg requested review from Copilot and fabianmcg and removed request for Copilot October 13, 2025 22:07
@MaheshRavishankar
Copy link
Contributor

Would appreciate waiting for some discussion on Discourse to follow through.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been approved by reviewers so far, but probably needs more discussion on the associated discourse post since the scope of impact is unclear to me.

@MaheshRavishankar MaheshRavishankar dismissed their stale review October 14, 2025 21:47

Removing my blockere. I see Mehdi's response on the discourse post, still digesting it, but I think as long as people who know this weigh in, I dont have much more to contribute here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category mlir:affine mlir:async mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir:emitc mlir:gpu mlir:memref mlir:scf mlir:shape mlir:sparse Sparse compiler in MLIR mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants