Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4484,7 +4484,7 @@ void fir::IfOp::getSuccessorRegions(
llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (!point.isParent()) {
regions.push_back(mlir::RegionSuccessor(getResults()));
regions.push_back(mlir::RegionSuccessor(getOperation(), getResults()));
return;
}

Expand All @@ -4494,7 +4494,8 @@ void fir::IfOp::getSuccessorRegions(
// Don't consider the else region if it is empty.
mlir::Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
regions.push_back(mlir::RegionSuccessor());
regions.push_back(
mlir::RegionSuccessor(getOperation(), getOperation()->getResults()));
else
regions.push_back(mlir::RegionSuccessor(elseRegion));
}
Expand All @@ -4513,7 +4514,7 @@ void fir::IfOp::getEntrySuccessorRegions(
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
regions.emplace_back(getResults());
regions.emplace_back(getOperation(), getOperation()->getResults());
}
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
Expand Down Expand Up @@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
Expand Down Expand Up @@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
}
void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// and propagating therefrom.
virtual void
visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
RegionBranchPoint successor,
RegionSuccessor successor,
ArrayRef<AbstractSparseLattice *> lattices);
};

Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [

/// Returns true if the mapping specified for this forall op is linear.
bool usesLinearMapping();

/// RegionBranchOpInterface

OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
return getInits();
}

Comment on lines +647 to +653
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer removing the interface from forall, untill the op and its terminator are really compatible with the interface. But that can be in a different PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still unsure about the fundamental compatibility of trying to map traditional control-flow over operations with parallel execution... But I would tend to keep this as a separate discussion indeed.

}];
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"

def AlternativesOp : TransformDialectOp<"alternatives",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands", "getSuccessorRegions",
["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
Expand Down Expand Up @@ -624,7 +624,7 @@ def ForeachOp : TransformDialectOp<"foreach",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getSuccessorRegions", "getEntrySuccessorOperands"]>,
"getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
]> {
let summary = "Executes the body for each element of the payload";
Expand Down Expand Up @@ -1237,7 +1237,7 @@ def SelectOp : TransformDialectOp<"select",

def SequenceOp : TransformDialectOp<"sequence",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands", "getSuccessorRegions",
["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [

def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands", "getSuccessorRegions",
["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/Diagnostics.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class MLIRContext;
class Operation;
class OperationName;
class OpPrintingFlags;
class OpWithFlags;
class Type;
class Value;

Expand Down Expand Up @@ -199,6 +200,7 @@ class Diagnostic {

/// Stream in an Operation.
Diagnostic &operator<<(Operation &op);
Diagnostic &operator<<(OpWithFlags op);
Diagnostic &operator<<(Operation *op) { return *this << *op; }
/// Append an operation with the given printing flags.
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,7 @@ class OpWithFlags {
: op(op), theFlags(flags) {}
OpPrintingFlags &flags() { return theFlags; }
const OpPrintingFlags &flags() const { return theFlags; }
Operation *getOperation() const { return op; }

private:
Operation *op;
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/Region.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ class RegionRange
friend RangeBaseT;
};

llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region &region);

} // namespace mlir

#endif // MLIR_IR_REGION_H
104 changes: 66 additions & 38 deletions mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H

#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"

namespace mlir {
class BranchOpInterface;
class RegionBranchOpInterface;
class RegionBranchTerminatorOpInterface;

/// This class models how operands are forwarded to block arguments in control
/// flow. It consists of a number, denoting how many of the successors block
Expand Down Expand Up @@ -186,92 +192,108 @@ class RegionSuccessor {
public:
/// Initialize a successor that branches to another region of the parent
/// operation.
/// TODO: the default value for the regionInputs is somehow broken.
/// A region successor should have its input correctly set.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to just drop the default value? result_range is not optional either.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it a TODO because it is challenging right now.
There is code like the following:

      RegionSuccessor region(blockArg.getParentRegion());
      SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
          regionBranchOp, region, blockArg.getArgNumber());

Here we don't know which block args to use to initialize the RegionSuccessor.
I think this code is actually buggy, but fixing it requires more changes I think, and this patch felt it was growing quite large already so I punted to a follow-up with a TODO.

RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
: region(region), inputs(regionInputs) {}
: successor(region), inputs(regionInputs) {
assert(region && "Region must not be null");
}
/// Initialize a successor that branches back to/out of the parent operation.
RegionSuccessor(Operation::result_range results)
: inputs(ValueRange(results)) {}
/// Constructor with no arguments.
RegionSuccessor() : inputs(ValueRange()) {}
/// The target must be one of the recursive parent operations.
RegionSuccessor(Operation *successorOp, Operation::result_range results)
: successor(successorOp), inputs(ValueRange(results)) {
assert(successorOp && "Successor op must not be null");
}

/// Return the given region successor. Returns nullptr if the successor is the
/// parent operation.
Region *getSuccessor() const { return region; }
Region *getSuccessor() const { return dyn_cast<Region *>(successor); }

/// Return true if the successor is the parent operation.
bool isParent() const { return region == nullptr; }
bool isParent() const { return isa<Operation *>(successor); }

/// Return the inputs to the successor that are remapped by the exit values of
/// the current region.
ValueRange getSuccessorInputs() const { return inputs; }

bool operator==(RegionSuccessor rhs) const {
return successor == rhs.successor && inputs == rhs.inputs;
}

friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
return !(lhs == rhs);
}

private:
Region *region{nullptr};
llvm::PointerUnion<Region *, Operation *> successor{nullptr};
ValueRange inputs;
};

/// This class represents a point being branched from in the methods of the
/// `RegionBranchOpInterface`.
/// One can branch from one of two kinds of places:
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
/// * A region within the parent operation.
/// * A RegionBranchTerminatorOpInterface inside a region within the parent
// operation.
class RegionBranchPoint {
public:
/// Returns an instance of `RegionBranchPoint` representing the parent
/// operation.
static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }

/// Creates a `RegionBranchPoint` that branches from the given region.
/// The pointer must not be null.
RegionBranchPoint(Region *region) : maybeRegion(region) {
assert(region && "Region must not be null");
}

RegionBranchPoint(Region &region) : RegionBranchPoint(&region) {}
/// Creates a `RegionBranchPoint` that branches from the given terminator.
inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);

