Skip to content

Commit d21cc86

Browse files
joker-ephaokblast
authored andcommitted
[MLIR] Revamp RegionBranchOpInterface (llvm#165429)
This is still somehow a WIP, we have some issues with this interface that are not trivial to solve. This patch tries to make the concepts of RegionBranchPoint and RegionSuccessor more robust and aligned with their definition: - A `RegionBranchPoint` is either the parent (`RegionBranchOpInterface`) op or a `RegionBranchTerminatorOpInterface` operation in a nested region. - A `RegionSuccessor` is either one of the nested region or the parent `RegionBranchOpInterface` Some new methods with reasonnable default implementation are added to help resolving the flow of values across the RegionBranchOpInterface. It is still not trivial in the current state to walk the def-use chain backward with this interface. For example when you have the 3rd block argument in the entry block of a for-loop, finding the matching operands requires to know about the hidden loop iterator block argument and where the iterargs start. The API is designed around forward-tracking of the chain unfortunately. Try to reland llvm#161575 ; I suspect a buildbot incremental build issue.
1 parent 20d9145 commit d21cc86

File tree

38 files changed

+828
-378
lines changed

38 files changed

+828
-378
lines changed

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4484,7 +4484,7 @@ void fir::IfOp::getSuccessorRegions(
44844484
llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
44854485
// The `then` and the `else` region branch back to the parent operation.
44864486
if (!point.isParent()) {
4487-
regions.push_back(mlir::RegionSuccessor(getResults()));
4487+
regions.push_back(mlir::RegionSuccessor(getOperation(), getResults()));
44884488
return;
44894489
}
44904490

@@ -4494,7 +4494,8 @@ void fir::IfOp::getSuccessorRegions(
44944494
// Don't consider the else region if it is empty.
44954495
mlir::Region *elseRegion = &this->getElseRegion();
44964496
if (elseRegion->empty())
4497-
regions.push_back(mlir::RegionSuccessor());
4497+
regions.push_back(
4498+
mlir::RegionSuccessor(getOperation(), getOperation()->getResults()));
44984499
else
44994500
regions.push_back(mlir::RegionSuccessor(elseRegion));
45004501
}
@@ -4513,7 +4514,7 @@ void fir::IfOp::getEntrySuccessorRegions(
45134514
if (!getElseRegion().empty())
45144515
regions.emplace_back(&getElseRegion());
45154516
else
4516-
regions.emplace_back(getResults());
4517+
regions.emplace_back(getOperation(), getOperation()->getResults());
45174518
}
45184519
}
45194520

mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
397397
/// itself.
398398
virtual void visitRegionBranchControlFlowTransfer(
399399
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
400-
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
400+
RegionSuccessor regionTo, const AbstractDenseLattice &after,
401401
AbstractDenseLattice *before) {
402402
meet(before, after);
403403
}
@@ -526,7 +526,7 @@ class DenseBackwardDataFlowAnalysis
526526
/// and "to" regions.
527527
virtual void visitRegionBranchControlFlowTransfer(
528528
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
529-
RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
529+
RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
530530
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
531531
branch, regionFrom, regionTo, after, before);
532532
}
@@ -571,7 +571,7 @@ class DenseBackwardDataFlowAnalysis
571571
}
572572
void visitRegionBranchControlFlowTransfer(
573573
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
574-
RegionBranchPoint regionTo, const AbstractDenseLattice &after,
574+
RegionSuccessor regionTo, const AbstractDenseLattice &after,
575575
AbstractDenseLattice *before) final {
576576
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
577577
static_cast<const LatticeT &>(after),

mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
286286
/// and propagating therefrom.
287287
virtual void
288288
visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
289-
RegionBranchPoint successor,
289+
RegionSuccessor successor,
290290
ArrayRef<AbstractSparseLattice *> lattices);
291291
};
292292

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [
644644

645645
/// Returns true if the mapping specified for this forall op is linear.
646646
bool usesLinearMapping();
647+
648+
/// RegionBranchOpInterface
649+
650+
OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
651+
return getInits();
652+
}
653+
647654
}];
648655
}
649656

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
2525

