-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[MLIR] Revamp RegionBranchOpInterface #161575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir-gpu Author: Mehdi Amini (joker-eph) ChangesThis is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition:
Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface. It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately. Patch is 197.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161575.diff 38 Files Affected:
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 8bcfe51ad7cd1..3c87c453a4cf0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
@@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
+ RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
@@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
}
void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 3f8874d02afad..72f717e163fb6 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// and propagating therefrom.
virtual void
visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
- RegionBranchPoint successor,
+ RegionSuccessor successor,
ArrayRef<AbstractSparseLattice *> lattices);
};
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index fadd3fc10bfc4..48690151caf01 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [
/// Returns true if the mapping specified for this forall op is linear.
bool usesLinearMapping();
+
+ /// RegionBranchOpInterface
+
+ OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
+ return getInits();
+ }
+
}];
}
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 7ff718ad7f241..a0a99f4953822 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -29,6 +29,7 @@ class MLIRContext;
class Operation;
class OperationName;
class OpPrintingFlags;
+class OpWithFlags;
class Type;
class Value;
@@ -199,6 +200,7 @@ class Diagnostic {
/// Stream in an Operation.
Diagnostic &operator<<(Operation &op);
+ Diagnostic &operator<<(OpWithFlags op);
Diagnostic &operator<<(Operation *op) { return *this << *op; }
/// Append an operation with the given printing flags.
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 5569392cf0b41..b2019574a820d 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1114,6 +1114,7 @@ class OpWithFlags {
: op(op), theFlags(flags) {}
OpPrintingFlags &flags() { return theFlags; }
const OpPrintingFlags &flags() const { return theFlags; }
+ Operation *getOperation() const { return op; }
private:
Operation *op;
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 1fcb316750230..53d461df98710 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -379,6 +379,8 @@ class RegionRange
friend RangeBaseT;
};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region ®ion);
+
} // namespace mlir
#endif // MLIR_IR_REGION_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index d63800c12d132..c8304829a4df8 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -15,10 +15,15 @@
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
class BranchOpInterface;
class RegionBranchOpInterface;
+class RegionBranchTerminatorOpInterface;
/// This class models how operands are forwarded to block arguments in control
/// flow. It consists of a number, denoting how many of the successors block
@@ -186,27 +191,46 @@ class RegionSuccessor {
public:
/// Initialize a successor that branches to another region of the parent
/// operation.
+ /// TODO: the default value for the regionInputs is somehow broken.
+ /// A region successor should have its input correctly set.
RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
- : region(region), inputs(regionInputs) {}
+ : region(region), inputs(regionInputs) {
+ assert(region && "Region must not be null");
+ }
/// Initialize a successor that branches back to/out of the parent operation.
- RegionSuccessor(Operation::result_range results)
- : inputs(ValueRange(results)) {}
- /// Constructor with no arguments.
- RegionSuccessor() : inputs(ValueRange()) {}
+ /// The target must be one of the recursive parent operations.
+ RegionSuccessor(Operation *successorOop, Operation::result_range results)
+ : successorOp(successorOop), inputs(ValueRange(results)) {
+ assert(successorOop && "Successor op must not be null");
+ }
/// Return the given region successor. Returns nullptr if the successor is the
/// parent operation.
Region *getSuccessor() const { return region; }
/// Return true if the successor is the parent operation.
- bool isParent() const { return region == nullptr; }
+ bool isParent() const {
+ assert((region != nullptr || successorOp != nullptr) &&
+ "Region and successor op must not be null");
+ return region == nullptr;
+ }
/// Return the inputs to the successor that are remapped by the exit values of
/// the current region.
ValueRange getSuccessorInputs() const { return inputs; }
+ bool operator==(RegionSuccessor rhs) const {
+ return region == rhs.region && successorOp == rhs.successorOp &&
+ inputs == rhs.inputs;
+ }
+
+ friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
+ return !(lhs == rhs);
+ }
+
private:
Region *region{nullptr};
+ Operation *successorOp{nullptr};
ValueRange inputs;
};
@@ -214,7 +238,8 @@ class RegionSuccessor {
/// `RegionBranchOpInterface`.
/// One can branch from one of two kinds of places:
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A region within the parent operation.
+/// * A RegionBranchTerminatorOpInterface inside a region within the parent
+// operation.
class RegionBranchPoint {
public:
/// Returns an instance of `RegionBranchPoint` representing the parent
@@ -223,55 +248,57 @@ class RegionBranchPoint {
/// Creates a `RegionBranchPoint` that branches from the given region.
/// The pointer must not be null.
- RegionBranchPoint(Region *region) : maybeRegion(region) {
- assert(region && "Region must not be null");
- }
-
- RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {}
+ inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
/// Explicitly stops users from constructing with `nullptr`.
RegionBranchPoint(std::nullptr_t) = delete;
- /// Constructs a `RegionBranchPoint` from the the target of a
- /// `RegionSuccessor` instance.
- RegionBranchPoint(RegionSuccessor successor) {
- if (successor.isParent())
- maybeRegion = nullptr;
- else
- maybeRegion = successor.getSuccessor();
- }
-
- /// Assigns a region being branched from.
- RegionBranchPoint &operator=(Region ®ion) {
- maybeRegion = ®ion;
- return *this;
- }
-
/// Returns true if branching from the parent op.
- bool isParent() const { return maybeRegion == nullptr; }
+ bool isParent() const { return predecessor == nullptr; }
/// Returns the region if branching from a region.
/// A null pointer otherwise.
- Region *getRegionOrNull() const { return maybeRegion; }
+ Operation *getPredecessorOrNull() const { return predecessor; }
/// Returns true if the two branch points are equal.
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
- return lhs.maybeRegion == rhs.maybeRegion;
+ return lhs.predecessor == rhs.predecessor;
}
private:
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
- constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+ constexpr RegionBranchPoint() = default;
/// Internal encoding. Uses nullptr for representing branching from the parent
- /// op and the region being branched from otherwise.
- Region *maybeRegion;
+ /// op and the region terminator being branched from otherwise.
+ Operation *predecessor = nullptr;
};
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return !(lhs == rhs);
}
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionBranchPoint point) {
+ if (point.isParent())
+ return os << "<from parent>";
+ return os
+ << "<region #"
+ << point.getPredecessorOrNull()->getParentRegion()->getRegionNumber()
+ << ", terminator "
+ << OpWithFlags(point.getPredecessorOrNull(),
+ OpPrintingFlags().skipRegions())
+ << ">";
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionSuccessor successor) {
+ if (successor.isParent())
+ return os << "<to parent>";
+ return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
+ << " with " << successor.getSuccessorInputs().size() << " inputs>";
+}
+
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
@@ -348,4 +375,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
/// Include the generated interface declarations.
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
+namespace mlir {
+inline RegionBranchPoint::RegionBranchPoint(
+ RegionBranchTerminatorOpInterface predecessor)
+ : predecessor(predecessor.getOperation()) {}
+} // namespace mlir
+
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index b8d08cc553caa..c9fe3354bc6a1 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let description = [{
- This interface provides information for region operations that exhibit
+ This interface provides information for region-holding operations that exhibit
branching behavior between held regions. I.e., this interface allows for
expressing control flow information for region holding operations.
@@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
be side-effect free.
A "region branch point" indicates a point from which a branch originates. It
- can indicate either a region of this op or `RegionBranchPoint::parent()`. In
- the latter case, the branch originates from outside of the op, i.e., when
- first executing this op.
+ can indicate either a terminator in any of the recursively nested region of
+ this op or `RegionBranchPoint::parent()`. In the latter case, the branch
+ originates from outside of the op, i.e., when first executing this op.
A "region successor" indicates the target of a branch. It can indicate
- either a region of this op or this op. In the former case, the region
+ either a region of this op or this op itself. In the former case, the region
successor is a region pointer and a range of block arguments to which the
"successor operands" are forwarded to. In the latter case, the control flow
leaves this op and the region successor is a range of results of this op to
@@ -151,10 +151,25 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
}
```
- `scf.for` has one region. The region has two region successors: the region
- itself and the `scf.for` op. %b is an entry successor operand. %c is a
- successor operand. %a is a successor block argument. %r is a successor
- result.
+ `scf.for` has one region. The `scf.yield` has two region successors: the
+ region body itself and the `scf.for` op. `%b` is an entry successor
+ operand. `%c` is a successor operand. `%a` is a successor block argument.
+ `%r` is a successor result.
+
+
+ ```
+ %r = scf.loop iter_args(%a = %b)
+ -> tensor<5xf32> {
+ ...
+ scf.yield %c : tensor<5xf32>
+ }
+ ```
+
+ `scf.for` has one region. The `scf.yield` has two region successors: the
+ region body itself and the `scf.for` op. `%b` is an entry successor
+ operand. `%c` is a successor operand. `%a` is a successor block argument.
+ `%r` is a successor result.
+
}];
let cppNamespace = "::mlir";
@@ -162,16 +177,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
InterfaceMethod<[{
Returns the operands of this operation that are forwarded to the region
successor's block arguments or this operation's results when branching
- to `point`. `point` is guaranteed to be among the successors that are
+ to `successor`. `successor` is guaranteed to be among the successors that are
returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
Example: In the above example, this method returns the operand %b of the
- `scf.for` op, regardless of the value of `point`. I.e., this op always
+ `scf.for` op, regardless of the value of `successor`. I.e., this op always
forwards the same operands, regardless of whether the loop has 0 or more
iterations.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point), [{}],
+ (ins "::mlir::RegionSuccessor":$successor), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -224,6 +239,81 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::mlir::RegionBranchPoint":$point,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
+ InterfaceMethod<[{
+ Returns the potential region successors when branching from any
+ terminator in `region`.
+ These are the regions that may be selected during the flow of control.
+ }],
+ "void", "getSuccessorRegions",
+ (ins "::mlir::Region&":$region,
+ "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+ [{}],
+ /*defaultImplementation=*/[{
+ for (::mlir::Block &block : region) {
+ if (block.empty())
+ continue;
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+ $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+ regions);
+ }
+ }]>,
+ InterfaceMethod<[{
+ Returns the potential branching point (predecessors) for a given successor.
+ }],
+ "void", "getPredecessors",
+ (ins "::mlir::RegionSuccessor":$successor,
+ "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
+ [{}],
+ /*defaultImplementation=*/[{
+ ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+ $_op.getSuccessorRegions(RegionBranchPoint::parent(),
+ successors);
+ if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+ return succ.getSuccessor() == successor.getSuccessor() ||
+ (succ.isParent() && successor.isParent());
+ }))
+ predecessors.push_back(RegionBranchPoint::parent());
+ for (Region ®ion : $_op->getRegions()) {
+ for (::mlir::Block &block : region) {
+ if (block.empty())
+ continue;
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) {
+ ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+ $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+ successors);
+ if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+ return succ.getSuccessor() == successor.getSuccessor() ||
+ (succ.isParent() && successor.isParent());
+ }))
+ predecessors.push_back(terminator);
+ }
+ }
+ }
+ }]>,
+ InterfaceMethod<[{
+ Returns the potential values across all (predecessors) for a given successor
+ input, modeled by its index.
+ }],
+ "void", "getPredecessorValues",
+ (ins "::mlir::RegionSuccessor":$successor,
+ "int":$index,
+ "::llvm::SmallVectorImpl<::mlir::Value> &":$predecessorValues),
+ [{}],
+ /*defaultImplementation=*/[{
+ ::llvm::SmallVector<::mlir::RegionBranchPoint> predecessors;
+ $_op.getPredecessors(successor, predecessors);
+ for (auto predecessor : predecessors) {
+ if (predecessor.isParent()) {
+ predecessorValues.push_back($_op.getEntrySuccessorOperands(successor)[index]);
+ continue;
+ }
+ auto terminator = cast<RegionBranchTerminatorOpInterface>(predecessor.getPredecessorOrNull());
+ predecessorValues.push_back(terminator.getSuccessorOperands(successor)[index]);
+
+ }
+ }]>,
InterfaceMethod<[{
Populates `invocationBounds` with the minimum and maximum number of
times this operation will invoke the attached regions (assuming the
@@ -298,7 +388,7 @@ def RegionBranchTerminatorOpInterface :
passing them to the region successor indicated by `point`.
}],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point)
+ (ins "::mlir::RegionSuccessor":$point)
>,
InterfaceMethod<[{
Returns the potential region successors that are branched to after this
@@ -317,7 +407,7 @@ def RegionBranchTerminatorOpInterface :
/*defaultImplementation=*/[{
::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
- .getSuccessorRegions(op->getParentRegion(), regions);
+ .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions);
}]
>,
];
@@ -337,8 +427,8 @@ def RegionBranchTerminatorOpInterface :
// them to the region successor given by `index`. If `index` is None, this
// function returns the operands that are passed as a result to the parent
// operation.
- ::mlir::OperandRange getSucces...
[truncated]
|
@llvm/pr-subscribers-mlir-shape Author: Mehdi Amini (joker-eph) ChangesThis is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition:
Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface. It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately. Patch is 197.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161575.diff 38 Files Affected:
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 8bcfe51ad7cd1..3c87c453a4cf0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
@@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
+ RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
@@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
}
void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 3f8874d02afad..72f717e163fb6 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// and propagating therefrom.
virtual void
visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
- RegionBranchPoint successor,
+ RegionSuccessor successor,
ArrayRef<AbstractSparseLattice *> lattices);
};
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index fadd3fc10bfc4..48690151caf01 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [
/// Returns true if the mapping specified for this forall op is linear.
bool usesLinearMapping();
+
+ /// RegionBranchOpInterface
+
+ OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
+ return getInits();
+ }
+
}];
}
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 7ff718ad7f241..a0a99f4953822 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -29,6 +29,7 @@ class MLIRContext;
class Operation;
class OperationName;
class OpPrintingFlags;
+class OpWithFlags;
class Type;
class Value;
@@ -199,6 +200,7 @@ class Diagnostic {
/// Stream in an Operation.
Diagnostic &operator<<(Operation &op);
+ Diagnostic &operator<<(OpWithFlags op);
Diagnostic &operator<<(Operation *op) { return *this << *op; }
/// Append an operation with the given printing flags.
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 5569392cf0b41..b2019574a820d 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1114,6 +1114,7 @@ class OpWithFlags {
: op(op), theFlags(flags) {}
OpPrintingFlags &flags() { return theFlags; }
const OpPrintingFlags &flags() const { return theFlags; }
+ Operation *getOperation() const { return op; }
private:
Operation *op;
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 1fcb316750230..53d461df98710 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -379,6 +379,8 @@ class RegionRange
friend RangeBaseT;
};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region ®ion);
+
} // namespace mlir
#endif // MLIR_IR_REGION_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index d63800c12d132..c8304829a4df8 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -15,10 +15,15 @@
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
class BranchOpInterface;
class RegionBranchOpInterface;
+class RegionBranchTerminatorOpInterface;
/// This class models how operands are forwarded to block arguments in control
/// flow. It consists of a number, denoting how many of the successors block
@@ -186,27 +191,46 @@ class RegionSuccessor {
public:
/// Initialize a successor that branches to another region of the parent
/// operation.
+ /// TODO: the default value for the regionInputs is somehow broken.
+ /// A region successor should have its input correctly set.
RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
- : region(region), inputs(regionInputs) {}
+ : region(region), inputs(regionInputs) {
+ assert(region && "Region must not be null");
+ }
/// Initialize a successor that branches back to/out of the parent operation.
- RegionSuccessor(Operation::result_range results)
- : inputs(ValueRange(results)) {}
- /// Constructor with no arguments.
- RegionSuccessor() : inputs(ValueRange()) {}
+ /// The target must be one of the recursive parent operations.
+ RegionSuccessor(Operation *successorOop, Operation::result_range results)
+ : successorOp(successorOop), inputs(ValueRange(results)) {
+ assert(successorOop && "Successor op must not be null");
+ }
/// Return the given region successor. Returns nullptr if the successor is the
/// parent operation.
Region *getSuccessor() const { return region; }
/// Return true if the successor is the parent operation.
- bool isParent() const { return region == nullptr; }
+ bool isParent() const {
+ assert((region != nullptr || successorOp != nullptr) &&
+ "Region and successor op must not be null");
+ return region == nullptr;
+ }
/// Return the inputs to the successor that are remapped by the exit values of
/// the current region.
ValueRange getSuccessorInputs() const { return inputs; }
+ bool operator==(RegionSuccessor rhs) const {
+ return region == rhs.region && successorOp == rhs.successorOp &&
+ inputs == rhs.inputs;
+ }
+
+ friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
+ return !(lhs == rhs);
+ }
+
private:
Region *region{nullptr};
+ Operation *successorOp{nullptr};
ValueRange inputs;
};
@@ -214,7 +238,8 @@ class RegionSuccessor {
/// `RegionBranchOpInterface`.
/// One can branch from one of two kinds of places:
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A region within the parent operation.
+/// * A RegionBranchTerminatorOpInterface inside a region within the parent
+// operation.
class RegionBranchPoint {
public:
/// Returns an instance of `RegionBranchPoint` representing the parent
@@ -223,55 +248,57 @@ class RegionBranchPoint {
/// Creates a `RegionBranchPoint` that branches from the given region.
/// The pointer must not be null.
- RegionBranchPoint(Region *region) : maybeRegion(region) {
- assert(region && "Region must not be null");
- }
-
- RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {}
+ inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
/// Explicitly stops users from constructing with `nullptr`.
RegionBranchPoint(std::nullptr_t) = delete;
- /// Constructs a `RegionBranchPoint` from the the target of a
- /// `RegionSuccessor` instance.
- RegionBranchPoint(RegionSuccessor successor) {
- if (successor.isParent())
- maybeRegion = nullptr;
- else
- maybeRegion = successor.getSuccessor();
- }
-
- /// Assigns a region being branched from.
- RegionBranchPoint &operator=(Region ®ion) {
- maybeRegion = ®ion;
- return *this;
- }
-
/// Returns true if branching from the parent op.
- bool isParent() const { return maybeRegion == nullptr; }
+ bool isParent() const { return predecessor == nullptr; }
/// Returns the region if branching from a region.
/// A null pointer otherwise.
- Region *getRegionOrNull() const { return maybeRegion; }
+ Operation *getPredecessorOrNull() const { return predecessor; }
/// Returns true if the two branch points are equal.
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
- return lhs.maybeRegion == rhs.maybeRegion;
+ return lhs.predecessor == rhs.predecessor;
}
private:
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
- constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+ constexpr RegionBranchPoint() = default;
/// Internal encoding. Uses nullptr for representing branching from the parent
- /// op and the region being branched from otherwise.
- Region *maybeRegion;
+ /// op and the region terminator being branched from otherwise.
+ Operation *predecessor = nullptr;
};
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return !(lhs == rhs);
}
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionBranchPoint point) {
+ if (point.isParent())
+ return os << "<from parent>";
+ return os
+ << "<region #"
+ << point.getPredecessorOrNull()->getParentRegion()->getRegionNumber()
+ << ", terminator "
+ << OpWithFlags(point.getPredecessorOrNull(),
+ OpPrintingFlags().skipRegions())
+ << ">";
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionSuccessor successor) {
+ if (successor.isParent())
+ return os << "<to parent>";
+ return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
+ << " with " << successor.getSuccessorInputs().size() << " inputs>";
+}
+
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
@@ -348,4 +375,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
/// Include the generated interface declarations.
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
+namespace mlir {
+inline RegionBranchPoint::RegionBranchPoint(
+ RegionBranchTerminatorOpInterface predecessor)
+ : predecessor(predecessor.getOperation()) {}
+} // namespace mlir
+
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index b8d08cc553caa..c9fe3354bc6a1 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let description = [{
- This interface provides information for region operations that exhibit
+ This interface provides information for region-holding operations that exhibit
branching behavior between held regions. I.e., this interface allows for
expressing control flow information for region holding operations.
@@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
be side-effect free.
A "region branch point" indicates a point from which a branch originates. It
- can indicate either a region of this op or `RegionBranchPoint::parent()`. In
- the latter case, the branch originates from outside of the op, i.e., when
- first executing this op.
+ can indicate either a terminator in any of the recursively nested region of
+ this op or `RegionBranchPoint::parent()`. In the latter case, the branch
+ originates from outside of the op, i.e., when first executing this op.
A "region successor" indicates the target of a branch. It can indicate
- either a region of this op or this op. In the former case, the region
+ either a region of this op or this op itself. In the former case, the region
successor is a region pointer and a range of block arguments to which the
"successor operands" are forwarded to. In the latter case, the control flow
leaves this op and the region successor is a range of results of this op to
@@ -151,10 +151,25 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
}
```
- `scf.for` has one region. The region has two region successors: the region
- itself and the `scf.for` op. %b is an entry successor operand. %c is a
- successor operand. %a is a successor block argument. %r is a successor
- result.
+ `scf.for` has one region. The `scf.yield` has two region successors: the
+ region body itself and the `scf.for` op. `%b` is an entry successor
+ operand. `%c` is a successor operand. `%a` is a successor block argument.
+ `%r` is a successor result.
+
+
+ ```
+ %r = scf.loop iter_args(%a = %b)
+ -> tensor<5xf32> {
+ ...
+ scf.yield %c : tensor<5xf32>
+ }
+ ```
+
+ `scf.for` has one region. The `scf.yield` has two region successors: the
+ region body itself and the `scf.for` op. `%b` is an entry successor
+ operand. `%c` is a successor operand. `%a` is a successor block argument.
+ `%r` is a successor result.
+
}];
let cppNamespace = "::mlir";
@@ -162,16 +177,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
InterfaceMethod<[{
Returns the operands of this operation that are forwarded to the region
successor's block arguments or this operation's results when branching
- to `point`. `point` is guaranteed to be among the successors that are
+ to `successor`. `successor` is guaranteed to be among the successors that are
returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
Example: In the above example, this method returns the operand %b of the
- `scf.for` op, regardless of the value of `point`. I.e., this op always
+ `scf.for` op, regardless of the value of `successor`. I.e., this op always
forwards the same operands, regardless of whether the loop has 0 or more
iterations.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point), [{}],
+ (ins "::mlir::RegionSuccessor":$successor), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -224,6 +239,81 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::mlir::RegionBranchPoint":$point,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
+ InterfaceMethod<[{
+ Returns the potential region successors when branching from any
+ terminator in `region`.
+ These are the regions that may be selected during the flow of control.
+ }],
+ "void", "getSuccessorRegions",
+ (ins "::mlir::Region&":$region,
+ "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+ [{}],
+ /*defaultImplementation=*/[{
+ for (::mlir::Block &block : region) {
+ if (block.empty())
+ continue;
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+ $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+ regions);
+ }
+ }]>,
+ InterfaceMethod<[{
+ Returns the potential branching point (predecessors) for a given successor.
+ }],
+ "void", "getPredecessors",
+ (ins "::mlir::RegionSuccessor":$successor,
+ "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
+ [{}],
+ /*defaultImplementation=*/[{
+ ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+ $_op.getSuccessorRegions(RegionBranchPoint::parent(),
+ successors);
+ if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+ return succ.getSuccessor() == successor.getSuccessor() ||
+ (succ.isParent() && successor.isParent());
+ }))
+ predecessors.push_back(RegionBranchPoint::parent());
+ for (Region ®ion : $_op->getRegions()) {
+ for (::mlir::Block &block : region) {
+ if (block.empty())
+ continue;
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) {
+ ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+ $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+ successors);
+ if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+ return succ.getSuccessor() == successor.getSuccessor() ||
+ (succ.isParent() && successor.isParent());
+ }))
+ predecessors.push_back(terminator);
+ }
+ }
+ }
+ }]>,
+ InterfaceMethod<[{
+ Returns the potential values across all (predecessors) for a given successor
+ input, modeled by its index.
+ }],
+ "void", "getPredecessorValues",
+ (ins "::mlir::RegionSuccessor":$successor,
+ "int":$index,
+ "::llvm::SmallVectorImpl<::mlir::Value> &":$predecessorValues),
+ [{}],
+ /*defaultImplementation=*/[{
+ ::llvm::SmallVector<::mlir::RegionBranchPoint> predecessors;
+ $_op.getPredecessors(successor, predecessors);
+ for (auto predecessor : predecessors) {
+ if (predecessor.isParent()) {
+ predecessorValues.push_back($_op.getEntrySuccessorOperands(successor)[index]);
+ continue;
+ }
+ auto terminator = cast<RegionBranchTerminatorOpInterface>(predecessor.getPredecessorOrNull());
+ predecessorValues.push_back(terminator.getSuccessorOperands(successor)[index]);
+
+ }
+ }]>,
InterfaceMethod<[{
Populates `invocationBounds` with the minimum and maximum number of
times this operation will invoke the attached regions (assuming the
@@ -298,7 +388,7 @@ def RegionBranchTerminatorOpInterface :
passing them to the region successor indicated by `point`.
}],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point)
+ (ins "::mlir::RegionSuccessor":$point)
>,
InterfaceMethod<[{
Returns the potential region successors that are branched to after this
@@ -317,7 +407,7 @@ def RegionBranchTerminatorOpInterface :
/*defaultImplementation=*/[{
::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
- .getSuccessorRegions(op->getParentRegion(), regions);
+ .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions);
}]
>,
];
@@ -337,8 +427,8 @@ def RegionBranchTerminatorOpInterface :
// them to the region successor given by `index`. If `index` is None, this
// function returns the operands that are passed as a result to the parent
// operation.
- ::mlir::OperandRange getSucces...
[truncated]
|
@llvm/pr-subscribers-mlir-sparse Author: Mehdi Amini (joker-eph) ChangesThis is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition:
Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface. It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately. Patch is 197.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161575.diff 38 Files Affected:
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 8bcfe51ad7cd1..3c87c453a4cf0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
@@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
+ RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
@@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
}
void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 3f8874d02afad..72f717e163fb6 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// and propagating therefrom.
virtual void
visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
- RegionBranchPoint successor,
+ RegionSuccessor successor,
ArrayRef<AbstractSparseLattice *> lattices);
};
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index fadd3fc10bfc4..48690151caf01 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [
/// Returns true if the mapping specified for this forall op is linear.
bool usesLinearMapping();
+
+ /// RegionBranchOpInterface
+
+ OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
+ return getInits();
+ }
+
}];
}
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 7ff718ad7f241..a0a99f4953822 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -29,6 +29,7 @@ class MLIRContext;
class Operation;
class OperationName;
class OpPrintingFlags;
+class OpWithFlags;
class Type;
class Value;
@@ -199,6 +200,7 @@ class Diagnostic {
/// Stream in an Operation.
Diagnostic &operator<<(Operation &op);
+ Diagnostic &operator<<(OpWithFlags op);
Diagnostic &operator<<(Operation *op) { return *this << *op; }
/// Append an operation with the given printing flags.
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 5569392cf0b41..b2019574a820d 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1114,6 +1114,7 @@ class OpWithFlags {
: op(op), theFlags(flags) {}
OpPrintingFlags &flags() { return theFlags; }
const OpPrintingFlags &flags() const { return theFlags; }
+ Operation *getOperation() const { return op; }
private:
Operation *op;
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 1fcb316750230..53d461df98710 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -379,6 +379,8 @@ class RegionRange
friend RangeBaseT;
};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region ®ion);
+
} // namespace mlir
#endif // MLIR_IR_REGION_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index d63800c12d132..c8304829a4df8 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -15,10 +15,15 @@
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
class BranchOpInterface;
class RegionBranchOpInterface;
+class RegionBranchTerminatorOpInterface;
/// This class models how operands are forwarded to block arguments in control
/// flow. It consists of a number, denoting how many of the successors block
@@ -186,27 +191,46 @@ class RegionSuccessor {
public:
/// Initialize a successor that branches to another region of the parent
/// operation.
+ /// TODO: the default value for the regionInputs is somehow broken.
+ /// A region successor should have its input correctly set.
RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
- : region(region), inputs(regionInputs) {}
+ : region(region), inputs(regionInputs) {
+ assert(region && "Region must not be null");
+ }
/// Initialize a successor that branches back to/out of the parent operation.
- RegionSuccessor(Operation::result_range results)
- : inputs(ValueRange(results)) {}
- /// Constructor with no arguments.
- RegionSuccessor() : inputs(ValueRange()) {}
+ /// The target must be one of the recursive parent operations.
+ RegionSuccessor(Operation *successorOop, Operation::result_range results)
+ : successorOp(successorOop), inputs(ValueRange(results)) {
+ assert(successorOop && "Successor op must not be null");
+ }
/// Return the given region successor. Returns nullptr if the successor is the
/// parent operation.
Region *getSuccessor() const { return region; }
/// Return true if the successor is the parent operation.
- bool isParent() const { return region == nullptr; }
+ bool isParent() const {
+ assert((region != nullptr || successorOp != nullptr) &&
+ "Region and successor op must not be null");
+ return region == nullptr;
+ }
/// Return the inputs to the successor that are remapped by the exit values of
/// the current region.
ValueRange getSuccessorInputs() const { return inputs; }
+ bool operator==(RegionSuccessor rhs) const {
+ return region == rhs.region && successorOp == rhs.successorOp &&
+ inputs == rhs.inputs;
+ }
+
+ friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
+ return !(lhs == rhs);
+ }
+
private:
Region *region{nullptr};
+ Operation *successorOp{nullptr};
ValueRange inputs;
};
@@ -214,7 +238,8 @@ class RegionSuccessor {
/// `RegionBranchOpInterface`.
/// One can branch from one of two kinds of places:
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A region within the parent operation.
+/// * A RegionBranchTerminatorOpInterface inside a region within the parent
+// operation.
class RegionBranchPoint {
public:
/// Returns an instance of `RegionBranchPoint` representing the parent
@@ -223,55 +248,57 @@ class RegionBranchPoint {
/// Creates a `RegionBranchPoint` that branches from the given region.
/// The pointer must not be null.
- RegionBranchPoint(Region *region) : maybeRegion(region) {
- assert(region && "Region must not be null");
- }
-
- RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {}
+ inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
/// Explicitly stops users from constructing with `nullptr`.
RegionBranchPoint(std::nullptr_t) = delete;
- /// Constructs a `RegionBranchPoint` from the the target of a
- /// `RegionSuccessor` instance.
- RegionBranchPoint(RegionSuccessor successor) {
- if (successor.isParent())
- maybeRegion = nullptr;
- else
- maybeRegion = successor.getSuccessor();
- }
-
- /// Assigns a region being branched from.
- RegionBranchPoint &operator=(Region ®ion) {
- maybeRegion = ®ion;
- return *this;
- }
-
/// Returns true if branching from the parent op.
- bool isParent() const { return maybeRegion == nullptr; }
+ bool isParent() const { return predecessor == nullptr; }
/// Returns the region if branching from a region.
/// A null pointer otherwise.
- Region *getRegionOrNull() const { return maybeRegion; }
+ Operation *getPredecessorOrNull() const { return predecessor; }
/// Returns true if the two branch points are equal.
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
- return lhs.maybeRegion == rhs.maybeRegion;
+ return lhs.predecessor == rhs.predecessor;
}
private:
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
- constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+ constexpr RegionBranchPoint() = default;
/// Internal encoding. Uses nullptr for representing branching from the parent
- /// op and the region being branched from otherwise.
- Region *maybeRegion;
+ /// op and the region terminator being branched from otherwise.
+ Operation *predecessor = nullptr;
};
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return !(lhs == rhs);
}
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionBranchPoint point) {
+ if (point.isParent())
+ return os << "<from parent>";
+ return os
+ << "<region #"
+ << point.getPredecessorOrNull()->getParentRegion()->getRegionNumber()
+ << ", terminator "
+ << OpWithFlags(point.getPredecessorOrNull(),
+ OpPrintingFlags().skipRegions())
+ << ">";
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionSuccessor successor) {
+ if (successor.isParent())
+ return os << "<to parent>";
+ return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
+ << " with " << successor.getSuccessorInputs().size() << " inputs>";
+}
+
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
@@ -348,4 +375,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
/// Include the generated interface declarations.
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
+namespace mlir {
+inline RegionBranchPoint::RegionBranchPoint(
+ RegionBranchTerminatorOpInterface predecessor)
+ : predecessor(predecessor.getOperation()) {}
+} // namespace mlir
+
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index b8d08cc553caa..c9fe3354bc6a1 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let description = [{
- This interface provides information for region operations that exhibit
+ This interface provides information for region-holding operations that exhibit
branching behavior between held regions. I.e., this interface allows for
expressing control flow information for region holding operations.
@@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
be side-effect free.
A "region branch point" indicates a point from which a branch originates. It
- can indicate either a region of this op or `RegionBranchPoint::parent()`. In
- the latter case, the branch originates from outside of the op, i.e., when
- first executing this op.
+ can indicate either a terminator in any of the recursively nested region of
+ this op or `RegionBranchPoint::parent()`. In the latter case, the branch
+ originates from outside of the op, i.e., when first executing this op.
A "region successor" indicates the target of a branch. It can indicate
- either a region of this op or this op. In the former case, the region
+ either a region of this op or this op itself. In the former case, the region
successor is a region pointer and a range of block arguments to which the
"successor operands" are forwarded to. In the latter case, the control flow
leaves this op and the region successor is a range of results of this op to
@@ -151,10 +151,25 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
}
```
- `scf.for` has one region. The region has two region successors: the region
- itself and the `scf.for` op. %b is an entry successor operand. %c is a
- successor operand. %a is a successor block argument. %r is a successor
- result.
+ `scf.for` has one region. The `scf.yield` has two region successors: the
+ region body itself and the `scf.for` op. `%b` is an entry successor
+ operand. `%c` is a successor operand. `%a` is a successor block argument.
+ `%r` is a successor result.
+
+
+ ```
+ %r = scf.loop iter_args(%a = %b)
+ -> tensor<5xf32> {
+ ...
+ scf.yield %c : tensor<5xf32>
+ }
+ ```
+
+ `scf.for` has one region. The `scf.yield` has two region successors: the
+ region body itself and the `scf.for` op. `%b` is an entry successor
+ operand. `%c` is a successor operand. `%a` is a successor block argument.
+ `%r` is a successor result.
+
}];
let cppNamespace = "::mlir";
@@ -162,16 +177,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
InterfaceMethod<[{
Returns the operands of this operation that are forwarded to the region
successor's block arguments or this operation's results when branching
- to `point`. `point` is guaranteed to be among the successors that are
+ to `successor`. `successor` is guaranteed to be among the successors that are
returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
Example: In the above example, this method returns the operand %b of the
- `scf.for` op, regardless of the value of `point`. I.e., this op always
+ `scf.for` op, regardless of the value of `successor`. I.e., this op always
forwards the same operands, regardless of whether the loop has 0 or more
iterations.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point), [{}],
+ (ins "::mlir::RegionSuccessor":$successor), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -224,6 +239,81 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
(ins "::mlir::RegionBranchPoint":$point,
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
+ InterfaceMethod<[{
+ Returns the potential region successors when branching from any
+ terminator in `region`.
+ These are the regions that may be selected during the flow of control.
+ }],
+ "void", "getSuccessorRegions",
+ (ins "::mlir::Region&":$region,
+ "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+ [{}],
+ /*defaultImplementation=*/[{
+ for (::mlir::Block &block : region) {
+ if (block.empty())
+ continue;
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+ $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+ regions);
+ }
+ }]>,
+ InterfaceMethod<[{
+ Returns the potential branching point (predecessors) for a given successor.
+ }],
+ "void", "getPredecessors",
+ (ins "::mlir::RegionSuccessor":$successor,
+ "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
+ [{}],
+ /*defaultImplementation=*/[{
+ ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+ $_op.getSuccessorRegions(RegionBranchPoint::parent(),
+ successors);
+ if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+ return succ.getSuccessor() == successor.getSuccessor() ||
+ (succ.isParent() && successor.isParent());
+ }))
+ predecessors.push_back(RegionBranchPoint::parent());
+ for (Region ®ion : $_op->getRegions()) {
+ for (::mlir::Block &block : region) {
+ if (block.empty())
+ continue;
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) {
+ ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+ $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+ successors);
+ if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+ return succ.getSuccessor() == successor.getSuccessor() ||
+ (succ.isParent() && successor.isParent());
+ }))
+ predecessors.push_back(terminator);
+ }
+ }
+ }
+ }]>,
+ InterfaceMethod<[{
+ Returns the potential values across all (predecessors) for a given successor
+ input, modeled by its index.
+ }],
+ "void", "getPredecessorValues",
+ (ins "::mlir::RegionSuccessor":$successor,
+ "int":$index,
+ "::llvm::SmallVectorImpl<::mlir::Value> &":$predecessorValues),
+ [{}],
+ /*defaultImplementation=*/[{
+ ::llvm::SmallVector<::mlir::RegionBranchPoint> predecessors;
+ $_op.getPredecessors(successor, predecessors);
+ for (auto predecessor : predecessors) {
+ if (predecessor.isParent()) {
+ predecessorValues.push_back($_op.getEntrySuccessorOperands(successor)[index]);
+ continue;
+ }
+ auto terminator = cast<RegionBranchTerminatorOpInterface>(predecessor.getPredecessorOrNull());
+ predecessorValues.push_back(terminator.getSuccessorOperands(successor)[index]);
+
+ }
+ }]>,
InterfaceMethod<[{
Populates `invocationBounds` with the minimum and maximum number of
times this operation will invoke the attached regions (assuming the
@@ -298,7 +388,7 @@ def RegionBranchTerminatorOpInterface :
passing them to the region successor indicated by `point`.
}],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point)
+ (ins "::mlir::RegionSuccessor":$point)
>,
InterfaceMethod<[{
Returns the potential region successors that are branched to after this
@@ -317,7 +407,7 @@ def RegionBranchTerminatorOpInterface :
/*defaultImplementation=*/[{
::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
- .getSuccessorRegions(op->getParentRegion(), regions);
+ .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions);
}]
>,
];
@@ -337,8 +427,8 @@ def RegionBranchTerminatorOpInterface :
// them to the region successor given by `index`. If `index` is None, this
// function returns the operands that are passed as a result to the parent
// operation.
- ::mlir::OperandRange getSucces...
[truncated]
|
can indicate either a region of this op or `RegionBranchPoint::parent()`. In | ||
the latter case, the branch originates from outside of the op, i.e., when | ||
first executing this op. | ||
can indicate either a terminator in any of the recursively nested region of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found it interesting that, since you explicitly talk about terminators, there could be multiple terminators in case the region has multiple blocks. I'm wondering if the interface actually supports that. There may be some code that tries to find the region terminator op, in order to understand the data flow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There may be some code that tries to find the region terminator op, in order to understand the data flow.
Yes there is, in theory it should collect all the terminators and work from this.
@@ -186,35 +191,55 @@ class RegionSuccessor { | |||
public: | |||
/// Initialize a successor that branches to another region of the parent | |||
/// operation. | |||
/// TODO: the default value for the regionInputs is somehow broken. | |||
/// A region successor should have its input correctly set. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to just drop the default value? result_range
is not optional either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it a TODO because it is challenging right now.
There is code like the following:
RegionSuccessor region(blockArg.getParentRegion());
SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
regionBranchOp, region, blockArg.getArgNumber());
Here we don't know which block args to use to initialize the RegionSuccessor
.
I think this code is actually buggy, but fixing it requires more changes I think, and this patch felt it was growing quite large already so I punted to a follow-up with a TODO.
}]; | ||
let cppNamespace = "::mlir"; | ||
|
||
let methods = [ | ||
InterfaceMethod<[{ | ||
Returns the operands of this operation that are forwarded to the region | ||
successor's block arguments or this operation's results when branching | ||
to `point`. `point` is guaranteed to be among the successors that are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused about the old API here. RegionBranchPoint
is the point from which the branch is originating. Yet, the old documentation here says that point
is a successor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's somehow what led me to work on this patch: I have looked at the history of how we got there, but the current state is quite messy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only skim through it, I'll give it another look this weekend.
|
||
/// RegionBranchOpInterface | ||
|
||
OperandRange getEntrySuccessorOperands(RegionSuccessor successor) { | ||
return getInits(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer removing the interface from forall
, untill the op and its terminator are really compatible with the interface. But that can be in a different PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still unsure about the fundamental compatibility of trying to map traditional control-flow over operations with parallel execution... But I would tend to keep this as a separate discussion indeed.
b89ed4c
to
c250b29
Compare
c250b29
to
6a3c143
Compare
Any more thoughts on this? |
In general LGTM. There are still failures in flang. Also, a PSA in discourse would be nice. |
This is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition: - A `RegionBranchPoint` is either the parent (`RegionBranchOpInterface`) op or a `RegionBranchTerminatorOpInterface` operation in a nested region. - A `RegionSuccessor` is either one of the nested region or the parent `RegionBranchOpInterface` Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface. It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately.
6a3c143
to
116a4c4
Compare
https://discourse.llvm.org/t/psa-regionbranchopinterface-revamping/88583 |
Would appreciate waiting for some discussion on Discourse to follow through. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been approved by reviewers so far, but probably needs more discussion on the associated discourse post since the scope of impact is unclear to me.
Removing my blockere. I see Mehdi's response on the discourse post, still digesting it, but I think as long as people who know this weigh in, I dont have much more to contribute here.
This is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition:
RegionBranchPoint
is either the parent (RegionBranchOpInterface
) op or aRegionBranchTerminatorOpInterface
operation in a nested region.RegionSuccessor
is either one of the nested region or the parentRegionBranchOpInterface
Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface.
It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately.