@@ -66,11 +66,12 @@ class SCFWhileLoop {
6666 SCFWhileLoop (cir::WhileOp op, cir::WhileOp::Adaptor adaptor,
6767 mlir::ConversionPatternRewriter *rewriter)
6868 : whileOp(op), adaptor(adaptor), rewriter(rewriter) {}
69- void transferToSCFWhileOp ();
69+ mlir::scf::WhileOp transferToSCFWhileOp ();
7070
7171private:
7272 cir::WhileOp whileOp;
7373 cir::WhileOp::Adaptor adaptor;
74+ mlir::scf::WhileOp scfWhileOp;
7475 mlir::ConversionPatternRewriter *rewriter;
7576};
7677
@@ -356,7 +357,7 @@ void SCFLoop::transformToSCFWhileOp() {
356357 scfWhileOp.getAfterBody ()->end ());
357358}
358359
359- void SCFWhileLoop::transferToSCFWhileOp () {
360+ mlir::scf::WhileOp SCFWhileLoop::transferToSCFWhileOp () {
360361 auto scfWhileOp = rewriter->create <mlir::scf::WhileOp>(
361362 whileOp->getLoc (), whileOp->getResultTypes (), adaptor.getOperands ());
362363 rewriter->createBlock (&scfWhileOp.getBefore ());
@@ -367,6 +368,7 @@ void SCFWhileLoop::transferToSCFWhileOp() {
367368 rewriter->inlineBlockBefore (&whileOp.getBody ().front (),
368369 scfWhileOp.getAfterBody (),
369370 scfWhileOp.getAfterBody ()->end ());
371+ return scfWhileOp;
370372}
371373
372374void SCFDoLoop::transferToSCFWhileOp () {
@@ -412,14 +414,53 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
412414};
413415
414416class CIRWhileOpLowering : public mlir ::OpConversionPattern<cir::WhileOp> {
417+ void rewriteContinue (mlir::scf::WhileOp whileOp,
418+ mlir::ConversionPatternRewriter &rewriter) const {
419+ // Collect all ContinueOp inside this while.
420+ llvm::SmallVector<cir::ContinueOp> continues;
421+ whileOp->walk ([&](mlir::Operation *op) {
422+ if (auto continueOp = dyn_cast<ContinueOp>(op))
423+ continues.push_back (continueOp);
424+ });
425+
426+ if (continues.empty ())
427+ return ;
428+
429+ for (auto continueOp : continues) {
430+ // When the break is under an IfOp, a direct replacement of `scf.yield`
431+ // won't work: the yield would jump out of that IfOp instead. We might
432+ // need to change the whileOp itself to achieve the same effect.
433+ for (mlir::Operation *parent = continueOp->getParentOp ();
434+ parent != whileOp; parent = parent->getParentOp ()) {
435+ if (isa<mlir::scf::IfOp>(parent) || isa<cir::IfOp>(parent))
436+ llvm_unreachable (" NYI" );
437+ }
438+
439+ // Operations after this break has to be removed.
440+ for (mlir::Operation *runner = continueOp->getNextNode (); runner;) {
441+ mlir::Operation *next = runner->getNextNode ();
442+ runner->erase ();
443+ runner = next;
444+ }
445+
446+ // Blocks after this break also has to be removed.
447+ for (mlir::Block *block = continueOp->getBlock ()->getNextNode (); block;) {
448+ mlir::Block *next = block->getNextNode ();
449+ block->erase ();
450+ block = next;
451+ }
452+ }
453+ }
454+
415455public:
416456 using OpConversionPattern<cir::WhileOp>::OpConversionPattern;
417457
418458 mlir::LogicalResult
419459 matchAndRewrite (cir::WhileOp op, OpAdaptor adaptor,
420460 mlir::ConversionPatternRewriter &rewriter) const override {
421461 SCFWhileLoop loop (op, adaptor, &rewriter);
422- loop.transferToSCFWhileOp ();
462+ auto whileOp = loop.transferToSCFWhileOp ();
463+ rewriteContinue (whileOp, rewriter);
423464 rewriter.eraseOp (op);
424465 return mlir::success ();
425466 }
0 commit comments