13
13
#include " mlir/Dialect/Arith/IR/Arith.h"
14
14
#include " mlir/Dialect/MemRef/IR/MemRef.h"
15
15
#include " mlir/Dialect/SCF/IR/SCF.h"
16
- #include " mlir/Dialect/SCF/Transforms/Passes.h"
17
16
#include " mlir/IR/Builders.h"
18
- #include " mlir/IR/BuiltinDialect.h"
19
- #include " mlir/IR/BuiltinTypes.h"
20
17
#include " mlir/IR/Location.h"
21
18
#include " mlir/IR/ValueRange.h"
22
- #include " mlir/Pass/Pass.h"
23
19
#include " mlir/Pass/PassManager.h"
24
20
#include " mlir/Support/LogicalResult.h"
25
21
#include " mlir/Transforms/DialectConversion.h"
26
22
#include " clang/CIR/Dialect/IR/CIRDialect.h"
27
23
#include " clang/CIR/Dialect/IR/CIRTypes.h"
28
24
#include " clang/CIR/LowerToMLIR.h"
29
- #include " clang/CIR/Passes.h"
30
25
#include " llvm/ADT/TypeSwitch.h"
31
26
32
27
using namespace cir ;
@@ -52,6 +47,7 @@ class SCFLoop {
52
47
53
48
mlir::Value plusConstant (mlir::Value v, mlir::Location loc, int addend);
54
49
void transferToSCFForOp ();
50
+ void transformToSCFWhileOp ();
55
51
56
52
private:
57
53
cir::ForOp forOp;
@@ -209,21 +205,21 @@ cir::CmpOp SCFLoop::findCmpOp() {
209
205
}
210
206
}
211
207
if (!cmpOp)
212
- llvm_unreachable ( " Can't find loop CmpOp " ) ;
208
+ return nullptr ;
213
209
214
210
auto type = cmpOp.getLhs ().getType ();
215
211
if (!mlir::isa<cir::IntType>(type))
216
- llvm_unreachable ( " Non-integer type IV is not supported " ) ;
212
+ return nullptr ;
217
213
218
214
auto *lhsDefOp = cmpOp.getLhs ().getDefiningOp ();
219
215
if (!lhsDefOp)
220
- llvm_unreachable ( " Can't find IV load " ) ;
216
+ return nullptr ;
221
217
if (!isIVLoad (lhsDefOp, ivAddr))
222
- llvm_unreachable ( " cmpOp LHS is not IV " ) ;
218
+ return nullptr ;
223
219
224
220
if (cmpOp.getKind () != cir::CmpOpKind::le &&
225
221
cmpOp.getKind () != cir::CmpOpKind::lt)
226
- llvm_unreachable ( " Not support lowering other than le or lt comparison " ) ;
222
+ return nullptr ;
227
223
228
224
return cmpOp;
229
225
}
@@ -253,30 +249,40 @@ mlir::Value SCFLoop::findIVInitValue() {
253
249
254
250
void SCFLoop::analysis () {
255
251
canonical = mlir::succeeded (findStepAndIV ());
256
- if (!canonical) {
257
- mlir::emitError (forOp.getLoc (),
258
- " cannot handle non-constant step for induction variable" );
252
+ if (!canonical)
259
253
return ;
260
- }
261
254
262
255
cmpOp = findCmpOp ();
263
- auto IVInit = findIVInitValue ();
256
+ if (!cmpOp) {
257
+ canonical = false ;
258
+ return ;
259
+ }
260
+
261
+ auto ivInit = findIVInitValue ();
262
+ if (!ivInit) {
263
+ canonical = false ;
264
+ return ;
265
+ }
266
+
264
267
// The loop end value should be hoisted out of loop by -cir-mlir-scf-prepare.
265
268
// So we could get the value by getRemappedValue.
266
- auto IVEndBound = rewriter->getRemappedValue (cmpOp.getRhs ());
267
- // If the loop end bound is not loop invariant and can't be hoisted.
268
- // The following assertion will be triggerred.
269
- assert (IVEndBound && " can't find IV end boundary" );
269
+ auto ivEndBound = rewriter->getRemappedValue (cmpOp.getRhs ());
270
+ // If the loop end bound is not loop invariant and can't be hoisted,
271
+ // then this is not a canonical loop.
272
+ if (!ivEndBound) {
273
+ canonical = false ;
274
+ return ;
275
+ }
270
276
271
277
if (step > 0 ) {
272
- lowerBound = IVInit ;
278
+ lowerBound = ivInit ;
273
279
if (cmpOp.getKind () == cir::CmpOpKind::lt)
274
- upperBound = IVEndBound ;
280
+ upperBound = ivEndBound ;
275
281
else if (cmpOp.getKind () == cir::CmpOpKind::le)
276
- upperBound = plusConstant (IVEndBound , cmpOp.getLoc (), 1 );
282
+ upperBound = plusConstant (ivEndBound , cmpOp.getLoc (), 1 );
277
283
}
278
- assert ( lowerBound && " can't find loop lower bound " );
279
- assert (upperBound && " can't find loop upper bound " ) ;
284
+ if (! lowerBound || !upperBound)
285
+ canonical = false ;
280
286
}
281
287
282
288
void SCFLoop::transferToSCFForOp () {
@@ -309,6 +315,28 @@ void SCFLoop::transferToSCFForOp() {
309
315
});
310
316
}
311
317
318
+ void SCFLoop::transformToSCFWhileOp () {
319
+ auto scfWhileOp = rewriter->create <mlir::scf::WhileOp>(
320
+ forOp->getLoc (), forOp->getResultTypes (), mlir::ValueRange ());
321
+ rewriter->createBlock (&scfWhileOp.getBefore ());
322
+ rewriter->createBlock (&scfWhileOp.getAfter ());
323
+
324
+ rewriter->inlineBlockBefore (&forOp.getCond ().front (),
325
+ scfWhileOp.getBeforeBody (),
326
+ scfWhileOp.getBeforeBody ()->end ());
327
+ rewriter->inlineBlockBefore (&forOp.getBody ().front (),
328
+ scfWhileOp.getAfterBody (),
329
+ scfWhileOp.getAfterBody ()->end ());
330
+ // There will be a yield after the `for` body.
331
+ // We should delete it.
332
+ auto yield = mlir::cast<YieldOp>(scfWhileOp.getAfterBody ()->back ());
333
+ rewriter->eraseOp (yield);
334
+
335
+ rewriter->inlineBlockBefore (&forOp.getStep ().front (),
336
+ scfWhileOp.getAfterBody (),
337
+ scfWhileOp.getAfterBody ()->end ());
338
+ }
339
+
312
340
void SCFWhileLoop::transferToSCFWhileOp () {
313
341
auto scfWhileOp = rewriter->create <mlir::scf::WhileOp>(
314
342
whileOp->getLoc (), whileOp->getResultTypes (), adaptor.getOperands ());
@@ -352,9 +380,11 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
352
380
mlir::ConversionPatternRewriter &rewriter) const override {
353
381
SCFLoop loop (op, &rewriter);
354
382
loop.analysis ();
355
- if (!loop.isCanonical ())
356
- return mlir::emitError (op.getLoc (),
357
- " cannot handle non-canonicalized loop" );
383
+ if (!loop.isCanonical ()) {
384
+ loop.transformToSCFWhileOp ();
385
+ rewriter.eraseOp (op);
386
+ return mlir::success ();
387
+ }
358
388
359
389
loop.transferToSCFForOp ();
360
390
rewriter.eraseOp (op);
0 commit comments