Skip to content

Commit 08aacb4

Browse files
committed
[flang][OpenMP][DoConcurrent] Simplify loop-nest detection logic
With llvm#114020, do-concurrent loop-nests are more conforment to the spec and easier to detect. All we need to do is to check that the only operations inside `loop A` which perfectly wraps `loop B` are: * the operations needed to update `loop A`'s iteration variable and * `loop B` itself. This PR simlifies the pass a bit using the above logic and replaces #127.
1 parent 8634cb1 commit 08aacb4

File tree

3 files changed

+190
-145
lines changed

3 files changed

+190
-145
lines changed

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 103 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ namespace flangomp {
3636
#include "flang/Optimizer/OpenMP/Passes.h.inc"
3737
} // namespace flangomp
3838

39-
#define DEBUG_TYPE "fopenmp-do-concurrent-conversion"
39+
#define DEBUG_TYPE "do-concurrent-conversion"
40+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
4041

4142
namespace Fortran {
4243
namespace lower {
@@ -45,14 +46,12 @@ namespace internal {
4546
// TODO The following 2 functions are copied from "flang/Lower/OpenMP/Utils.h".
4647
// This duplication is temporary until we find a solution for a shared location
4748
// for these utils that does not introduce circular CMake deps.
48-
mlir::omp::MapInfoOp
49-
createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc,
50-
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
51-
llvm::ArrayRef<mlir::Value> bounds,
52-
llvm::ArrayRef<mlir::Value> members,
53-
mlir::ArrayAttr membersIndex, uint64_t mapType,
54-
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
55-
bool partialMap = false) {
49+
mlir::omp::MapInfoOp createMapInfoOp(
50+
mlir::OpBuilder &builder, mlir::Location loc, mlir::Value baseAddr,
51+
mlir::Value varPtrPtr, std::string name, llvm::ArrayRef<mlir::Value> bounds,
52+
llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex,
53+
uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType,
54+
mlir::Type retTy, bool partialMap = false) {
5655
if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
5756
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
5857
retTy = baseAddr.getType();
@@ -255,9 +254,21 @@ bool isIndVarUltimateOperand(mlir::Operation *op, fir::DoLoopOp doLoop) {
255254
return false;
256255
}
257256

257+
/// For the \p doLoop parameter, find the operations that declares its induction
258+
/// variable or allocates memory for it.
259+
mlir::Operation *findLoopIndVarMemDecl(fir::DoLoopOp doLoop) {
260+
mlir::Value result = nullptr;
261+
mlir::visitUsedValuesDefinedAbove(
262+
doLoop.getRegion(), [&](mlir::OpOperand *operand) {
263+
if (isIndVarUltimateOperand(operand->getOwner(), doLoop))
264+
result = operand->get();
265+
});
266+
267+
assert(result.getDefiningOp() != nullptr);
268+
return result.getDefiningOp();
269+
}
270+
258271
/// Collect the list of values used inside the loop but defined outside of it.
259-
/// The first item in the returned list is always the loop's induction
260-
/// variable.
261272
void collectLoopLiveIns(fir::DoLoopOp doLoop,
262273
llvm::SmallVectorImpl<mlir::Value> &liveIns) {
263274
llvm::SmallDenseSet<mlir::Value> seenValues;
@@ -274,9 +285,6 @@ void collectLoopLiveIns(fir::DoLoopOp doLoop,
274285
return;
275286

276287
liveIns.push_back(operand->get());
277-
278-
if (isIndVarUltimateOperand(operand->getOwner(), doLoop))
279-
std::swap(*liveIns.begin(), *liveIns.rbegin());
280288
});
281289
}
282290

@@ -366,24 +374,78 @@ void collectIndirectConstOpChain(mlir::Operation *link,
366374
opChain.insert(link);
367375
}
368376

377+
/// Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff
378+
/// there are no operations in \p outerloop's other than:
379+
///
380+
/// 1. the operations needed to assing/update \p outerLoop's induction variable.
381+
/// 2. \p innerLoop itself.
382+
///
383+
/// \p return true if \p innerLoop is perfectly nested inside \p outerLoop
384+
/// according to the above definition.
385+
bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
386+
mlir::BackwardSliceOptions backwardSliceOptions;
387+
backwardSliceOptions.inclusive = true;
388+
// We will collect the backward slices for innerLoop's LB, UB, and step.
389+
// However, we want to limit the scope of these slices to the scope of
390+
// outerLoop's region.
391+
backwardSliceOptions.filter = [&](mlir::Operation *op) {
392+
return !mlir::areValuesDefinedAbove(op->getResults(),
393+
outerLoop.getRegion());
394+
};
395+
396+
mlir::ForwardSliceOptions forwardSliceOptions;
397+
forwardSliceOptions.inclusive = true;
398+
// We don't care about the outer-loop's induction variable's uses within the
399+
// inner-loop, so we filter out these uses.
400+
forwardSliceOptions.filter = [&](mlir::Operation *op) {
401+
return mlir::areValuesDefinedAbove(op->getResults(), innerLoop.getRegion());
402+
};
403+
404+
llvm::SetVector<mlir::Operation *> indVarSlice;
405+
mlir::getForwardSlice(outerLoop.getInductionVar(), &indVarSlice,
406+
forwardSliceOptions);
407+
llvm::DenseSet<mlir::Operation *> innerLoopSetupOpsSet(indVarSlice.begin(),
408+
indVarSlice.end());
409+
410+
llvm::DenseSet<mlir::Operation *> loopBodySet;
411+
outerLoop.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
412+
if (op == outerLoop)
413+
return mlir::WalkResult::advance();
414+
415+
if (op == innerLoop)
416+
return mlir::WalkResult::skip();
417+
418+
if (op->hasTrait<mlir::OpTrait::IsTerminator>())
419+
return mlir::WalkResult::advance();
420+
421+
loopBodySet.insert(op);
422+
return mlir::WalkResult::advance();
423+
});
424+
425+
bool result = (loopBodySet == innerLoopSetupOpsSet);
426+
mlir::Location loc = outerLoop.getLoc();
427+
LLVM_DEBUG(DBGS() << "Loop pair starting at location " << loc << " is"
428+
<< (result ? "" : " not") << " perfectly nested\n");
429+
430+
return result;
431+
}
432+
369433
/// Starting with `outerLoop` collect a perfectly nested loop nest, if any. This
370434
/// function collects as much as possible loops in the nest; it case it fails to
371435
/// recognize a certain nested loop as part of the nest it just returns the
372436
/// parent loops it discovered before.
373-
mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop,
437+
mlir::LogicalResult collectLoopNest(fir::DoLoopOp currentLoop,
374438
LoopNestToIndVarMap &loopNest) {
375-
assert(outerLoop.getUnordered());
376-
llvm::SmallVector<mlir::Value> outerLoopLiveIns;
377-
collectLoopLiveIns(outerLoop, outerLoopLiveIns);
439+
assert(currentLoop.getUnordered());
378440

379441
while (true) {
380442
loopNest.try_emplace(
381-
outerLoop,
443+
currentLoop,
382444
InductionVariableInfo{
383-
outerLoopLiveIns.front().getDefiningOp(),
384-
std::move(looputils::extractIndVarUpdateOps(outerLoop))});
445+
findLoopIndVarMemDecl(currentLoop),
446+
std::move(looputils::extractIndVarUpdateOps(currentLoop))});
385447

