@@ -54,10 +54,6 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
5454 if (op->getParentOfType <mlir::omp::SimdOp>() &&
5555 (mlir::isa<mlir::omp::YieldOp>(op) ||
5656 mlir::isa<mlir::omp::LoopNestOp>(op) ||
57- mlir::isa<mlir::omp::WsloopOp>(op) ||
58- mlir::isa<mlir::omp::WorkshareLoopWrapperOp>(op) ||
59- mlir::isa<mlir::omp::DistributeOp>(op) ||
60- mlir::isa<mlir::omp::TaskloopOp>(op) ||
6157 mlir::isa<mlir::omp::TerminatorOp>(op)))
6258 return rewriter.notifyMatchFailure (op, " Op is part of a simd construct" );
6359
@@ -67,6 +63,10 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
6763 return rewriter.notifyMatchFailure (op,
6864 " Non top-level yield or terminator" );
6965
66+ if (mlir::isa<mlir::omp::UnrollHeuristicOp>(op))
67+ return rewriter.notifyMatchFailure (
68+ op, " UnrollHeuristic has special handling" );
69+
7070 // SectionOp overrides its BlockArgInterface based on the parent SectionsOp.
7171 // We need to make sure we only rewrite omp.sections once all omp.section
7272 // ops inside it have been rewritten, otherwise the individual omp.section
@@ -291,6 +291,129 @@ class SimdOnlyConversionPattern : public mlir::RewritePattern {
291291 return mlir::success ();
292292 }
293293
294+ if (auto cLoopOp = mlir::dyn_cast<mlir::omp::CanonicalLoopOp>(op)) {
295+ assert (" CanonicalLoopOp has one region" && cLoopOp->getNumRegions () == 1 );
296+ auto cli = cLoopOp.getCli ();
297+ auto tripCount = cLoopOp.getTripCount ();
298+
299+ builder.setInsertionPoint (cLoopOp);
300+ mlir::Type indexType = builder.getIndexType ();
301+ mlir::Type oldIndexType = tripCount.getType ();
302+ auto one = mlir::arith::ConstantIndexOp::create (builder, loc, 1 );
303+ auto ub = builder.createConvert (loc, indexType, tripCount);
304+
305+ llvm::SmallVector<mlir::Value> loopIndArgs;
306+ auto doLoop = fir::DoLoopOp::create (builder, loc, one, ub, one, false );
307+ builder.setInsertionPointToStart (doLoop.getBody ());
308+ if (oldIndexType != indexType) {
309+ auto convertedIndVar =
310+ builder.createConvert (loc, oldIndexType, doLoop.getInductionVar ());
311+ loopIndArgs.push_back (convertedIndVar);
312+ } else {
313+ loopIndArgs.push_back (doLoop.getInductionVar ());
314+ }
315+
316+ if (cLoopOp.getRegion ().getBlocks ().size () == 1 ) {
317+ auto &block = *cLoopOp.getRegion ().getBlocks ().begin ();
318+ // DoLoopOp will handle incrementing the induction variable
319+ if (auto addIOp = mlir::dyn_cast<mlir::arith::AddIOp>(block.front ())) {
320+ rewriter.replaceOpUsesWithinBlock (addIOp, addIOp.getLhs (), &block);
321+ rewriter.eraseOp (addIOp);
322+ }
323+
324+ rewriter.mergeBlocks (&block, doLoop.getBody (), loopIndArgs);
325+
326+ // Find the new loop block terminator and move it before the end of the
327+ // block
328+ for (auto &loopBodyOp : doLoop.getBody ()->getOperations ()) {
329+ if (auto resultOp = mlir::dyn_cast<fir::ResultOp>(loopBodyOp)) {
330+ rewriter.moveOpBefore (resultOp.getOperation (),
331+ &doLoop.getBody ()->back ());
332+ break ;
333+ }
334+ }
335+
336+ // Remove omp.terminator at the end of the loop body
337+ if (auto terminatorOp = mlir::dyn_cast<mlir::omp::TerminatorOp>(
338+ doLoop.getBody ()->back ())) {
339+ rewriter.eraseOp (terminatorOp);
340+ }
341+ } else {
342+ rewriter.inlineRegionBefore (cLoopOp->getRegion (0 ), doLoop.getBody ());
343+ auto indVarArg = doLoop.getBody ()->getArgument (0 );
344+ // fir::convertDoLoopToCFG expects the induction variable to be of type
345+ // index while the OpenMP CanonicalLoopOp can have indices of different
346+ // types. We need to work around it.
347+ if (indVarArg.getType () != indexType)
348+ indVarArg.setType (indexType);
349+
350+ // fir.do_loop, unlike omp.canonical_loop does not support multi-block
351+ // regions. If we're dealing with multiple blocks inside omp.loop_nest,
352+ // we need to convert it into basic control-flow operations instead.
353+ auto loopBlocks =
354+ fir::convertDoLoopToCFG (doLoop, rewriter, false , false );
355+ auto *conditionalBlock = loopBlocks.first ;
356+ auto *firstBlock =
357+ conditionalBlock->getNextNode (); // Start of the loop body
358+ auto *lastBlock = loopBlocks.second ; // Incrementing induction variables
359+
360+ // Incrementing the induction variable is handled elsewhere
361+ if (auto addIOp =
362+ mlir::dyn_cast<mlir::arith::AddIOp>(firstBlock->front ())) {
363+ rewriter.replaceOpUsesWithinBlock (addIOp, addIOp.getLhs (),
364+ firstBlock);
365+ rewriter.eraseOp (addIOp);
366+ }
367+
368+ // If the induction variable is used within the loop and was originally
369+ // not of type index, then we need to add a convert to the original type
370+ // and replace its uses inside the loop body.
371+ if (oldIndexType != indexType) {
372+ indVarArg = conditionalBlock->getArgument (0 );
373+ builder.setInsertionPointToStart (firstBlock);
374+ auto convertedIndVar =
375+ builder.createConvert (loc, oldIndexType, indVarArg);
376+ rewriter.replaceUsesWithIf (
377+ indVarArg, convertedIndVar, [&](auto &use) -> bool {
378+ return use.getOwner () != convertedIndVar.getDefiningOp () &&
379+ use.getOwner ()->getBlock () != lastBlock;
380+ });
381+ }
382+
383+ // There might be an unused convert and an unused argument to the block.
384+ // If so, remove them.
385+ if (lastBlock->front ().getUses ().empty ())
386+ lastBlock->front ().erase ();
387+ for (auto arg : lastBlock->getArguments ()) {
388+ if (arg.getUses ().empty ())
389+ lastBlock->eraseArgument (arg.getArgNumber ());
390+ }
391+
392+ // Any loop blocks that end in omp.terminator should just branch to
393+ // lastBlock.
394+ for (auto *loopBlock = conditionalBlock; loopBlock != lastBlock;
395+ loopBlock = loopBlock->getNextNode ()) {
396+ if (auto terminatorOp =
397+ mlir::dyn_cast<mlir::omp::TerminatorOp>(loopBlock->back ())) {
398+ builder.setInsertionPointToEnd (loopBlock);
399+ mlir::cf::BranchOp::create (builder, loc, lastBlock);
400+ rewriter.eraseOp (terminatorOp);
401+ }
402+ }
403+ }
404+
405+ rewriter.eraseOp (cLoopOp);
406+ // Handle the optional omp.new_cli op
407+ if (cli) {
408+ // cli will be used by omp.unroll_heuristic ops
409+ for (auto *user : cli.getUsers ())
410+ rewriter.eraseOp (user);
411+ rewriter.eraseOp (cli.getDefiningOp ());
412+ }
413+
414+ return mlir::success ();
415+ }
416+
294417 auto inlineSimpleOp = [&](mlir::Operation *ompOp) -> bool {
295418 if (!ompOp)
296419 return false ;
0 commit comments