@@ -36,7 +36,8 @@ namespace flangomp {
3636#include " flang/Optimizer/OpenMP/Passes.h.inc"
3737} // namespace flangomp
3838
39- #define DEBUG_TYPE " fopenmp-do-concurrent-conversion"
39+ #define DEBUG_TYPE " do-concurrent-conversion"
40+ #define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE << " ]: " )
4041
4142namespace Fortran {
4243namespace lower {
@@ -45,14 +46,12 @@ namespace internal {
4546// TODO The following 2 functions are copied from "flang/Lower/OpenMP/Utils.h".
4647// This duplication is temporary until we find a solution for a shared location
4748// for these utils that does not introduce circular CMake deps.
48- mlir::omp::MapInfoOp
49- createMapInfoOp (mlir::OpBuilder &builder, mlir::Location loc,
50- mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
51- llvm::ArrayRef<mlir::Value> bounds,
52- llvm::ArrayRef<mlir::Value> members,
53- mlir::ArrayAttr membersIndex, uint64_t mapType,
54- mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
55- bool partialMap = false ) {
49+ mlir::omp::MapInfoOp createMapInfoOp (
50+ mlir::OpBuilder &builder, mlir::Location loc, mlir::Value baseAddr,
51+ mlir::Value varPtrPtr, std::string name, llvm::ArrayRef<mlir::Value> bounds,
52+ llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex,
53+ uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType,
54+ mlir::Type retTy, bool partialMap = false ) {
5655 if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType ())) {
5756 baseAddr = builder.create <fir::BoxAddrOp>(loc, baseAddr);
5857 retTy = baseAddr.getType ();
@@ -255,9 +254,21 @@ bool isIndVarUltimateOperand(mlir::Operation *op, fir::DoLoopOp doLoop) {
255254 return false ;
256255}
257256
257+ // / For the \p doLoop parameter, find the operations that declares its induction
258+ // / variable or allocates memory for it.
259+ mlir::Operation *findLoopIndVarMemDecl (fir::DoLoopOp doLoop) {
260+ mlir::Value result = nullptr ;
261+ mlir::visitUsedValuesDefinedAbove (
262+ doLoop.getRegion (), [&](mlir::OpOperand *operand) {
263+ if (isIndVarUltimateOperand (operand->getOwner (), doLoop))
264+ result = operand->get ();
265+ });
266+
267+ assert (result.getDefiningOp () != nullptr );
268+ return result.getDefiningOp ();
269+ }
270+
258271// / Collect the list of values used inside the loop but defined outside of it.
259- // / The first item in the returned list is always the loop's induction
260- // / variable.
261272void collectLoopLiveIns (fir::DoLoopOp doLoop,
262273 llvm::SmallVectorImpl<mlir::Value> &liveIns) {
263274 llvm::SmallDenseSet<mlir::Value> seenValues;
@@ -274,9 +285,6 @@ void collectLoopLiveIns(fir::DoLoopOp doLoop,
274285 return ;
275286
276287 liveIns.push_back (operand->get ());
277-
278- if (isIndVarUltimateOperand (operand->getOwner (), doLoop))
279- std::swap (*liveIns.begin (), *liveIns.rbegin ());
280288 });
281289}
282290
@@ -366,24 +374,78 @@ void collectIndirectConstOpChain(mlir::Operation *link,
366374 opChain.insert (link);
367375}
368376
377+ // / Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff
378+ // / there are no operations in \p outerloop's other than:
379+ // /
380+ // / 1. the operations needed to assing/update \p outerLoop's induction variable.
381+ // / 2. \p innerLoop itself.
382+ // /
383+ // / \p return true if \p innerLoop is perfectly nested inside \p outerLoop
384+ // / according to the above definition.
385+ bool isPerfectlyNested (fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
386+ mlir::BackwardSliceOptions backwardSliceOptions;
387+ backwardSliceOptions.inclusive = true ;
388+ // We will collect the backward slices for innerLoop's LB, UB, and step.
389+ // However, we want to limit the scope of these slices to the scope of
390+ // outerLoop's region.
391+ backwardSliceOptions.filter = [&](mlir::Operation *op) {
392+ return !mlir::areValuesDefinedAbove (op->getResults (),
393+ outerLoop.getRegion ());
394+ };
395+
396+ mlir::ForwardSliceOptions forwardSliceOptions;
397+ forwardSliceOptions.inclusive = true ;
398+ // We don't care about the outer-loop's induction variable's uses within the
399+ // inner-loop, so we filter out these uses.
400+ forwardSliceOptions.filter = [&](mlir::Operation *op) {
401+ return mlir::areValuesDefinedAbove (op->getResults (), innerLoop.getRegion ());
402+ };
403+
404+ llvm::SetVector<mlir::Operation *> indVarSlice;
405+ mlir::getForwardSlice (outerLoop.getInductionVar (), &indVarSlice,
406+ forwardSliceOptions);
407+ llvm::DenseSet<mlir::Operation *> innerLoopSetupOpsSet (indVarSlice.begin (),
408+ indVarSlice.end ());
409+
410+ llvm::DenseSet<mlir::Operation *> loopBodySet;
411+ outerLoop.walk <mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
412+ if (op == outerLoop)
413+ return mlir::WalkResult::advance ();
414+
415+ if (op == innerLoop)
416+ return mlir::WalkResult::skip ();
417+
418+ if (op->hasTrait <mlir::OpTrait::IsTerminator>())
419+ return mlir::WalkResult::advance ();
420+
421+ loopBodySet.insert (op);
422+ return mlir::WalkResult::advance ();
423+ });
424+
425+ bool result = (loopBodySet == innerLoopSetupOpsSet);
426+ mlir::Location loc = outerLoop.getLoc ();
427+ LLVM_DEBUG (DBGS () << " Loop pair starting at location " << loc << " is"
428+ << (result ? " " : " not" ) << " perfectly nested\n " );
429+
430+ return result;
431+ }
432+
369433// / Starting with `outerLoop` collect a perfectly nested loop nest, if any. This
370434// / function collects as much as possible loops in the nest; it case it fails to
371435// / recognize a certain nested loop as part of the nest it just returns the
372436// / parent loops it discovered before.
373- mlir::LogicalResult collectLoopNest (fir::DoLoopOp outerLoop ,
437+ mlir::LogicalResult collectLoopNest (fir::DoLoopOp currentLoop ,
374438 LoopNestToIndVarMap &loopNest) {
375- assert (outerLoop.getUnordered ());
376- llvm::SmallVector<mlir::Value> outerLoopLiveIns;
377- collectLoopLiveIns (outerLoop, outerLoopLiveIns);
439+ assert (currentLoop.getUnordered ());
378440
379441 while (true ) {
380442 loopNest.try_emplace (
381- outerLoop ,
443+ currentLoop ,
382444 InductionVariableInfo{
383- outerLoopLiveIns. front (). getDefiningOp ( ),
384- std::move (looputils::extractIndVarUpdateOps (outerLoop ))});
445+ findLoopIndVarMemDecl (currentLoop ),
446+ std::move (looputils::extractIndVarUpdateOps (currentLoop ))});
385447
386- auto directlyNestedLoops = outerLoop .getRegion ().getOps <fir::DoLoopOp>();
448+ auto directlyNestedLoops = currentLoop .getRegion ().getOps <fir::DoLoopOp>();
387449 llvm::SmallVector<fir::DoLoopOp> unorderedLoops;
388450
389451 for (auto nestedLoop : directlyNestedLoops)
@@ -398,69 +460,10 @@ mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop,
398460
399461 fir::DoLoopOp nestedUnorderedLoop = unorderedLoops.front ();
400462
401- if ((nestedUnorderedLoop.getLowerBound ().getDefiningOp () == nullptr ) ||
402- (nestedUnorderedLoop.getUpperBound ().getDefiningOp () == nullptr ) ||
403- (nestedUnorderedLoop.getStep ().getDefiningOp () == nullptr ))
404- return mlir::failure ();
405-
406- llvm::SmallVector<mlir::Value> nestedLiveIns;
407- collectLoopLiveIns (nestedUnorderedLoop, nestedLiveIns);
408-
409- llvm::DenseSet<mlir::Value> outerLiveInsSet;
410- llvm::DenseSet<mlir::Value> nestedLiveInsSet;
411-
412- // Returns a "unified" view of an mlir::Value. This utility checks if the
413- // value is defined by an op, and if so, return the first value defined by
414- // that op (if there are many), otherwise just returns the value.
415- //
416- // This serves the purpose that if, for example, `%op_res#0` is used in the
417- // outer loop and `%op_res#1` is used in the nested loop (or vice versa),
418- // that we detect both as the same value. If we did not do so, we might
419- // falesely detect that the 2 loops are not perfectly nested since they use
420- // "different" sets of values.
421- auto getUnifiedLiveInView = [](mlir::Value liveIn) {
422- return liveIn.getDefiningOp () != nullptr
423- ? liveIn.getDefiningOp ()->getResult (0 )
424- : liveIn;
425- };
426-
427- // Re-package both lists of live-ins into sets so that we can use set
428- // equality to compare the values used in the outerloop vs. the nestd one.
429-
430- for (auto liveIn : nestedLiveIns)
431- nestedLiveInsSet.insert (getUnifiedLiveInView (liveIn));
432-
433- mlir::Value outerLoopIV;
434- for (auto liveIn : outerLoopLiveIns) {
435- outerLiveInsSet.insert (getUnifiedLiveInView (liveIn));
436-
437- // Keep track of the IV of the outerloop. See `isPerfectlyNested` for more
438- // info on the reason.
439- if (outerLoopIV == nullptr )
440- outerLoopIV = getUnifiedLiveInView (liveIn);
441- }
442-
443- // For the 2 loops to be perfectly nested, either:
444- // * both would have exactly the same set of live-in values or,
445- // * the outer loop would have exactly 1 extra live-in value: the outer
446- // loop's induction variable; this happens when the outer loop's IV is
447- // *not* referenced in the nested loop.
448- bool isPerfectlyNested = [&]() {
449- if (outerLiveInsSet == nestedLiveInsSet)
450- return true ;
451-
452- if ((outerLiveInsSet.size () == nestedLiveIns.size () + 1 ) &&
453- !nestedLiveInsSet.contains (outerLoopIV))
454- return true ;
455-
456- return false ;
457- }();
458-
459- if (!isPerfectlyNested)
463+ if (!isPerfectlyNested (currentLoop, nestedUnorderedLoop))
460464 return mlir::failure ();
461465
462- outerLoop = nestedUnorderedLoop;
463- outerLoopLiveIns = std::move (nestedLiveIns);
466+ currentLoop = nestedUnorderedLoop;
464467 }
465468
466469 return mlir::success ();
@@ -634,10 +637,6 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
634637 " defining operation." );
635638 }
636639
637- llvm::SmallVector<mlir::Value> outermostLoopLiveIns;
638- looputils::collectLoopLiveIns (doLoop, outermostLoopLiveIns);
639- assert (!outermostLoopLiveIns.empty ());
640-
641640 looputils::LoopNestToIndVarMap loopNest;
642641 bool hasRemainingNestedLoops =
643642 failed (looputils::collectLoopNest (doLoop, loopNest));
@@ -646,15 +645,19 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
646645 " Some `do concurent` loops are not perfectly-nested. "
647646 " These will be serialzied." );
648647
648+ llvm::SmallVector<mlir::Value> loopNestLiveIns;
649+ looputils::collectLoopLiveIns (loopNest.back ().first , loopNestLiveIns);
650+ assert (!loopNestLiveIns.empty ());
651+
649652 llvm::SetVector<mlir::Value> locals;
650653 looputils::collectLoopLocalValues (loopNest.back ().first , locals);
651654 // We do not want to map "loop-local" values to the device through
652655 // `omp.map.info` ops. Therefore, we remove them from the list of live-ins.
653- outermostLoopLiveIns .erase (llvm::remove_if (outermostLoopLiveIns ,
654- [&](mlir::Value liveIn) {
655- return locals.contains (liveIn);
656- }),
657- outermostLoopLiveIns .end ());
656+ loopNestLiveIns .erase (llvm::remove_if (loopNestLiveIns ,
657+ [&](mlir::Value liveIn) {
658+ return locals.contains (liveIn);
659+ }),
660+ loopNestLiveIns .end ());
658661
659662 looputils::sinkLoopIVArgs (rewriter, loopNest);
660663
@@ -669,12 +672,12 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
669672 // The outermost loop will contain all the live-in values in all nested
670673 // loops since live-in values are collected recursively for all nested
671674 // ops.
672- for (mlir::Value liveIn : outermostLoopLiveIns )
675+ for (mlir::Value liveIn : loopNestLiveIns )
673676 targetClauseOps.mapVars .push_back (
674677 genMapInfoOpForLiveIn (rewriter, liveIn));
675678
676- targetOp = genTargetOp (doLoop.getLoc (), rewriter, mapper,
677- outermostLoopLiveIns, targetClauseOps);
679+ targetOp = genTargetOp (doLoop.getLoc (), rewriter, mapper, loopNestLiveIns,
680+ targetClauseOps);
678681 genTeamsOp (doLoop.getLoc (), rewriter);
679682 }
680683
@@ -1062,10 +1065,11 @@ class DoConcurrentConversionPass
10621065 context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
10631066 concurrentLoopsToSkip);
10641067 mlir::ConversionTarget target (*context);
1065- target.addLegalDialect <
1066- fir::FIROpsDialect, hlfir::hlfirDialect, mlir::arith::ArithDialect,
1067- mlir::func::FuncDialect, mlir::omp::OpenMPDialect,
1068- mlir::cf::ControlFlowDialect, mlir::math::MathDialect>();
1068+ target
1069+ .addLegalDialect <fir::FIROpsDialect, hlfir::hlfirDialect,
1070+ mlir::arith::ArithDialect, mlir::func::FuncDialect,
1071+ mlir::omp::OpenMPDialect, mlir::cf::ControlFlowDialect,
1072+ mlir::math::MathDialect, mlir::LLVM::LLVMDialect>();
10691073
10701074 target.addDynamicallyLegalOp <fir::DoLoopOp>([&](fir::DoLoopOp op) {
10711075 return !op.getUnordered () || concurrentLoopsToSkip.contains (op);
0 commit comments