@@ -414,6 +414,62 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
414414};
415415
416416class CIRWhileOpLowering : public mlir ::OpConversionPattern<cir::WhileOp> {
417+ void rewriteContinueInIf (cir::IfOp ifOp, cir::ContinueOp continueOp,
418+ mlir::scf::WhileOp whileOp,
419+ mlir::ConversionPatternRewriter &rewriter) const {
420+ auto loc = ifOp->getLoc ();
421+
422+ rewriter.setInsertionPointToStart (whileOp.getAfterBody ());
423+ auto boolTy = rewriter.getType <BoolType>();
424+ auto boolPtrTy = rewriter.getType <PointerType>(boolTy);
425+ auto alignment = rewriter.getI64IntegerAttr (4 );
426+ auto condAlloca = rewriter.create <AllocaOp>(loc, boolPtrTy, boolTy,
427+ " condition" , alignment);
428+
429+ rewriter.setInsertionPoint (ifOp);
430+ auto negated = rewriter.create <UnaryOp>(loc, boolTy, UnaryOpKind::Not,
431+ ifOp.getCondition ());
432+ rewriter.create <StoreOp>(loc, negated, condAlloca);
433+
434+ // On each layer, surround everything after runner in its parent with a
435+ // guard: `if (!condAlloca)`.
436+ for (mlir::Operation *runner = ifOp; runner != whileOp;
437+ runner = runner->getParentOp ()) {
438+ rewriter.setInsertionPointAfter (runner);
439+ auto cond = rewriter.create <LoadOp>(
440+ loc, boolTy, condAlloca, /* isDeref=*/ false ,
441+ /* volatile=*/ false , /* nontemporal=*/ false , alignment,
442+ /* memorder=*/ cir::MemOrderAttr{}, /* tbaa=*/ cir::TBAAAttr{});
443+ auto ifnot =
444+ rewriter.create <IfOp>(loc, cond, /* withElseRegion=*/ false ,
445+ [&](mlir::OpBuilder &, mlir::Location) {
446+ /* Intentionally left empty */
447+ });
448+
449+ auto ®ion = ifnot.getThenRegion ();
450+ rewriter.setInsertionPointToEnd (®ion.back ());
451+ auto terminator = rewriter.create <YieldOp>(loc);
452+
453+ bool inserted = false ;
454+ for (mlir::Operation *op = ifnot->getNextNode (); op;) {
455+ // Don't move terminators in.
456+ if (isa<YieldOp>(op) || isa<ReturnOp>(op))
457+ break ;
458+
459+ mlir::Operation *next = op->getNextNode ();
460+ op->moveBefore (terminator);
461+ op = next;
462+ inserted = true ;
463+ }
464+ // Don't retain `if (!condAlloca)` when it's empty.
465+ if (!inserted)
466+ rewriter.eraseOp (ifnot);
467+ }
468+ rewriter.setInsertionPoint (continueOp);
469+ rewriter.create <mlir::scf::YieldOp>(continueOp->getLoc ());
470+ rewriter.eraseOp (continueOp);
471+ }
472+
417473 void rewriteContinue (mlir::scf::WhileOp whileOp,
418474 mlir::ConversionPatternRewriter &rewriter) const {
419475 // Collect all ContinueOp inside this while.
@@ -427,23 +483,29 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
427483 return ;
428484
429485 for (auto continueOp : continues) {
430- // When the break is under an IfOp, a direct replacement of `scf.yield`
431- // won't work: the yield would jump out of that IfOp instead. We might
432- // need to change the whileOp itself to achieve the same effect.
486+ // When the ContinueOp is under an IfOp, a direct replacement of
487+ // `scf.yield` won't work: the yield would jump out of that IfOp instead.
488+ // We might need to change the WhileOp itself to achieve the same effect.
489+ bool rewritten = false ;
433490 for (mlir::Operation *parent = continueOp->getParentOp ();
434491 parent != whileOp; parent = parent->getParentOp ()) {
435- if (isa<mlir::scf::IfOp>(parent) || isa<cir::IfOp>(parent))
436- llvm_unreachable (" NYI" );
492+ if (auto ifOp = dyn_cast<cir::IfOp>(parent)) {
493+ rewriteContinueInIf (ifOp, continueOp, whileOp, rewriter);
494+ rewritten = true ;
495+ break ;
496+ }
437497 }
498+ if (rewritten)
499+ continue ;
438500
439- // Operations after this break has to be removed.
501+ // Operations after this ContinueOp has to be removed.
440502 for (mlir::Operation *runner = continueOp->getNextNode (); runner;) {
441503 mlir::Operation *next = runner->getNextNode ();
442504 runner->erase ();
443505 runner = next;
444506 }
445507
446- // Blocks after this break also has to be removed.
508+ // Blocks after this ContinueOp also has to be removed.
447509 for (mlir::Block *block = continueOp->getBlock ()->getNextNode (); block;) {
448510 mlir::Block *next = block->getNextNode ();
449511 block->erase ();
0 commit comments