@@ -528,7 +528,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
528
528
for (auto continueOp : continues) {
529
529
bool nested = false ;
530
530
// When there is another loop between this WhileOp and the ContinueOp,
531
- // we shouldn't change that loop instead.
531
+ // we should change that loop instead.
532
532
for (mlir::Operation *parent = continueOp->getParentOp ();
533
533
parent != whileOp; parent = parent->getParentOp ()) {
534
534
if (isa<WhileOp>(parent)) {
@@ -570,6 +570,81 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
570
570
}
571
571
}
572
572
573
+ void rewriteBreak (mlir::scf::WhileOp whileOp,
574
+ mlir::ConversionPatternRewriter &rewriter) const {
575
+ // Collect all BreakOp inside this while.
576
+ llvm::SmallVector<cir::BreakOp> breaks;
577
+ whileOp->walk ([&](mlir::Operation *op) {
578
+ if (auto breakOp = dyn_cast<BreakOp>(op))
579
+ breaks.push_back (breakOp);
580
+ });
581
+
582
+ if (breaks.empty ())
583
+ return ;
584
+
585
+ for (auto breakOp : breaks) {
586
+ bool nested = false ;
587
+ // When there is another loop between this WhileOp and the BreakOp,
588
+ // we should change that loop instead.
589
+ for (mlir::Operation *parent = breakOp->getParentOp (); parent != whileOp;
590
+ parent = parent->getParentOp ()) {
591
+ if (isa<WhileOp>(parent)) {
592
+ nested = true ;
593
+ break ;
594
+ }
595
+ }
596
+ if (nested)
597
+ continue ;
598
+
599
+ // Similar to the case of ContinueOp, when there is an `IfOp`,
600
+ // we need to take special care.
601
+ for (mlir::Operation *parent = breakOp->getParentOp (); parent != whileOp;
602
+ parent = parent->getParentOp ()) {
603
+ if (auto ifOp = dyn_cast<cir::IfOp>(parent))
604
+ llvm_unreachable (" NYI" );
605
+ }
606
+
607
+ // Operations after this BreakOp has to be removed.
608
+ for (mlir::Operation *runner = breakOp->getNextNode (); runner;) {
609
+ mlir::Operation *next = runner->getNextNode ();
610
+ runner->erase ();
611
+ runner = next;
612
+ }
613
+
614
+ // Blocks after this BreakOp also has to be removed.
615
+ for (mlir::Block *block = breakOp->getBlock ()->getNextNode (); block;) {
616
+ mlir::Block *next = block->getNextNode ();
617
+ block->erase ();
618
+ block = next;
619
+ }
620
+
621
+ // We know this BreakOp isn't nested in any IfOp.
622
+ // Therefore, the loop is executed only once.
623
+ // We pull everything out of the loop.
624
+
625
+ auto &beforeOps = whileOp.getBeforeBody ()->getOperations ();
626
+ for (mlir::Operation *op = &*beforeOps.begin (); op;) {
627
+ if (isa<ConditionOp>(op))
628
+ break ;
629
+ auto *next = op->getNextNode ();
630
+ op->moveBefore (whileOp);
631
+ op = next;
632
+ }
633
+
634
+ auto &afterOps = whileOp.getAfterBody ()->getOperations ();
635
+ for (mlir::Operation *op = &*afterOps.begin (); op;) {
636
+ if (isa<YieldOp>(op))
637
+ break ;
638
+ auto *next = op->getNextNode ();
639
+ op->moveBefore (whileOp);
640
+ op = next;
641
+ }
642
+
643
+ // The loop itself should now be removed.
644
+ rewriter.eraseOp (whileOp);
645
+ }
646
+ }
647
+
573
648
public:
574
649
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;
575
650
@@ -579,6 +654,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
579
654
SCFWhileLoop loop (op, adaptor, &rewriter);
580
655
auto whileOp = loop.transferToSCFWhileOp ();
581
656
rewriteContinue (whileOp, rewriter);
657
+ rewriteBreak (whileOp, rewriter);
582
658
rewriter.eraseOp (op);
583
659
return mlir::success ();
584
660
}
0 commit comments