-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR] Revamp RegionBranchOpInterface #161575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,10 +15,16 @@ | |
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H | ||
|
||
#include "mlir/IR/OpDefinition.h" | ||
#include "mlir/IR/Operation.h" | ||
#include "llvm/ADT/PointerUnion.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/Support/DebugLog.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
|
||
namespace mlir { | ||
class BranchOpInterface; | ||
class RegionBranchOpInterface; | ||
class RegionBranchTerminatorOpInterface; | ||
|
||
/// This class models how operands are forwarded to block arguments in control | ||
/// flow. It consists of a number, denoting how many of the successors block | ||
|
@@ -186,92 +192,108 @@ class RegionSuccessor { | |
public: | ||
/// Initialize a successor that branches to another region of the parent | ||
/// operation. | ||
/// TODO: the default value for the regionInputs is somehow broken. | ||
/// A region successor should have its input correctly set. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to just drop the default value? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made it a TODO because it is challenging right now.
Here we don't know which block args to use to initialize the |
||
RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {}) | ||
: region(region), inputs(regionInputs) {} | ||
: successor(region), inputs(regionInputs) { | ||
assert(region && "Region must not be null"); | ||
} | ||
/// Initialize a successor that branches back to/out of the parent operation. | ||
RegionSuccessor(Operation::result_range results) | ||
: inputs(ValueRange(results)) {} | ||
/// Constructor with no arguments. | ||
RegionSuccessor() : inputs(ValueRange()) {} | ||
/// The target must be one of the recursive parent operations. | ||
RegionSuccessor(Operation *successorOp, Operation::result_range results) | ||
: successor(successorOp), inputs(ValueRange(results)) { | ||
assert(successorOp && "Successor op must not be null"); | ||
} | ||
|
||
/// Return the given region successor. Returns nullptr if the successor is the | ||
/// parent operation. | ||
Region *getSuccessor() const { return region; } | ||
Region *getSuccessor() const { return dyn_cast<Region *>(successor); } | ||
|
||
/// Return true if the successor is the parent operation. | ||
bool isParent() const { return region == nullptr; } | ||
bool isParent() const { return isa<Operation *>(successor); } | ||
|
||
/// Return the inputs to the successor that are remapped by the exit values of | ||
/// the current region. | ||
ValueRange getSuccessorInputs() const { return inputs; } | ||
|
||
bool operator==(RegionSuccessor rhs) const { | ||
return successor == rhs.successor && inputs == rhs.inputs; | ||
} | ||
|
||
friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) { | ||
return !(lhs == rhs); | ||
} | ||
|
||
private: | ||
Region *region{nullptr}; | ||
llvm::PointerUnion<Region *, Operation *> successor{nullptr}; | ||
ValueRange inputs; | ||
}; | ||
|
||
/// This class represents a point being branched from in the methods of the | ||
/// `RegionBranchOpInterface`. | ||
/// One can branch from one of two kinds of places: | ||
/// * The parent operation (aka the `RegionBranchOpInterface` implementation) | ||
/// * A region within the parent operation. | ||
/// * A RegionBranchTerminatorOpInterface inside a region within the parent | ||
// operation. | ||
class RegionBranchPoint { | ||
public: | ||
/// Returns an instance of `RegionBranchPoint` representing the parent | ||
/// operation. | ||
static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); } | ||
|
||
/// Creates a `RegionBranchPoint` that branches from the given region. | ||
/// The pointer must not be null. | ||
RegionBranchPoint(Region *region) : maybeRegion(region) { | ||
assert(region && "Region must not be null"); | ||
} | ||
|
||
RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {} | ||
/// Creates a `RegionBranchPoint` that branches from the given terminator. | ||
inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor); | ||
|
||
/// Explicitly stops users from constructing with `nullptr`. | ||
RegionBranchPoint(std::nullptr_t) = delete; | ||
|
||
/// Constructs a `RegionBranchPoint` from the the target of a | ||
/// `RegionSuccessor` instance. | ||
RegionBranchPoint(RegionSuccessor successor) { | ||
if (successor.isParent()) | ||
maybeRegion = nullptr; | ||
else | ||
maybeRegion = successor.getSuccessor(); | ||
} | ||
|
||
/// Assigns a region being branched from. | ||
RegionBranchPoint &operator=(Region ®ion) { | ||
maybeRegion = ®ion; | ||
return *this; | ||
} | ||
|
||
/// Returns true if branching from the parent op. | ||
bool isParent() const { return maybeRegion == nullptr; } | ||
bool isParent() const { return predecessor == nullptr; } | ||
|
||
/// Returns the region if branching from a region. | ||
/// Returns the terminator if branching from a region. | ||
/// A null pointer otherwise. | ||
Region *getRegionOrNull() const { return maybeRegion; } | ||
Operation *getTerminatorPredecessorOrNull() const { return predecessor; } | ||
|
||
/// Returns true if the two branch points are equal. | ||
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) { | ||
return lhs.maybeRegion == rhs.maybeRegion; | ||
return lhs.predecessor == rhs.predecessor; | ||
} | ||
|
||
private: | ||
// Private constructor to encourage the use of `RegionBranchPoint::parent`. | ||
constexpr RegionBranchPoint() : maybeRegion(nullptr) {} | ||
constexpr RegionBranchPoint() = default; | ||
|
||
/// Internal encoding. Uses nullptr for representing branching from the parent | ||
/// op and the region being branched from otherwise. | ||
Region *maybeRegion; | ||
/// op and the region terminator being branched from otherwise. | ||
Operation *predecessor = nullptr; | ||
}; | ||
|
||
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) { | ||
return !(lhs == rhs); | ||
} | ||
|
||
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, | ||
RegionBranchPoint point) { | ||
if (point.isParent()) | ||
return os << "<from parent>"; | ||
return os << "<region #" | ||
<< point.getTerminatorPredecessorOrNull() | ||
->getParentRegion() | ||
->getRegionNumber() | ||
<< ", terminator " | ||
<< OpWithFlags(point.getTerminatorPredecessorOrNull(), | ||
OpPrintingFlags().skipRegions()) | ||
<< ">"; | ||
} | ||
|
||
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, | ||
RegionSuccessor successor) { | ||
if (successor.isParent()) | ||
return os << "<to parent>"; | ||
return os << "<to region #" << successor.getSuccessor()->getRegionNumber() | ||
<< " with " << successor.getSuccessorInputs().size() << " inputs>"; | ||
} | ||
|
||
/// This class represents upper and lower bounds on the number of times a region | ||
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least | ||
/// zero, but the upper bound may not be known. | ||
|
@@ -348,4 +370,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> { | |
/// Include the generated interface declarations. | ||
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc" | ||
|
||
namespace mlir { | ||
inline RegionBranchPoint::RegionBranchPoint( | ||
RegionBranchTerminatorOpInterface predecessor) | ||
: predecessor(predecessor.getOperation()) {} | ||
} // namespace mlir | ||
|
||
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer removing the interface from
forall
, untill the op and its terminator are really compatible with the interface. But that can be in a different PR.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still unsure about the fundamental compatibility of trying to map traditional control-flow over operations with parallel execution... But I would tend to keep this as a separate discussion indeed.