| 
15 | 15 | #define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H  | 
16 | 16 | 
 
  | 
17 | 17 | #include "mlir/IR/OpDefinition.h"  | 
 | 18 | +#include "mlir/IR/Operation.h"  | 
 | 19 | +#include "llvm/ADT/PointerUnion.h"  | 
 | 20 | +#include "llvm/ADT/STLExtras.h"  | 
 | 21 | +#include "llvm/Support/DebugLog.h"  | 
 | 22 | +#include "llvm/Support/raw_ostream.h"  | 
18 | 23 | 
 
  | 
19 | 24 | namespace mlir {  | 
20 | 25 | class BranchOpInterface;  | 
21 | 26 | class RegionBranchOpInterface;  | 
 | 27 | +class RegionBranchTerminatorOpInterface;  | 
22 | 28 | 
 
  | 
23 | 29 | /// This class models how operands are forwarded to block arguments in control  | 
24 | 30 | /// flow. It consists of a number, denoting how many of the successors block  | 
@@ -186,92 +192,108 @@ class RegionSuccessor {  | 
186 | 192 | public:  | 
187 | 193 |   /// Initialize a successor that branches to another region of the parent  | 
188 | 194 |   /// operation.  | 
 | 195 | +  /// TODO: the default value for the regionInputs is somehow broken.  | 
 | 196 | +  /// A region successor should have its input correctly set.  | 
189 | 197 |   RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})  | 
190 |  | -      : region(region), inputs(regionInputs) {}  | 
 | 198 | +      : successor(region), inputs(regionInputs) {  | 
 | 199 | +    assert(region && "Region must not be null");  | 
 | 200 | +  }  | 
191 | 201 |   /// Initialize a successor that branches back to/out of the parent operation.  | 
192 |  | -  RegionSuccessor(Operation::result_range results)  | 
193 |  | -      : inputs(ValueRange(results)) {}  | 
194 |  | -  /// Constructor with no arguments.  | 
195 |  | -  RegionSuccessor() : inputs(ValueRange()) {}  | 
 | 202 | +  /// The target must be one of the recursive parent operations.  | 
 | 203 | +  RegionSuccessor(Operation *successorOp, Operation::result_range results)  | 
 | 204 | +      : successor(successorOp), inputs(ValueRange(results)) {  | 
 | 205 | +    assert(successorOp && "Successor op must not be null");  | 
 | 206 | +  }  | 
196 | 207 | 
 
  | 
197 | 208 |   /// Return the given region successor. Returns nullptr if the successor is the  | 
198 | 209 |   /// parent operation.  | 
199 |  | -  Region *getSuccessor() const { return region; }  | 
 | 210 | +  Region *getSuccessor() const { return dyn_cast<Region *>(successor); }  | 
200 | 211 | 
 
  | 
201 | 212 |   /// Return true if the successor is the parent operation.  | 
202 |  | -  bool isParent() const { return region == nullptr; }  | 
 | 213 | +  bool isParent() const { return isa<Operation *>(successor); }  | 
203 | 214 | 
 
  | 
204 | 215 |   /// Return the inputs to the successor that are remapped by the exit values of  | 
205 | 216 |   /// the current region.  | 
206 | 217 |   ValueRange getSuccessorInputs() const { return inputs; }  | 
207 | 218 | 
 
  | 
 | 219 | +  bool operator==(RegionSuccessor rhs) const {  | 
 | 220 | +    return successor == rhs.successor && inputs == rhs.inputs;  | 
 | 221 | +  }  | 
 | 222 | + | 
 | 223 | +  friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {  | 
 | 224 | +    return !(lhs == rhs);  | 
 | 225 | +  }  | 
 | 226 | + | 
208 | 227 | private:  | 
209 |  | -  Region *region{nullptr};  | 
 | 228 | +  llvm::PointerUnion<Region *, Operation *> successor{nullptr};  | 
210 | 229 |   ValueRange inputs;  | 
211 | 230 | };  | 
212 | 231 | 
 
  | 
213 | 232 | /// This class represents a point being branched from in the methods of the  | 
214 | 233 | /// `RegionBranchOpInterface`.  | 
215 | 234 | /// One can branch from one of two kinds of places:  | 
216 | 235 | /// * The parent operation (aka the `RegionBranchOpInterface` implementation)  | 
217 |  | -/// * A region within the parent operation.  | 
 | 236 | +/// * A RegionBranchTerminatorOpInterface inside a region within the parent  | 
 | 237 | +//    operation.  | 
218 | 238 | class RegionBranchPoint {  | 
219 | 239 | public:  | 
220 | 240 |   /// Returns an instance of `RegionBranchPoint` representing the parent  | 
221 | 241 |   /// operation.  | 
222 | 242 |   static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }  | 
223 | 243 | 
 
  | 
224 |  | -  /// Creates a `RegionBranchPoint` that branches from the given region.  | 
225 |  | -  /// The pointer must not be null.  | 
226 |  | -  RegionBranchPoint(Region *region) : maybeRegion(region) {  | 
227 |  | -    assert(region && "Region must not be null");  | 
228 |  | -  }  | 
229 |  | - | 
230 |  | -  RegionBranchPoint(Region ®ion) : RegionBranchPoint(®ion) {}  | 
 | 244 | +  /// Creates a `RegionBranchPoint` that branches from the given terminator.  | 
 | 245 | +  inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);  | 
