14
14
#include " mlir/Dialect/MemRef/IR/MemRef.h"
15
15
#include " mlir/Dialect/SCF/IR/SCF.h"
16
16
#include " mlir/IR/Builders.h"
17
- #include " mlir/IR/BuiltinOps.h"
18
17
#include " mlir/IR/Location.h"
19
18
#include " mlir/IR/ValueRange.h"
20
19
#include " mlir/Pass/PassManager.h"
@@ -40,6 +39,7 @@ class SCFLoop {
40
39
mlir::Value getLowerBound () { return lowerBound; }
41
40
mlir::Value getUpperBound () { return upperBound; }
42
41
bool isCanonical () { return canonical; }
42
+ bool hasBreakOrContinue () { return hasBreakContinue; }
43
43
44
44
// Returns true if successfully finds both step and induction variable.
45
45
mlir::LogicalResult findStepAndIV ();
@@ -50,13 +50,15 @@ class SCFLoop {
50
50
mlir::Value plusConstant (mlir::Value v, mlir::Location loc, int addend);
51
51
void transferToSCFForOp ();
52
52
void transformToSCFWhileOp ();
53
+ void transformToCIRWhileOp (); // TODO
53
54
54
55
private:
55
56
cir::ForOp forOp;
56
57
cir::CmpOp cmpOp;
57
58
mlir::Value ivAddr, lowerBound = nullptr , upperBound = nullptr ;
58
59
mlir::ConversionPatternRewriter *rewriter;
59
60
int64_t step = 0 ;
61
+ bool hasBreakContinue = false ;
60
62
bool canonical = true ;
61
63
};
62
64
@@ -251,6 +253,16 @@ mlir::Value SCFLoop::findIVInitValue() {
251
253
}
252
254
253
255
void SCFLoop::analysis () {
256
+ // Check whether this ForOp contains break or continue.
257
+ forOp.walk ([&](mlir::Operation *op) {
258
+ if (isa<BreakOp, ContinueOp>(op))
259
+ hasBreakContinue = true ;
260
+ });
261
+ if (hasBreakContinue) {
262
+ canonical = false ;
263
+ return ;
264
+ }
265
+
254
266
canonical = mlir::succeeded (findStepAndIV ());
255
267
if (!canonical)
256
268
return ;
@@ -356,6 +368,30 @@ void SCFLoop::transformToSCFWhileOp() {
356
368
scfWhileOp.getAfterBody ()->end ());
357
369
}
358
370
371
+ void SCFLoop::transformToCIRWhileOp () {
372
+ auto cirWhileOp = rewriter->create <cir::WhileOp>(
373
+ forOp->getLoc (), forOp->getResultTypes (), mlir::ValueRange ());
374
+ rewriter->createBlock (&cirWhileOp.getCond ());
375
+ rewriter->createBlock (&cirWhileOp.getBody ());
376
+
377
+ mlir::Block &condFront = cirWhileOp.getCond ().front ();
378
+ rewriter->inlineBlockBefore (&forOp.getCond ().front (), &condFront,
379
+ condFront.end ());
380
+
381
+ mlir::Block &bodyFront = cirWhileOp.getBody ().front ();
382
+ rewriter->inlineBlockBefore (&forOp.getBody ().front (), &bodyFront,
383
+ bodyFront.end ());
384
+
385
+ // The last operation of `bodyFront` must be a terminator.
386
+ // We need to place the step just before it.
387
+ rewriter->inlineBlockBefore (&forOp.getStep ().front (), &bodyFront,
388
+ --bodyFront.end ());
389
+ // This also introduces another terminator before the one of the body.
390
+ // We need to erase it.
391
+ auto &stepTerminator = *--bodyFront.end ();
392
+ rewriter->eraseOp (&stepTerminator);
393
+ }
394
+
359
395
mlir::scf::WhileOp SCFWhileLoop::transferToSCFWhileOp () {
360
396
auto scfWhileOp = rewriter->create <mlir::scf::WhileOp>(
361
397
whileOp->getLoc (), whileOp->getResultTypes (), adaptor.getOperands ());
@@ -400,6 +436,14 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
400
436
mlir::ConversionPatternRewriter &rewriter) const override {
401
437
SCFLoop loop (op, &rewriter);
402
438
loop.analysis ();
439
+ // Breaks and continues are handled in lowering of cir::WhileOp.
440
+ // We can reuse the code by transforming this ForOp into WhileOp.
441
+ if (loop.hasBreakOrContinue ()) {
442
+ loop.transformToCIRWhileOp ();
443
+ rewriter.eraseOp (op);
444
+ return mlir::success ();
445
+ }
446
+
403
447
if (!loop.isCanonical ()) {
404
448
loop.transformToSCFWhileOp ();
405
449
rewriter.eraseOp (op);
0 commit comments