386-
auto directlyNestedLoops = outerLoop.getRegion().getOps<fir::DoLoopOp>();
448+
auto directlyNestedLoops = currentLoop.getRegion().getOps<fir::DoLoopOp>();
387449
llvm::SmallVector<fir::DoLoopOp> unorderedLoops;
388450

389451
for (auto nestedLoop : directlyNestedLoops)
@@ -398,69 +460,10 @@ mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop,
398460

399461
fir::DoLoopOp nestedUnorderedLoop = unorderedLoops.front();
400462

401-
if ((nestedUnorderedLoop.getLowerBound().getDefiningOp() == nullptr) ||
402-
(nestedUnorderedLoop.getUpperBound().getDefiningOp() == nullptr) ||
403-
(nestedUnorderedLoop.getStep().getDefiningOp() == nullptr))
404-
return mlir::failure();
405-
406-
llvm::SmallVector<mlir::Value> nestedLiveIns;
407-
collectLoopLiveIns(nestedUnorderedLoop, nestedLiveIns);
408-
409-
llvm::DenseSet<mlir::Value> outerLiveInsSet;
410-
llvm::DenseSet<mlir::Value> nestedLiveInsSet;
411-
412-
// Returns a "unified" view of an mlir::Value. This utility checks if the
413-
// value is defined by an op, and if so, return the first value defined by
414-
// that op (if there are many), otherwise just returns the value.
415-
//
416-
// This serves the purpose that if, for example, `%op_res#0` is used in the
417-
// outer loop and `%op_res#1` is used in the nested loop (or vice versa),
418-
// that we detect both as the same value. If we did not do so, we might
419-
// falesely detect that the 2 loops are not perfectly nested since they use
420-
// "different" sets of values.
421-
auto getUnifiedLiveInView = [](mlir::Value liveIn) {
422-
return liveIn.getDefiningOp() != nullptr
423-
? liveIn.getDefiningOp()->getResult(0)
424-
: liveIn;
425-
};
426-
427-
// Re-package both lists of live-ins into sets so that we can use set
428-
// equality to compare the values used in the outerloop vs. the nestd one.
429-
430-
for (auto liveIn : nestedLiveIns)
431-
nestedLiveInsSet.insert(getUnifiedLiveInView(liveIn));
432-
433-
mlir::Value outerLoopIV;
434-
for (auto liveIn : outerLoopLiveIns) {
435-
outerLiveInsSet.insert(getUnifiedLiveInView(liveIn));
436-
437-
// Keep track of the IV of the outerloop. See `isPerfectlyNested` for more
438-
// info on the reason.
439-
if (outerLoopIV == nullptr)
440-
outerLoopIV = getUnifiedLiveInView(liveIn);
441-
}
442-
443-
// For the 2 loops to be perfectly nested, either:
444-
// * both would have exactly the same set of live-in values or,
445-
// * the outer loop would have exactly 1 extra live-in value: the outer
446-
// loop's induction variable; this happens when the outer loop's IV is
447-
// *not* referenced in the nested loop.
448-
bool isPerfectlyNested = [&]() {
449-
if (outerLiveInsSet == nestedLiveInsSet)
450-
return true;
451-
452-
if ((outerLiveInsSet.size() == nestedLiveIns.size() + 1) &&
453-
!nestedLiveInsSet.contains(outerLoopIV))
454-
return true;
455-
456-
return false;
457-
}();
458-
459-
if (!isPerfectlyNested)
463+
if (!isPerfectlyNested(currentLoop, nestedUnorderedLoop))
460464
return mlir::failure();
461465

