Skip to content

Commit 3235864

Browse files
[CIR][ThroughMLIR] Lower For to While when it contains break/continue (#1716)
We lower `cir::ForOp` into `cir::WhileOp` (rather than `scf::WhileOp`) when it contains break and continue. This is to reuse the rewriting functions already implemented for while loops. Co-authored-by: Yue Huang <[email protected]>
1 parent 9642d52 commit 3235864

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1515
#include "mlir/Dialect/SCF/IR/SCF.h"
1616
#include "mlir/IR/Builders.h"
17-
#include "mlir/IR/BuiltinOps.h"
1817
#include "mlir/IR/Location.h"
1918
#include "mlir/IR/ValueRange.h"
2019
#include "mlir/Pass/PassManager.h"
@@ -40,6 +39,7 @@ class SCFLoop {
4039
mlir::Value getLowerBound() { return lowerBound; }
4140
mlir::Value getUpperBound() { return upperBound; }
4241
bool isCanonical() { return canonical; }
42+
bool hasBreakOrContinue() { return hasBreakContinue; }
4343

4444
// Returns true if successfully finds both step and induction variable.
4545
mlir::LogicalResult findStepAndIV();
@@ -50,13 +50,15 @@ class SCFLoop {
5050
mlir::Value plusConstant(mlir::Value v, mlir::Location loc, int addend);
5151
void transferToSCFForOp();
5252
void transformToSCFWhileOp();
53+
void transformToCIRWhileOp(); // TODO
5354

5455
private:
5556
cir::ForOp forOp;
5657
cir::CmpOp cmpOp;
5758
mlir::Value ivAddr, lowerBound = nullptr, upperBound = nullptr;
5859
mlir::ConversionPatternRewriter *rewriter;
5960
int64_t step = 0;
61+
bool hasBreakContinue = false;
6062
bool canonical = true;
6163
};
6264

@@ -251,6 +253,16 @@ mlir::Value SCFLoop::findIVInitValue() {
251253
}
252254

253255
void SCFLoop::analysis() {
256+
// Check whether this ForOp contains break or continue.
257+
forOp.walk([&](mlir::Operation *op) {
258+
if (isa<BreakOp, ContinueOp>(op))
259+
hasBreakContinue = true;
260+
});
261+
if (hasBreakContinue) {
262+
canonical = false;
263+
return;
264+
}
265+
254266
canonical = mlir::succeeded(findStepAndIV());
255267
if (!canonical)
256268
return;
@@ -356,6 +368,30 @@ void SCFLoop::transformToSCFWhileOp() {
356368
scfWhileOp.getAfterBody()->end());
357369
}
358370

371+
void SCFLoop::transformToCIRWhileOp() {
372+
auto cirWhileOp = rewriter->create<cir::WhileOp>(
373+
forOp->getLoc(), forOp->getResultTypes(), mlir::ValueRange());
374+
rewriter->createBlock(&cirWhileOp.getCond());
375+
rewriter->createBlock(&cirWhileOp.getBody());
376+
377+
mlir::Block &condFront = cirWhileOp.getCond().front();
378+
rewriter->inlineBlockBefore(&forOp.getCond().front(), &condFront,
379+
condFront.end());
380+
381+
mlir::Block &bodyFront = cirWhileOp.getBody().front();
382+
rewriter->inlineBlockBefore(&forOp.getBody().front(), &bodyFront,
383+
bodyFront.end());
384+
385+
// The last operation of `bodyFront` must be a terminator.
386+
// We need to place the step just before it.
387+
rewriter->inlineBlockBefore(&forOp.getStep().front(), &bodyFront,
388+
--bodyFront.end());
389+
// This also introduces another terminator before the one of the body.
390+
// We need to erase it.
391+
auto &stepTerminator = *--bodyFront.end();
392+
rewriter->eraseOp(&stepTerminator);
393+
}
394+
359395
mlir::scf::WhileOp SCFWhileLoop::transferToSCFWhileOp() {
360396
auto scfWhileOp = rewriter->create<mlir::scf::WhileOp>(
361397
whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands());
@@ -400,6 +436,14 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
400436
mlir::ConversionPatternRewriter &rewriter) const override {
401437
SCFLoop loop(op, &rewriter);
402438
loop.analysis();
439+
// Breaks and continues are handled in lowering of cir::WhileOp.
440+
// We can reuse the code by transforming this ForOp into WhileOp.
441+
if (loop.hasBreakOrContinue()) {
442+
loop.transformToCIRWhileOp();
443+
rewriter.eraseOp(op);
444+
return mlir::success();
445+
}
446+
403447
if (!loop.isCanonical()) {
404448
loop.transformToSCFWhileOp();
405449
rewriter.eraseOp(op);
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
void for_continue() {
5+
for (int i = 0; i < 100; i++)
6+
continue;
7+
8+
// CHECK: scf.while : () -> () {
9+
// CHECK: %[[IV:.+]] = memref.load %alloca[]
10+
// CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[IV]], %c100_i32
11+
// CHECK: scf.condition(%[[CMP]])
12+
// CHECK: } do {
13+
// CHECK: %[[IV2:.+]] = memref.load %alloca[]
14+
// CHECK: %[[ONE:.+]] = arith.constant 1
15+
// CHECK: %[[CMP2:.+]] = arith.addi %[[IV2]], %[[ONE]]
16+
// CHECK: memref.store %[[CMP2]], %alloca[]
17+
// CHECK: scf.yield
18+
// CHECK: }
19+
}

0 commit comments

Comments
 (0)