1313#include " mlir/Dialect/Arith/IR/Arith.h"
1414#include " mlir/Dialect/MemRef/IR/MemRef.h"
1515#include " mlir/Dialect/SCF/IR/SCF.h"
16- #include " mlir/Dialect/SCF/Transforms/Passes.h"
1716#include " mlir/IR/Builders.h"
18- #include " mlir/IR/BuiltinDialect.h"
19- #include " mlir/IR/BuiltinTypes.h"
2017#include " mlir/IR/Location.h"
2118#include " mlir/IR/ValueRange.h"
22- #include " mlir/Pass/Pass.h"
2319#include " mlir/Pass/PassManager.h"
2420#include " mlir/Support/LogicalResult.h"
2521#include " mlir/Transforms/DialectConversion.h"
2622#include " clang/CIR/Dialect/IR/CIRDialect.h"
2723#include " clang/CIR/Dialect/IR/CIRTypes.h"
2824#include " clang/CIR/LowerToMLIR.h"
29- #include " clang/CIR/Passes.h"
3025#include " llvm/ADT/TypeSwitch.h"
3126
3227using namespace cir ;
@@ -52,6 +47,7 @@ class SCFLoop {
5247
5348 mlir::Value plusConstant (mlir::Value v, mlir::Location loc, int addend);
5449 void transferToSCFForOp ();
50+ void transformToSCFWhileOp ();
5551
5652private:
5753 cir::ForOp forOp;
@@ -209,21 +205,21 @@ cir::CmpOp SCFLoop::findCmpOp() {
209205 }
210206 }
211207 if (!cmpOp)
212- llvm_unreachable ( " Can't find loop CmpOp " ) ;
208+ return nullptr ;
213209
214210 auto type = cmpOp.getLhs ().getType ();
215211 if (!mlir::isa<cir::IntType>(type))
216- llvm_unreachable ( " Non-integer type IV is not supported " ) ;
212+ return nullptr ;
217213
218214 auto *lhsDefOp = cmpOp.getLhs ().getDefiningOp ();
219215 if (!lhsDefOp)
220- llvm_unreachable ( " Can't find IV load " ) ;
216+ return nullptr ;
221217 if (!isIVLoad (lhsDefOp, ivAddr))
222- llvm_unreachable ( " cmpOp LHS is not IV " ) ;
218+ return nullptr ;
223219
224220 if (cmpOp.getKind () != cir::CmpOpKind::le &&
225221 cmpOp.getKind () != cir::CmpOpKind::lt)
226- llvm_unreachable ( " Not support lowering other than le or lt comparison " ) ;
222+ return nullptr ;
227223
228224 return cmpOp;
229225}
@@ -253,30 +249,40 @@ mlir::Value SCFLoop::findIVInitValue() {
253249
254250void SCFLoop::analysis () {
255251 canonical = mlir::succeeded (findStepAndIV ());
256- if (!canonical) {
257- mlir::emitError (forOp.getLoc (),
258- " cannot handle non-constant step for induction variable" );
252+ if (!canonical)
259253 return ;
260- }
261254
262255 cmpOp = findCmpOp ();
263- auto IVInit = findIVInitValue ();
256+ if (!cmpOp) {
257+ canonical = false ;
258+ return ;
259+ }
260+
261+ auto ivInit = findIVInitValue ();
262+ if (!ivInit) {
263+ canonical = false ;
264+ return ;
265+ }
266+
264267 // The loop end value should be hoisted out of loop by -cir-mlir-scf-prepare.
265268 // So we could get the value by getRemappedValue.
266- auto IVEndBound = rewriter->getRemappedValue (cmpOp.getRhs ());
267- // If the loop end bound is not loop invariant and can't be hoisted.
268- // The following assertion will be triggerred.
269- assert (IVEndBound && " can't find IV end boundary" );
269+ auto ivEndBound = rewriter->getRemappedValue (cmpOp.getRhs ());
270+ // If the loop end bound is not loop invariant and can't be hoisted,
271+ // then this is not a canonical loop.
272+ if (!ivEndBound) {
273+ canonical = false ;
274+ return ;
275+ }
270276
271277 if (step > 0 ) {
272- lowerBound = IVInit ;
278+ lowerBound = ivInit ;
273279 if (cmpOp.getKind () == cir::CmpOpKind::lt)
274- upperBound = IVEndBound ;
280+ upperBound = ivEndBound ;
275281 else if (cmpOp.getKind () == cir::CmpOpKind::le)
276- upperBound = plusConstant (IVEndBound , cmpOp.getLoc (), 1 );
282+ upperBound = plusConstant (ivEndBound , cmpOp.getLoc (), 1 );
277283 }
278- assert ( lowerBound && " can't find loop lower bound " );
279- assert (upperBound && " can't find loop upper bound " ) ;
284+ if (! lowerBound || !upperBound)
285+ canonical = false ;
280286}
281287
282288void SCFLoop::transferToSCFForOp () {
@@ -309,6 +315,28 @@ void SCFLoop::transferToSCFForOp() {
309315 });
310316}
311317
318+ void SCFLoop::transformToSCFWhileOp () {
319+ auto scfWhileOp = rewriter->create <mlir::scf::WhileOp>(
320+ forOp->getLoc (), forOp->getResultTypes (), mlir::ValueRange ());
321+ rewriter->createBlock (&scfWhileOp.getBefore ());
322+ rewriter->createBlock (&scfWhileOp.getAfter ());
323+
324+ rewriter->inlineBlockBefore (&forOp.getCond ().front (),
325+ scfWhileOp.getBeforeBody (),
326+ scfWhileOp.getBeforeBody ()->end ());
327+ rewriter->inlineBlockBefore (&forOp.getBody ().front (),
328+ scfWhileOp.getAfterBody (),
329+ scfWhileOp.getAfterBody ()->end ());
330+ // There will be a yield after the `for` body.
331+ // We should delete it.
332+ auto yield = mlir::cast<YieldOp>(scfWhileOp.getAfterBody ()->back ());
333+ rewriter->eraseOp (yield);
334+
335+ rewriter->inlineBlockBefore (&forOp.getStep ().front (),
336+ scfWhileOp.getAfterBody (),
337+ scfWhileOp.getAfterBody ()->end ());
338+ }
339+
312340void SCFWhileLoop::transferToSCFWhileOp () {
313341 auto scfWhileOp = rewriter->create <mlir::scf::WhileOp>(
314342 whileOp->getLoc (), whileOp->getResultTypes (), adaptor.getOperands ());
@@ -352,9 +380,11 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
352380 mlir::ConversionPatternRewriter &rewriter) const override {
353381 SCFLoop loop (op, &rewriter);
354382 loop.analysis ();
355- if (!loop.isCanonical ())
356- return mlir::emitError (op.getLoc (),
357- " cannot handle non-canonicalized loop" );
383+ if (!loop.isCanonical ()) {
384+ loop.transformToSCFWhileOp ();
385+ rewriter.eraseOp (op);
386+ return mlir::success ();
387+ }
358388
359389 loop.transferToSCFForOp ();
360390 rewriter.eraseOp (op);
0 commit comments