@@ -36,7 +36,8 @@ namespace flangomp {
3636#include " flang/Optimizer/OpenMP/Passes.h.inc"
3737} // namespace flangomp
3838
39- #define DEBUG_TYPE " fopenmp-do-concurrent-conversion"
39+ #define DEBUG_TYPE " do-concurrent-conversion"
40+ #define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE << " ]: " )
4041
4142namespace Fortran {
4243namespace lower {
@@ -45,14 +46,12 @@ namespace internal {
4546// TODO The following 2 functions are copied from "flang/Lower/OpenMP/Utils.h".
4647// This duplication is temporary until we find a solution for a shared location
4748// for these utils that does not introduce circular CMake deps.
48- mlir::omp::MapInfoOp
49- createMapInfoOp (mlir::OpBuilder &builder, mlir::Location loc,
50- mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
51- llvm::ArrayRef<mlir::Value> bounds,
52- llvm::ArrayRef<mlir::Value> members,
53- mlir::ArrayAttr membersIndex, uint64_t mapType,
54- mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
55- bool partialMap = false ) {
49+ mlir::omp::MapInfoOp createMapInfoOp (
50+ mlir::OpBuilder &builder, mlir::Location loc, mlir::Value baseAddr,
51+ mlir::Value varPtrPtr, std::string name, llvm::ArrayRef<mlir::Value> bounds,
52+ llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex,
53+ uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType,
54+ mlir::Type retTy, bool partialMap = false ) {
5655 if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType ())) {
5756 baseAddr = builder.create <fir::BoxAddrOp>(loc, baseAddr);
5857 retTy = baseAddr.getType ();
@@ -255,9 +254,24 @@ bool isIndVarUltimateOperand(mlir::Operation *op, fir::DoLoopOp doLoop) {
255254 return false ;
256255}
257256
257+ // / For the \p doLoop parameter, find the operations that declares its induction
258+ // / variable or allocates memory for it.
259+ mlir::Operation *findLoopIndVarMemDecl (fir::DoLoopOp doLoop) {
260+ mlir::Value result = nullptr ;
261+ mlir::visitUsedValuesDefinedAbove (
262+ doLoop.getRegion (), [&](mlir::OpOperand *operand) {
263+ if (isIndVarUltimateOperand (operand->getOwner (), doLoop)) {
264+ assert (result == nullptr &&
265+ " loop can have only one induction variable" );
266+ result = operand->get ();
267+ }
268+ });
269+
270+ assert (result != nullptr && result.getDefiningOp () != nullptr );
271+ return result.getDefiningOp ();
272+ }
273+
258274// / Collect the list of values used inside the loop but defined outside of it.
259- // / The first item in the returned list is always the loop's induction
260- // / variable.
261275void collectLoopLiveIns (fir::DoLoopOp doLoop,
262276 llvm::SmallVectorImpl<mlir::Value> &liveIns) {
263277 llvm::SmallDenseSet<mlir::Value> seenValues;
@@ -274,9 +288,6 @@ void collectLoopLiveIns(fir::DoLoopOp doLoop,
274288 return ;
275289
276290 liveIns.push_back (operand->get ());
277-
278- if (isIndVarUltimateOperand (operand->getOwner (), doLoop))
279- std::swap (*liveIns.begin (), *liveIns.rbegin ());
280291 });
281292}
282293
@@ -366,24 +377,78 @@ void collectIndirectConstOpChain(mlir::Operation *link,
366377 opChain.insert (link);
367378}
368379
380+ // / Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff
381+ // / there are no operations in \p outerloop's other than:
382+ // /
383+ // / 1. the operations needed to assing/update \p outerLoop's induction variable.
384+ // / 2. \p innerLoop itself.
385+ // /
386+ // / \p return true if \p innerLoop is perfectly nested inside \p outerLoop
387+ // / according to the above definition.
388+ bool isPerfectlyNested (fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
389+ mlir::BackwardSliceOptions backwardSliceOptions;
390+ backwardSliceOptions.inclusive = true ;
391+ // We will collect the backward slices for innerLoop's LB, UB, and step.
392+ // However, we want to limit the scope of these slices to the scope of
393+ // outerLoop's region.
394+ backwardSliceOptions.filter = [&](mlir::Operation *op) {
395+ return !mlir::areValuesDefinedAbove (op->getResults (),
396+ outerLoop.getRegion ());
397+ };
398+
399+ mlir::ForwardSliceOptions forwardSliceOptions;
400+ forwardSliceOptions.inclusive = true ;
401+ // We don't care about the outer-loop's induction variable's uses within the
402+ // inner-loop, so we filter out these uses.
403+ forwardSliceOptions.filter = [&](mlir::Operation *op) {
404+ return mlir::areValuesDefinedAbove (op->getResults (), innerLoop.getRegion ());
405+ };
406+
407+ llvm::SetVector<mlir::Operation *> indVarSlice;
408+ mlir::getForwardSlice (outerLoop.getInductionVar (), &indVarSlice,
409+ forwardSliceOptions);
410+ llvm::DenseSet<mlir::Operation *> innerLoopSetupOpsSet (indVarSlice.begin (),
411+ indVarSlice.end ());
412+
413+ llvm::DenseSet<mlir::Operation *> loopBodySet;
414+ outerLoop.walk <mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
415+ if (op == outerLoop)
416+ return mlir::WalkResult::advance ();
417+
418+ if (op == innerLoop)
419+ return mlir::WalkResult::skip ();
420+
421+ if (mlir::isa<fir::ResultOp>(op))
422+ return mlir::WalkResult::advance ();
423+
424+ loopBodySet.insert (op);
425+ return mlir::WalkResult::advance ();
426+ });
427+
428+ bool result = (loopBodySet == innerLoopSetupOpsSet);
429+ mlir::Location loc = outerLoop.getLoc ();
430+ LLVM_DEBUG (DBGS () << " Loop pair starting at location " << loc << " is"
431+ << (result ? " " : " not" ) << " perfectly nested\n " );
432+
433+ return result;
434+ }
435+
369436// / Starting with `outerLoop` collect a perfectly nested loop nest, if any. This
370437// / function collects as much as possible loops in the nest; it case it fails to
371438// / recognize a certain nested loop as part of the nest it just returns the
372439// / parent loops it discovered before.
373- mlir::LogicalResult collectLoopNest (fir::DoLoopOp outerLoop ,
440+ mlir::LogicalResult collectLoopNest (fir::DoLoopOp currentLoop ,
374441 LoopNestToIndVarMap &loopNest) {
375- assert (outerLoop.getUnordered ());
376- llvm::SmallVector<mlir::Value> outerLoopLiveIns;
377- collectLoopLiveIns (outerLoop, outerLoopLiveIns);
442+ assert (currentLoop.getUnordered ());
378443
379444 while (true ) {
380445 loopNest.try_emplace (
381- outerLoop ,
446+ currentLoop ,
382447 InductionVariableInfo{
383- outerLoopLiveIns. front (). getDefiningOp ( ),
384- std::move (looputils::extractIndVarUpdateOps (outerLoop ))});
448+ findLoopIndVarMemDecl (currentLoop ),
449+ std::move (looputils::extractIndVarUpdateOps (currentLoop ))});
385450
386- auto directlyNestedLoops = outerLoop .getRegion ().getOps <fir::DoLoopOp>();
451+ auto directlyNestedLoops = currentLoop .getRegion ().getOps <fir::DoLoopOp>();
387452 llvm::SmallVector<fir::DoLoopOp> unorderedLoops;
388453
389454 for (auto nestedLoop : directlyNestedLoops)
@@ -398,69 +463,10 @@ mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop,
398463
399464 fir::DoLoopOp nestedUnorderedLoop = unorderedLoops.front ();
400465
401- if ((nestedUnorderedLoop.getLowerBound ().getDefiningOp () == nullptr ) ||
402- (nestedUnorderedLoop.getUpperBound ().getDefiningOp () == nullptr ) ||
403- (nestedUnorderedLoop.getStep ().getDefiningOp () == nullptr ))
466+ if (!isPerfectlyNested (currentLoop, nestedUnorderedLoop))
404467 return mlir::failure ();
405468
406- llvm::SmallVector<mlir::Value> nestedLiveIns;
407- collectLoopLiveIns (nestedUnorderedLoop, nestedLiveIns);
408-
409- llvm::DenseSet<mlir::Value> outerLiveInsSet;
410- llvm::DenseSet<mlir::Value> nestedLiveInsSet;
411-
412- // Returns a "unified" view of an mlir::Value. This utility checks if the
413- // value is defined by an op, and if so, return the first value defined by
414- // that op (if there are many), otherwise just returns the value.
415- //
416- // This serves the purpose that if, for example, `%op_res#0` is used in the
417- // outer loop and `%op_res#1` is used in the nested loop (or vice versa),
418- // that we detect both as the same value. If we did not do so, we might
419- // falesely detect that the 2 loops are not perfectly nested since they use
420- // "different" sets of values.
421- auto getUnifiedLiveInView = [](mlir::Value liveIn) {
422- return liveIn.getDefiningOp () != nullptr
423- ? liveIn.getDefiningOp ()->getResult (0 )
424- : liveIn;
425- };
426-
427- // Re-package both lists of live-ins into sets so that we can use set
428- // equality to compare the values used in the outerloop vs. the nestd one.
429-
430- for (auto liveIn : nestedLiveIns)
431- nestedLiveInsSet.insert (getUnifiedLiveInView (liveIn));
432-
433- mlir::Value outerLoopIV;
434- for (auto liveIn : outerLoopLiveIns) {
435- outerLiveInsSet.insert (getUnifiedLiveInView (liveIn));
436-
437- // Keep track of the IV of the outerloop. See `isPerfectlyNested` for more
438- // info on the reason.
439- if (outerLoopIV == nullptr )
440- outerLoopIV = getUnifiedLiveInView (liveIn);
441- }
442-
443- // For the 2 loops to be perfectly nested, either:
444- // * both would have exactly the same set of live-in values or,
445- // * the outer loop would have exactly 1 extra live-in value: the outer
446- // loop's induction variable; this happens when the outer loop's IV is
447- // *not* referenced in the nested loop.
448- bool isPerfectlyNested = [&]() {
449- if (outerLiveInsSet == nestedLiveInsSet)
450- return true ;
451-
452- if ((outerLiveInsSet.size () == nestedLiveIns.size () + 1 ) &&
453- !nestedLiveInsSet.contains (outerLoopIV))
454- return true ;
455-
456- return false ;
457- }();
458-
459- if (!isPerfectlyNested)
460- return mlir::failure ();
461-
462- outerLoop = nestedUnorderedLoop;
463- outerLoopLiveIns = std::move (nestedLiveIns);
469+ currentLoop = nestedUnorderedLoop;
464470 }
465471
466472 return mlir::success ();
@@ -634,10 +640,6 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
634640 " defining operation." );
635641 }
636642
637- llvm::SmallVector<mlir::Value> outermostLoopLiveIns;
638- looputils::collectLoopLiveIns (doLoop, outermostLoopLiveIns);
639- assert (!outermostLoopLiveIns.empty ());
640-
641643 looputils::LoopNestToIndVarMap loopNest;
642644 bool hasRemainingNestedLoops =
643645 failed (looputils::collectLoopNest (doLoop, loopNest));
@@ -646,15 +648,19 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
646648 " Some `do concurent` loops are not perfectly-nested. "
647649 " These will be serialzied." );
648650
651+ llvm::SmallVector<mlir::Value> loopNestLiveIns;
652+ looputils::collectLoopLiveIns (loopNest.back ().first , loopNestLiveIns);
653+ assert (!loopNestLiveIns.empty ());
654+
649655 llvm::SetVector<mlir::Value> locals;
650656 looputils::collectLoopLocalValues (loopNest.back ().first , locals);
651657 // We do not want to map "loop-local" values to the device through
652658 // `omp.map.info` ops. Therefore, we remove them from the list of live-ins.
653- outermostLoopLiveIns .erase (llvm::remove_if (outermostLoopLiveIns ,
654- [&](mlir::Value liveIn) {
655- return locals.contains (liveIn);
656- }),
657- outermostLoopLiveIns .end ());
659+ loopNestLiveIns .erase (llvm::remove_if (loopNestLiveIns ,
660+ [&](mlir::Value liveIn) {
661+ return locals.contains (liveIn);
662+ }),
663+ loopNestLiveIns .end ());
658664
659665 looputils::sinkLoopIVArgs (rewriter, loopNest);
660666
@@ -669,12 +675,12 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
669675 // The outermost loop will contain all the live-in values in all nested
670676 // loops since live-in values are collected recursively for all nested
671677 // ops.
672- for (mlir::Value liveIn : outermostLoopLiveIns )
678+ for (mlir::Value liveIn : loopNestLiveIns )
673679 targetClauseOps.mapVars .push_back (
674680 genMapInfoOpForLiveIn (rewriter, liveIn));
675681
676- targetOp = genTargetOp (doLoop.getLoc (), rewriter, mapper,
677- outermostLoopLiveIns, targetClauseOps);
682+ targetOp = genTargetOp (doLoop.getLoc (), rewriter, mapper, loopNestLiveIns,
683+ targetClauseOps);
678684 genTeamsOp (doLoop.getLoc (), rewriter);
679685 }
680686
@@ -1062,10 +1068,11 @@ class DoConcurrentConversionPass
10621068 context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
10631069 concurrentLoopsToSkip);
10641070 mlir::ConversionTarget target (*context);
1065- target.addLegalDialect <
1066- fir::FIROpsDialect, hlfir::hlfirDialect, mlir::arith::ArithDialect,
1067- mlir::func::FuncDialect, mlir::omp::OpenMPDialect,
1068- mlir::cf::ControlFlowDialect, mlir::math::MathDialect>();
1071+ target
1072+ .addLegalDialect <fir::FIROpsDialect, hlfir::hlfirDialect,
1073+ mlir::arith::ArithDialect, mlir::func::FuncDialect,
1074+ mlir::omp::OpenMPDialect, mlir::cf::ControlFlowDialect,
1075+ mlir::math::MathDialect, mlir::LLVM::LLVMDialect>();
10691076
10701077 target.addDynamicallyLegalOp <fir::DoLoopOp>([&](fir::DoLoopOp op) {
10711078 return !op.getUnordered () || concurrentLoopsToSkip.contains (op);
0 commit comments