@@ -23,51 +23,40 @@ namespace mlir {
2323
2424using namespace mlir ;
2525
26- LogicalResult mlir:: scf::parallelForToNestedFors (RewriterBase &rewriter,
27- scf::ParallelOp parallelOp ,
28- scf::ForOp *result ) {
26+ FailureOr< scf::LoopNest>
27+ mlir:: scf::parallelForToNestedFors (RewriterBase &rewriter ,
28+ scf::ParallelOp parallelOp ) {
2929
3030 if (!parallelOp.getResults ().empty ()) {
31- parallelOp->emitError (" Currently ScfParallel to ScfFor conversion "
32- " doesn't support ScfParallel with results." );
31+ parallelOp->emitError (" Currently scf.parallel to scf.for conversion "
32+ " doesn't support scf.parallel with results." );
3333 return failure ();
3434 }
3535
3636 rewriter.setInsertionPoint (parallelOp);
3737
3838 Location loc = parallelOp.getLoc ();
39- auto lowerBounds = parallelOp.getLowerBound ();
40- auto upperBounds = parallelOp.getUpperBound ();
41- auto steps = parallelOp.getStep ();
39+ SmallVector<Value> lowerBounds = parallelOp.getLowerBound ();
40+ SmallVector<Value> upperBounds = parallelOp.getUpperBound ();
41+ SmallVector<Value> steps = parallelOp.getStep ();
4242
4343 assert (lowerBounds.size () == upperBounds.size () &&
4444 lowerBounds.size () == steps.size () &&
4545 " Mismatched parallel loop bounds" );
4646
4747 SmallVector<Value> ivs;
48- auto loopNest =
48+ scf::LoopNest loopNest =
4949 scf::buildLoopNest (rewriter, loc, lowerBounds, upperBounds, steps);
5050
51- auto oldInductionVars = parallelOp.getInductionVars ();
52- auto newInductionVars = llvm::map_to_vector (
51+ SmallVector<Value> newInductionVars = llvm::map_to_vector (
5352 loopNest.loops , [](scf::ForOp forOp) { return forOp.getInductionVar (); });
54- assert (oldInductionVars.size () == newInductionVars.size () &&
55- " Mismatched induction variables" );
56- for (auto [oldIV, newIV] : llvm::zip (oldInductionVars, newInductionVars))
57- oldIV.replaceAllUsesWith (newIV);
58-
59- auto *linearizedBody = loopNest.loops .back ().getBody ();
60- Block ¶llelBody = *parallelOp.getBody ();
61- for (Operation &op : llvm::make_early_inc_range (parallelBody)) {
62- // Skip the terminator of the parallelOp body.
63- if (&op == parallelBody.getTerminator ())
64- continue ;
65- op.moveBefore (linearizedBody->getTerminator ());
66- }
53+ Block *linearizedBody = loopNest.loops .back ().getBody ();
54+ Block *parallelBody = parallelOp.getBody ();
55+ rewriter.eraseOp (parallelBody->getTerminator ());
56+ rewriter.inlineBlockBefore (parallelBody, linearizedBody->getTerminator (),
57+ newInductionVars);
6758 rewriter.eraseOp (parallelOp);
68- if (result)
69- *result = loopNest.loops .front ();
70- return success ();
59+ return loopNest;
7160}
7261
7362namespace {
0 commit comments