Skip to content

Commit 5c2209a

Browse files
terapines open source contributor 2AdUhTkJm
authored andcommitted
[CIR][ThroughMLIR] Lower uncanonicalized fors to whiles (llvm#1644)
The transformation functions are all named `transferToXXXOp`. Are those typos? Co-authored-by: Yue Huang <[email protected]>
1 parent c80903e commit 5c2209a

File tree

3 files changed

+61
-29
lines changed

3 files changed

+61
-29
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,15 @@
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1515
#include "mlir/Dialect/SCF/IR/SCF.h"
16-
#include "mlir/Dialect/SCF/Transforms/Passes.h"
1716
#include "mlir/IR/Builders.h"
18-
#include "mlir/IR/BuiltinDialect.h"
19-
#include "mlir/IR/BuiltinTypes.h"
2017
#include "mlir/IR/Location.h"
2118
#include "mlir/IR/ValueRange.h"
22-
#include "mlir/Pass/Pass.h"
2319
#include "mlir/Pass/PassManager.h"
2420
#include "mlir/Support/LogicalResult.h"
2521
#include "mlir/Transforms/DialectConversion.h"
2622
#include "clang/CIR/Dialect/IR/CIRDialect.h"
2723
#include "clang/CIR/Dialect/IR/CIRTypes.h"
2824
#include "clang/CIR/LowerToMLIR.h"
29-
#include "clang/CIR/Passes.h"
3025
#include "llvm/ADT/TypeSwitch.h"
3126

3227
using namespace cir;
@@ -52,6 +47,7 @@ class SCFLoop {
5247

5348
mlir::Value plusConstant(mlir::Value v, mlir::Location loc, int addend);
5449
void transferToSCFForOp();
50+
void transformToSCFWhileOp();
5551

5652
private:
5753
cir::ForOp forOp;
@@ -209,21 +205,21 @@ cir::CmpOp SCFLoop::findCmpOp() {
209205
}
210206
}
211207
if (!cmpOp)
212-
llvm_unreachable("Can't find loop CmpOp");
208+
return nullptr;
213209

214210
auto type = cmpOp.getLhs().getType();
215211
if (!mlir::isa<cir::IntType>(type))
216-
llvm_unreachable("Non-integer type IV is not supported");
212+
return nullptr;
217213

218214
auto *lhsDefOp = cmpOp.getLhs().getDefiningOp();
219215
if (!lhsDefOp)
220-
llvm_unreachable("Can't find IV load");
216+
return nullptr;
221217
if (!isIVLoad(lhsDefOp, ivAddr))
222-
llvm_unreachable("cmpOp LHS is not IV");
218+
return nullptr;
223219

224220
if (cmpOp.getKind() != cir::CmpOpKind::le &&
225221
cmpOp.getKind() != cir::CmpOpKind::lt)
226-
llvm_unreachable("Not support lowering other than le or lt comparison");
222+
return nullptr;
227223

228224
return cmpOp;
229225
}
@@ -253,30 +249,40 @@ mlir::Value SCFLoop::findIVInitValue() {
253249

254250
void SCFLoop::analysis() {
255251
canonical = mlir::succeeded(findStepAndIV());
256-
if (!canonical) {
257-
mlir::emitError(forOp.getLoc(),
258-
"cannot handle non-constant step for induction variable");
252+
if (!canonical)
259253
return;
260-
}
261254

262255
cmpOp = findCmpOp();
263-
auto IVInit = findIVInitValue();
256+
if (!cmpOp) {
257+
canonical = false;
258+
return;
259+
}
260+
261+
auto ivInit = findIVInitValue();
262+
if (!ivInit) {
263+
canonical = false;
264+
return;
265+
}
266+
264267
// The loop end value should be hoisted out of loop by -cir-mlir-scf-prepare.
265268
// So we could get the value by getRemappedValue.
266-
auto IVEndBound = rewriter->getRemappedValue(cmpOp.getRhs());
267-
// If the loop end bound is not loop invariant and can't be hoisted.
268-
// The following assertion will be triggerred.
269-
assert(IVEndBound && "can't find IV end boundary");
269+
auto ivEndBound = rewriter->getRemappedValue(cmpOp.getRhs());
270+
// If the loop end bound is not loop invariant and can't be hoisted,
271+
// then this is not a canonical loop.
272+
if (!ivEndBound) {
273+
canonical = false;
274+
return;
275+
}
270276

271277
if (step > 0) {
272-
lowerBound = IVInit;
278+
lowerBound = ivInit;
273279
if (cmpOp.getKind() == cir::CmpOpKind::lt)
274-
upperBound = IVEndBound;
280+
upperBound = ivEndBound;
275281
else if (cmpOp.getKind() == cir::CmpOpKind::le)
276-
upperBound = plusConstant(IVEndBound, cmpOp.getLoc(), 1);
282+
upperBound = plusConstant(ivEndBound, cmpOp.getLoc(), 1);
277283
}
278-
assert(lowerBound && "can't find loop lower bound");
279-
assert(upperBound && "can't find loop upper bound");
284+
if (!lowerBound || !upperBound)
285+
canonical = false;
280286
}
281287

282288
void SCFLoop::transferToSCFForOp() {
@@ -309,6 +315,28 @@ void SCFLoop::transferToSCFForOp() {
309315
});
310316
}
311317