462-
outerLoop = nestedUnorderedLoop;
463-
outerLoopLiveIns = std::move(nestedLiveIns);
466+
currentLoop = nestedUnorderedLoop;
464467
}
465468

466469
return mlir::success();
@@ -634,10 +637,6 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
634637
"defining operation.");
635638
}
636639

637-
llvm::SmallVector<mlir::Value> outermostLoopLiveIns;
638-
looputils::collectLoopLiveIns(doLoop, outermostLoopLiveIns);
639-
assert(!outermostLoopLiveIns.empty());
640-
641640
looputils::LoopNestToIndVarMap loopNest;
642641
bool hasRemainingNestedLoops =
643642
failed(looputils::collectLoopNest(doLoop, loopNest));
@@ -646,15 +645,19 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
646645
"Some `do concurent` loops are not perfectly-nested. "
647646
"These will be serialzied.");
648647

648+
llvm::SmallVector<mlir::Value> loopNestLiveIns;
649+
looputils::collectLoopLiveIns(loopNest.back().first, loopNestLiveIns);
650+
assert(!loopNestLiveIns.empty());
651+
649652
llvm::SetVector<mlir::Value> locals;
650653
looputils::collectLoopLocalValues(loopNest.back().first, locals);
651654
// We do not want to map "loop-local" values to the device through
652655
// `omp.map.info` ops. Therefore, we remove them from the list of live-ins.
653-
outermostLoopLiveIns.erase(llvm::remove_if(outermostLoopLiveIns,
654-
[&](mlir::Value liveIn) {
655-
return locals.contains(liveIn);
656-
}),
657-
outermostLoopLiveIns.end());
656+
loopNestLiveIns.erase(llvm::remove_if(loopNestLiveIns,
657+
[&](mlir::Value liveIn) {
658+
return locals.contains(liveIn);
659+
}),
660+
loopNestLiveIns.end());
658661