2626
def AlternativesOp : TransformDialectOp<"alternatives",
2727
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
28-
["getEntrySuccessorOperands", "getSuccessorRegions",
28+
["getEntrySuccessorOperands",
2929
"getRegionInvocationBounds"]>,
3030
DeclareOpInterfaceMethods<TransformOpInterface>,
3131
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -624,7 +624,7 @@ def ForeachOp : TransformDialectOp<"foreach",
624624
[DeclareOpInterfaceMethods<TransformOpInterface>,
625625
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
626626
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
627-
"getSuccessorRegions", "getEntrySuccessorOperands"]>,
627+
"getEntrySuccessorOperands"]>,
628628
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
629629
]> {
630630
let summary = "Executes the body for each element of the payload";
@@ -1237,7 +1237,7 @@ def SelectOp : TransformDialectOp<"select",
12371237

12381238
def SequenceOp : TransformDialectOp<"sequence",
12391239
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
1240-
["getEntrySuccessorOperands", "getSuccessorRegions",
1240+
["getEntrySuccessorOperands",
12411241
"getRegionInvocationBounds"]>,
12421242
MatchOpInterface,
12431243
DeclareOpInterfaceMethods<TransformOpInterface>,

mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
6363

6464
def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
6565
DeclareOpInterfaceMethods<RegionBranchOpInterface,
66-
["getEntrySuccessorOperands", "getSuccessorRegions",
66+
["getEntrySuccessorOperands",
6767
"getRegionInvocationBounds"]>,
6868
DeclareOpInterfaceMethods<TransformOpInterface>,
6969
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

mlir/include/mlir/IR/Diagnostics.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class MLIRContext;
2929
class Operation;
3030
class OperationName;
3131
class OpPrintingFlags;
32+
class OpWithFlags;
3233
class Type;
3334
class Value;
3435

@@ -199,6 +200,7 @@ class Diagnostic {
199200

200201
/// Stream in an Operation.
201202
Diagnostic &operator<<(Operation &op);
203+
Diagnostic &operator<<(OpWithFlags op);
202204
Diagnostic &operator<<(Operation *op) { return *this << *op; }
203205
/// Append an operation with the given printing flags.
204206
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);

mlir/include/mlir/IR/Operation.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,7 @@ class OpWithFlags {
11141114
: op(op), theFlags(flags) {}
11151115
OpPrintingFlags &flags() { return theFlags; }
11161116
const OpPrintingFlags &flags() const { return theFlags; }
1117+
Operation *getOperation() const { return op; }
11171118

11181119
private:
11191120
Operation *op;

mlir/include/mlir/IR/Region.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ class RegionRange
379379
friend RangeBaseT;
380380
};
381381

382+
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region &region);
383+
382384
} // namespace mlir
383385

384386
#endif // MLIR_IR_REGION_H

mlir/include/mlir/Interfaces/ControlFlowInterfaces.h

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,16 @@
1515
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
1616

1717
#include "mlir/IR/OpDefinition.h"
18+
#include "mlir/IR/Operation.h"
19+
#include "llvm/ADT/PointerUnion.h"
20+
#include "llvm/ADT/STLExtras.h"
21+
#include "llvm/Support/DebugLog.h"
22+
#include "llvm/Support/raw_ostream.h"
1823

1924
namespace mlir {
2025
class BranchOpInterface;
2126
class RegionBranchOpInterface;
27+
class RegionBranchTerminatorOpInterface;
2228

2329
/// This class models how operands are forwarded to block arguments in control
2430
/// flow. It consists of a number, denoting how many of the successors block
@@ -186,92 +192,108 @@ class RegionSuccessor {
186192
public:
187193
/// Initialize a successor that branches to another region of the parent
188194
/// operation.
195+
/// TODO: the default value for the regionInputs is somehow broken.
196+
/// A region successor should have its input correctly set.
189197
RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
190-
: region(region), inputs(regionInputs) {}
198+
: successor(region), inputs(regionInputs) {
199+
assert(region && "Region must not be null");
200+
}
191201
/// Initialize a successor that branches back to/out of the parent operation.
192-
RegionSuccessor(Operation::result_range results)
193-
: inputs(ValueRange(results)) {}
194-
/// Constructor with no arguments.
195-
RegionSuccessor() : inputs(ValueRange()) {}
202+
/// The target must be one of the recursive parent operations.
203+
RegionSuccessor(Operation *successorOp, Operation::result_range results)
204+
: successor(successorOp), inputs(ValueRange(results)) {
205+
assert(successorOp && "Successor op must not be null");
206+
}
196207

197208
/// Return the given region successor. Returns nullptr if the successor is the
198209
/// parent operation.
199-
Region *getSuccessor() const { return region; }
210+
Region *getSuccessor() const { return dyn_cast<Region *>(successor); }
200211

201212
/// Return true if the successor is the parent operation.
202-
bool isParent() const { return region == nullptr; }
213+
bool isParent() const { return isa<Operation *>(successor); }
203214

204215
/// Return the inputs to the successor that are remapped by the exit values of
205216
/// the current region.
206217
ValueRange getSuccessorInputs() const { return inputs; }
207218

219+
bool operator==(RegionSuccessor rhs) const {
220+
return successor == rhs.successor && inputs == rhs.inputs;
221+
}
222+
223+
friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
224+
return !(lhs == rhs);
225+
}
226+
208227
private:
209-
Region *region{nullptr};
228+
llvm::PointerUnion<Region *, Operation *> successor{nullptr};
210229
ValueRange inputs;
211230
};
212231

213232
/// This class represents a point being branched from in the methods of the
214233
/// `RegionBranchOpInterface`.
215234
/// One can branch from one of two kinds of places:
216235
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
217-
/// * A region within the parent operation.
236+
/// * A RegionBranchTerminatorOpInterface inside a region within the parent
237+
// operation.
218238
class RegionBranchPoint {
219239
public:
220240
/// Returns an instance of `RegionBranchPoint` representing the parent
221241
/// operation.
222242
static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
223243

224-
/// Creates a `RegionBranchPoint` that branches from the given region.
225-
/// The pointer must not be null.
226-
RegionBranchPoint(Region *region) : maybeRegion(region) {
227-
assert(region && "Region must not be null");
228-
}
229-
230-
RegionBranchPoint(Region &region) : RegionBranchPoint(&region) {}
244+
/// Creates a `RegionBranchPoint` that branches from the given terminator.
245+
inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
231246

232247
/// Explicitly stops users from constructing with `nullptr`.
233248
RegionBranchPoint(std::nullptr_t) = delete;
234249

235-
/// Constructs a `RegionBranchPoint` from the the target of a
236-
/// `RegionSuccessor` instance.
237-
RegionBranchPoint(RegionSuccessor successor) {
238-
if (successor.isParent())
239-
maybeRegion = nullptr;
240-
else
241-
maybeRegion = successor.getSuccessor();
242-
}
243-
244-
/// Assigns a region being branched from.
245-
RegionBranchPoint &operator=(Region &region) {
246-
maybeRegion = &region;
247-
return *this;
248-
}
249-
250250
/// Returns true if branching from the parent op.
251-
bool isParent() const { return maybeRegion == nullptr; }
251+
bool isParent() const { return predecessor == nullptr; }
252252

253-
/// Returns the region if branching from a region.
253+
/// Returns the terminator if branching from a region.
254254
/// A null pointer otherwise.
255-
Region *getRegionOrNull() const { return maybeRegion; }
255+
Operation *getTerminatorPredecessorOrNull() const { return predecessor; }
256256

257257
/// Returns true if the two branch points are equal.
258258
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
259-
return lhs.maybeRegion == rhs.maybeRegion;
259+
return lhs.predecessor == rhs.predecessor;
260260
}
261261

262262
private:
263263
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
264-
constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
264+
constexpr RegionBranchPoint() = default;
265265

266266
/// Internal encoding. Uses nullptr for representing branching from the parent
267-
/// op and the region being branched from otherwise.
268-
Region *maybeRegion;
267+
/// op and the region terminator being branched from otherwise.
268+
Operation *predecessor = nullptr;
269269
};
270270