318+
void SCFLoop::transformToSCFWhileOp() {
319+
auto scfWhileOp = rewriter->create<mlir::scf::WhileOp>(
320+
forOp->getLoc(), forOp->getResultTypes(), mlir::ValueRange());
321+
rewriter->createBlock(&scfWhileOp.getBefore());
322+
rewriter->createBlock(&scfWhileOp.getAfter());
323+
324+
rewriter->inlineBlockBefore(&forOp.getCond().front(),
325+
scfWhileOp.getBeforeBody(),
326+
scfWhileOp.getBeforeBody()->end());
327+
rewriter->inlineBlockBefore(&forOp.getBody().front(),
328+
scfWhileOp.getAfterBody(),
329+
scfWhileOp.getAfterBody()->end());
330+
// There will be a yield after the `for` body.
331+
// We should delete it.
332+
auto yield = mlir::cast<YieldOp>(scfWhileOp.getAfterBody()->back());
333+
rewriter->eraseOp(yield);
334+
335+
rewriter->inlineBlockBefore(&forOp.getStep().front(),
336+
scfWhileOp.getAfterBody(),
337+
scfWhileOp.getAfterBody()->end());
338+
}
339+
312340
void SCFWhileLoop::transferToSCFWhileOp() {
313341
auto scfWhileOp = rewriter->create<mlir::scf::WhileOp>(
314342
whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands());
@@ -352,9 +380,11 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
352380
mlir::ConversionPatternRewriter &rewriter) const override {
353381
SCFLoop loop(op, &rewriter);
354382
loop.analysis();
355-
if (!loop.isCanonical())
356-
return mlir::emitError(op.getLoc(),
357-
"cannot handle non-canonicalized loop");
383+
if (!loop.isCanonical()) {
384+
loop.transformToSCFWhileOp();
385+
rewriter.eraseOp(op);
386+
return mlir::success();
387+
}
358388

359389
loop.transferToSCFForOp();
360390
rewriter.eraseOp(op);
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
// RUN: not %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o - 2>&1 | FileCheck %s
2+
// XFAIL: *
23

3-
void f();
4+
void f() {}
45

56
void reject() {
67
for (int i = 0; i < 100; i++, f());
7-
// CHECK: failed to legalize operation 'cir.for'
8+
// CHECK: failed to legalize operation 'cir.scope'
89
}

clang/test/CIR/Lowering/ThroughMLIR/for-reject-2.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: not %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o - 2>&1 | FileCheck %s
2+
// XFAIL: *
23

34
void reject() {
45
for (int i = 0; i < 100; i++, i++);

0 commit comments

Comments
 (0)