Skip to content

Commit 8c31280

Browse files
committed
[CIR][ThroughMLIR] Lower WhileOp with break
1 parent 33c739f commit 8c31280

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
528528
for (auto continueOp : continues) {
529529
bool nested = false;
530530
// When there is another loop between this WhileOp and the ContinueOp,
531-
// we shouldn't change that loop instead.
531+
// we should change that loop instead.
532532
for (mlir::Operation *parent = continueOp->getParentOp();
533533
parent != whileOp; parent = parent->getParentOp()) {
534534
if (isa<WhileOp>(parent)) {
@@ -570,6 +570,81 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
570570
}
571571
}
572572

573+
void rewriteBreak(mlir::scf::WhileOp whileOp,
574+
mlir::ConversionPatternRewriter &rewriter) const {
575+
// Collect all BreakOp inside this while.
576+
llvm::SmallVector<cir::BreakOp> breaks;
577+
whileOp->walk([&](mlir::Operation *op) {
578+
if (auto breakOp = dyn_cast<BreakOp>(op))
579+
breaks.push_back(breakOp);
580+
});
581+
582+
if (breaks.empty())
583+
return;
584+
585+
for (auto breakOp : breaks) {
586+
bool nested = false;
587+
// When there is another loop between this WhileOp and the BreakOp,
588+
// we should change that loop instead.
589+
for (mlir::Operation *parent = breakOp->getParentOp(); parent != whileOp;
590+
parent = parent->getParentOp()) {
591+
if (isa<WhileOp>(parent)) {
592+
nested = true;
593+
break;
594+
}
595+
}
596+
if (nested)
597+
continue;
598+
599+
// Similar to the case of ContinueOp, when there is an `IfOp`,
600+
// we need to take special care.
601+
for (mlir::Operation *parent = breakOp->getParentOp(); parent != whileOp;
602+
parent = parent->getParentOp()) {
603+
if (auto ifOp = dyn_cast<cir::IfOp>(parent))
604+
llvm_unreachable("NYI");
605+
}
606+
607+
// Operations after this BreakOp has to be removed.
608+
for (mlir::Operation *runner = breakOp->getNextNode(); runner;) {
609+
mlir::Operation *next = runner->getNextNode();
610+
runner->erase();
611+
runner = next;
612+
}
613+
614+
// Blocks after this BreakOp also has to be removed.
615+
for (mlir::Block *block = breakOp->getBlock()->getNextNode(); block;) {
616+
mlir::Block *next = block->getNextNode();
617+
block->erase();
618+
block = next;
619+
}
620+
621+
// We know this BreakOp isn't nested in any IfOp.
622+
// Therefore, the loop is executed only once.
623+
// We pull everything out of the loop.
624+
625+
auto &beforeOps = whileOp.getBeforeBody()->getOperations();
626+
for (mlir::Operation *op = &*beforeOps.begin(); op;) {
627+
if (isa<ConditionOp>(op))
628+
break;
629+
auto *next = op->getNextNode();
630+
op->moveBefore(whileOp);
631+
op = next;
632+
}
633+
634+
auto &afterOps = whileOp.getAfterBody()->getOperations();
635+
for (mlir::Operation *op = &*afterOps.begin(); op;) {
636+
if (isa<YieldOp>(op))
637+
break;
638+
auto *next = op->getNextNode();
639+
op->moveBefore(whileOp);
640+
op = next;
641+
}
642+
643+
// The loop itself should now be removed.
644+
rewriter.eraseOp(whileOp);
645+
}
646+
}
647+
573648
public:
574649
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;
575650

@@ -579,6 +654,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
579654
SCFWhileLoop loop(op, adaptor, &rewriter);
580655
auto whileOp = loop.transferToSCFWhileOp();
581656
rewriteContinue(whileOp, rewriter);
657+
rewriteBreak(whileOp, rewriter);
582658
rewriter.eraseOp(op);
583659
return mlir::success();
584660
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
void while_break() {
5+
int i = 0;
6+
while (i < 100) {
7+
i++;
8+
break;
9+
i++;
10+
}
11+
// This should be compiled into the condition `i < 100` and a single `i++`,
12+
// without the while-loop.
13+
14+
// CHECK: memref.alloca_scope {
15+
// CHECK: %[[IV:.+]] = memref.load %alloca[]
16+
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
17+
// CHECK: %[[_:.+]] = arith.cmpi slt, %[[IV]], %[[HUNDRED]]
18+
// CHECK: memref.alloca_scope {
19+
// CHECK: %[[IV2:.+]] = memref.load %alloca[]
20+
// CHECK: %[[ONE:.+]] = arith.constant 1
21+
// CHECK: %[[INCR:.+]] = arith.addi %[[IV2]], %[[ONE]]
22+
// CHECK: memref.store %[[INCR]], %alloca[]
23+
// CHECK: }
24+
// CHECK: }
25+
}

0 commit comments

Comments
 (0)