/// Explicitly stops users from constructing with `nullptr`.
RegionBranchPoint(std::nullptr_t) = delete;

/// Constructs a `RegionBranchPoint` from the the target of a
/// `RegionSuccessor` instance.
RegionBranchPoint(RegionSuccessor successor) {
if (successor.isParent())
maybeRegion = nullptr;
else
maybeRegion = successor.getSuccessor();
}

/// Assigns a region being branched from.
RegionBranchPoint &operator=(Region &region) {
maybeRegion = &region;
return *this;
}

/// Returns true if branching from the parent op.
bool isParent() const { return maybeRegion == nullptr; }
bool isParent() const { return predecessor == nullptr; }

/// Returns the region if branching from a region.
/// Returns the terminator if branching from a region.
/// A null pointer otherwise.
Region *getRegionOrNull() const { return maybeRegion; }
Operation *getTerminatorPredecessorOrNull() const { return predecessor; }

/// Returns true if the two branch points are equal.
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return lhs.maybeRegion == rhs.maybeRegion;
return lhs.predecessor == rhs.predecessor;
}

private:
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
constexpr RegionBranchPoint() = default;

/// Internal encoding. Uses nullptr for representing branching from the parent
/// op and the region being branched from otherwise.
Region *maybeRegion;
/// op and the region terminator being branched from otherwise.
Operation *predecessor = nullptr;
};

inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return !(lhs == rhs);
}

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
RegionBranchPoint point) {
if (point.isParent())
return os << "<from parent>";
return os << "<region #"
<< point.getTerminatorPredecessorOrNull()
->getParentRegion()
->getRegionNumber()
<< ", terminator "
<< OpWithFlags(point.getTerminatorPredecessorOrNull(),
OpPrintingFlags().skipRegions())
<< ">";
}

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
RegionSuccessor successor) {
if (successor.isParent())
return os << "<to parent>";
return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
<< " with " << successor.getSuccessorInputs().size() << " inputs>";
}

/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
Expand Down Expand Up @@ -348,4 +370,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
/// Include the generated interface declarations.
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"

namespace mlir {
inline RegionBranchPoint::RegionBranchPoint(
RegionBranchTerminatorOpInterface predecessor)
: predecessor(predecessor.getOperation()) {}
} // namespace mlir

#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
Loading