659662
looputils::sinkLoopIVArgs(rewriter, loopNest);
660663

@@ -669,12 +672,12 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
669672
// The outermost loop will contain all the live-in values in all nested
670673
// loops since live-in values are collected recursively for all nested
671674
// ops.
672-
for (mlir::Value liveIn : outermostLoopLiveIns)
675+
for (mlir::Value liveIn : loopNestLiveIns)
673676
targetClauseOps.mapVars.push_back(
674677
genMapInfoOpForLiveIn(rewriter, liveIn));
675678

676-
targetOp = genTargetOp(doLoop.getLoc(), rewriter, mapper,
677-
outermostLoopLiveIns, targetClauseOps);
679+
targetOp = genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns,
680+
targetClauseOps);
678681
genTeamsOp(doLoop.getLoc(), rewriter);
679682
}
680683

@@ -1062,10 +1065,11 @@ class DoConcurrentConversionPass
10621065
context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
10631066
concurrentLoopsToSkip);
10641067
mlir::ConversionTarget target(*context);
1065-
target.addLegalDialect<
1066-
fir::FIROpsDialect, hlfir::hlfirDialect, mlir::arith::ArithDialect,
1067-
mlir::func::FuncDialect, mlir::omp::OpenMPDialect,
1068-
mlir::cf::ControlFlowDialect, mlir::math::MathDialect>();
1068+
target
1069+
.addLegalDialect<fir::FIROpsDialect, hlfir::hlfirDialect,
1070+
mlir::arith::ArithDialect, mlir::func::FuncDialect,
1071+
mlir::omp::OpenMPDialect, mlir::cf::ControlFlowDialect,
1072+
mlir::math::MathDialect, mlir::LLVM::LLVMDialect>();
10691073

10701074
target.addDynamicallyLegalOp<fir::DoLoopOp>([&](fir::DoLoopOp op) {
10711075
return !op.getUnordered() || concurrentLoopsToSkip.contains(op);
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
! Tests loop-nest detection algorithm for do-concurrent mapping.
2+
3+
! REQUIRES: asserts
4+
5+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host \
6+
! RUN: -mmlir -debug %s -o - 2> %t.log || true
7+
8+
! RUN: FileCheck %s < %t.log
9+
10+
program main
11+
implicit none
12+
13+
contains
14+
15+
subroutine foo(n)
16+
implicit none
17+
integer :: n, m
18+
integer :: i, j, k
19+
integer :: x
20+
integer, dimension(n) :: a
21+
integer, dimension(n, n, n) :: b
22+
23+
! CHECK: Loop pair starting at location
24+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested
25+
do concurrent(i=1:n, j=1:bar(n*m, n/m))
26+
a(i) = n
27+
end do
28+
29+
! CHECK: Loop pair starting at location
30+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested
31+
do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m))
32+
a(i) = n
33+
end do
34+
35+
! CHECK: Loop pair starting at location
36+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
37+
do concurrent(i=bar(n, x):n)
38+
do concurrent(j=1:bar(n*m, n/m))
39+
a(i) = n
40+
end do
41+
end do
42+
43+
! CHECK: Loop pair starting at location
44+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
45+
do concurrent(i=1:n)
46+
x = 10
47+
do concurrent(j=1:m)
48+
b(i,j,k) = i * j + k
49+
end do
50+
end do
51+
52+
! CHECK: Loop pair starting at location
53+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
54+
do concurrent(i=1:n)
55+
do concurrent(j=1:m)
56+
b(i,j,k) = i * j + k
57+
end do
58+
x = 10
59+
end do
60+
61+
! CHECK: Loop pair starting at location
62+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
63+
do concurrent(i=1:n)
64+
do concurrent(j=1:m)
65+
b(i,j,k) = i * j + k
66+
x = 10
67+
end do
68+
end do
69+
70+
! CHECK: Loop pair starting at location
71+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested
72+
do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m), k=1:bar(n*m, bar(n*m, n/m)))
73+
a(i) = n
74+
end do
75+
76+
77+
end subroutine
78+
79+
pure function bar(n, m)
80+
implicit none
81+
integer, intent(in) :: n, m
82+
integer :: bar
83+
84+
bar = n + m
85+
end function
86+
87+
end program main

0 commit comments

Comments
 (0)