271271
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
272272
return !(lhs == rhs);
273273
}
274274

275+
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
276+
RegionBranchPoint point) {
277+
if (point.isParent())
278+
return os << "<from parent>";
279+
return os << "<region #"
280+
<< point.getTerminatorPredecessorOrNull()
281+
->getParentRegion()
282+
->getRegionNumber()
283+
<< ", terminator "
284+
<< OpWithFlags(point.getTerminatorPredecessorOrNull(),
285+
OpPrintingFlags().skipRegions())
286+
<< ">";
287+
}
288+
289+
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
290+
RegionSuccessor successor) {
291+
if (successor.isParent())
292+
return os << "<to parent>";
293+
return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
294+
<< " with " << successor.getSuccessorInputs().size() << " inputs>";
295+
}
296+
275297
/// This class represents upper and lower bounds on the number of times a region
276298
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
277299
/// zero, but the upper bound may not be known.
@@ -348,4 +370,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
348370
/// Include the generated interface declarations.
349371
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
350372

373+
namespace mlir {
374+
inline RegionBranchPoint::RegionBranchPoint(
375+
RegionBranchTerminatorOpInterface predecessor)
376+
: predecessor(predecessor.getOperation()) {}
377+
} // namespace mlir
378+
351379
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H

0 commit comments

Comments
 (0)