231 | 246 | 
 
  | 
232 | 247 |   /// Explicitly stops users from constructing with `nullptr`.  | 
233 | 248 |   RegionBranchPoint(std::nullptr_t) = delete;  | 
234 | 249 | 
 
  | 
235 |  | -  /// Constructs a `RegionBranchPoint` from the the target of a  | 
236 |  | -  /// `RegionSuccessor` instance.  | 
237 |  | -  RegionBranchPoint(RegionSuccessor successor) {  | 
238 |  | -    if (successor.isParent())  | 
239 |  | -      maybeRegion = nullptr;  | 
240 |  | -    else  | 
241 |  | -      maybeRegion = successor.getSuccessor();  | 
242 |  | -  }  | 
243 |  | - | 
244 |  | -  /// Assigns a region being branched from.  | 
245 |  | -  RegionBranchPoint &operator=(Region ®ion) {  | 
246 |  | -    maybeRegion = ®ion;  | 
247 |  | -    return *this;  | 
248 |  | -  }  | 
249 |  | - | 
250 | 250 |   /// Returns true if branching from the parent op.  | 
251 |  | -  bool isParent() const { return maybeRegion == nullptr; }  | 
 | 251 | +  bool isParent() const { return predecessor == nullptr; }  | 
252 | 252 | 
 
  | 
253 |  | -  /// Returns the region if branching from a region.  | 
 | 253 | +  /// Returns the terminator if branching from a region.  | 
254 | 254 |   /// A null pointer otherwise.  | 
255 |  | -  Region *getRegionOrNull() const { return maybeRegion; }  | 
 | 255 | +  Operation *getTerminatorPredecessorOrNull() const { return predecessor; }  | 
256 | 256 | 
 
  | 
257 | 257 |   /// Returns true if the two branch points are equal.  | 
258 | 258 |   friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {  | 
259 |  | -    return lhs.maybeRegion == rhs.maybeRegion;  | 
 | 259 | +    return lhs.predecessor == rhs.predecessor;  | 
260 | 260 |   }  | 
261 | 261 | 
 
  | 
262 | 262 | private:  | 
263 | 263 |   // Private constructor to encourage the use of `RegionBranchPoint::parent`.  | 
264 |  | -  constexpr RegionBranchPoint() : maybeRegion(nullptr) {}  | 
 | 264 | +  constexpr RegionBranchPoint() = default;  | 
265 | 265 | 
 
  | 
266 | 266 |   /// Internal encoding. Uses nullptr for representing branching from the parent  | 
267 |  | -  /// op and the region being branched from otherwise.  | 
268 |  | -  Region *maybeRegion;  | 
 | 267 | +  /// op and the region terminator being branched from otherwise.  | 
 | 268 | +  Operation *predecessor = nullptr;  | 
269 | 269 | };  | 
270 | 270 | 
 
  | 
271 | 271 | inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {  | 
272 | 272 |   return !(lhs == rhs);  | 
273 | 273 | }  | 
274 | 274 | 
 
  | 
 | 275 | +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,  | 
 | 276 | +                                     RegionBranchPoint point) {  | 
 | 277 | +  if (point.isParent())  | 
 | 278 | +    return os << "<from parent>";  | 
 | 279 | +  return os << "<region #"  | 
 | 280 | +            << point.getTerminatorPredecessorOrNull()  | 
 | 281 | +                   ->getParentRegion()  | 
 | 282 | +                   ->getRegionNumber()  | 
 | 283 | +            << ", terminator "  | 
 | 284 | +            << OpWithFlags(point.getTerminatorPredecessorOrNull(),  | 
 | 285 | +                           OpPrintingFlags().skipRegions())  | 
 | 286 | +            << ">";  | 
 | 287 | +}  | 
 | 288 | + | 
 | 289 | +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,  | 
 | 290 | +                                     RegionSuccessor successor) {  | 
 | 291 | +  if (successor.isParent())  | 
 | 292 | +    return os << "<to parent>";  | 
 | 293 | +  return os << "<to region #" << successor.getSuccessor()->getRegionNumber()  | 
 | 294 | +            << " with " << successor.getSuccessorInputs().size() << " inputs>";  | 
 | 295 | +}  | 
 | 296 | + | 
275 | 297 | /// This class represents upper and lower bounds on the number of times a region  | 
276 | 298 | /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least  | 
277 | 299 | /// zero, but the upper bound may not be known.  | 
@@ -348,4 +370,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {  | 
348 | 370 | /// Include the generated interface declarations.  | 
349 | 371 | #include "mlir/Interfaces/ControlFlowInterfaces.h.inc"  | 
350 | 372 | 
 
  | 
 | 373 | +namespace mlir {  | 
 | 374 | +inline RegionBranchPoint::RegionBranchPoint(  | 
 | 375 | +    RegionBranchTerminatorOpInterface predecessor)  | 
 | 376 | +    : predecessor(predecessor.getOperation()) {}  | 
 | 377 | +} // namespace mlir  | 
 | 378 | + | 
351 | 379 | #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H  | 
0 commit comments