Skip to content

Commit cf6eb8f

Browse files
committed
[flang][OpenMP] Implement more robust loop-nest detection logic
The previous loop-nest detection algorithm fell short, in some cases, to detect whether a pair of `do concurrent` loops are perfectly nested or not. This is a re-implementation using forward and backward slice extraction algorithms to compare the set of ops required to setup the inner loop bounds vs. the set of ops nested in the outer loop other thatn the nested loop itself.
1 parent 40a16e7 commit cf6eb8f

File tree

2 files changed

+167
-52
lines changed

2 files changed

+167
-52
lines changed

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 78 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ namespace flangomp {
3636
#include "flang/Optimizer/OpenMP/Passes.h.inc"
3737
} // namespace flangomp
3838

39-
#define DEBUG_TYPE "fopenmp-do-concurrent-conversion"
39+
#define DEBUG_TYPE "do-concurrent-conversion"
40+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
4041

4142
namespace Fortran {
4243
namespace lower {
@@ -366,6 +367,81 @@ void collectIndirectConstOpChain(mlir::Operation *link,
366367
opChain.insert(link);
367368
}
368369

370+
/// Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff
371+
/// there are no operations in \p outerloop's other than:
372+
///
373+
/// 1. those operations needed to setup \p innerLoop's LB, UB, and step values,
374+
/// 2. the operations needed to assing/update \p outerLoop's induction variable.
375+
/// 3. \p innerLoop itself.
376+
///
377+
/// \p return true if \p innerLoop is perfectly nested inside \p outerLoop
378+
/// according to the above definition.
379+
bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
380+
mlir::BackwardSliceOptions backwardSliceOptions;
381+
backwardSliceOptions.inclusive = true;
382+
// We will collect the backward slices for innerLoop's LB, UB, and step.
383+
// However, we want to limit the scope of these slices to the scope of
384+
// outerLoop's region.
385+
backwardSliceOptions.filter = [&](mlir::Operation *op) {
386+
return !mlir::areValuesDefinedAbove(op->getResults(),
387+
outerLoop.getRegion());
388+
};
389+
390+
llvm::SetVector<mlir::Operation *> lbSlice;
391+
mlir::getBackwardSlice(innerLoop.getLowerBound(), &lbSlice,
392+
backwardSliceOptions);
393+
394+
llvm::SetVector<mlir::Operation *> ubSlice;
395+
mlir::getBackwardSlice(innerLoop.getUpperBound(), &ubSlice,
396+
backwardSliceOptions);
397+
398+
llvm::SetVector<mlir::Operation *> stepSlice;
399+
mlir::getBackwardSlice(innerLoop.getStep(), &stepSlice, backwardSliceOptions);
400+
401+
mlir::ForwardSliceOptions forwardSliceOptions;
402+
forwardSliceOptions.inclusive = true;
403+
// We don't care of the outer loop's induction variable's uses within the
404+
// inner loop, so we filter out these uses.
405+
forwardSliceOptions.filter = [&](mlir::Operation *op) {
406+
return mlir::areValuesDefinedAbove(op->getResults(), innerLoop.getRegion());
407+
};
408+
409+
llvm::SetVector<mlir::Operation *> indVarSlice;
410+
mlir::getForwardSlice(outerLoop.getInductionVar(), &indVarSlice,
411+
forwardSliceOptions);
412+
413+
llvm::SetVector<mlir::Operation *> innerLoopSetupOpsVec;
414+
innerLoopSetupOpsVec.set_union(indVarSlice);
415+
innerLoopSetupOpsVec.set_union(lbSlice);
416+
innerLoopSetupOpsVec.set_union(ubSlice);
417+
innerLoopSetupOpsVec.set_union(stepSlice);
418+
llvm::DenseSet<mlir::Operation *> innerLoopSetupOpsSet;
419+
420+
for (mlir::Operation *op : innerLoopSetupOpsVec)
421+
innerLoopSetupOpsSet.insert(op);
422+
423+
llvm::DenseSet<mlir::Operation *> loopBodySet;
424+
outerLoop.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
425+
if (op == outerLoop)
426+
return mlir::WalkResult::advance();
427+
428+
if (op == innerLoop)
429+
return mlir::WalkResult::skip();
430+
431+
if (op->hasTrait<mlir::OpTrait::IsTerminator>())
432+
return mlir::WalkResult::advance();
433+
434+
loopBodySet.insert(op);
435+
return mlir::WalkResult::advance();
436+
});
437+
438+
bool result = (loopBodySet == innerLoopSetupOpsSet);
439+
mlir::Location loc = outerLoop.getLoc();
440+
LLVM_DEBUG(DBGS() << "Loop pair starting at location " << loc << " is"
441+
<< (result ? "" : " not") << " perfectly nested\n");
442+
return result;
443+
}
444+
369445
/// Starting with `outerLoop` collect a perfectly nested loop nest, if any. This
370446
/// function collects as much as possible loops in the nest; it case it fails to
371447
/// recognize a certain nested loop as part of the nest it just returns the
@@ -406,57 +482,7 @@ mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop,
406482
llvm::SmallVector<mlir::Value> nestedLiveIns;
407483
collectLoopLiveIns(nestedUnorderedLoop, nestedLiveIns);
408484

409-
llvm::DenseSet<mlir::Value> outerLiveInsSet;
410-
llvm::DenseSet<mlir::Value> nestedLiveInsSet;
411-
412-
// Returns a "unified" view of an mlir::Value. This utility checks if the
413-
// value is defined by an op, and if so, return the first value defined by
414-
// that op (if there are many), otherwise just returns the value.
415-
//
416-
// This serves the purpose that if, for example, `%op_res#0` is used in the
417-
// outer loop and `%op_res#1` is used in the nested loop (or vice versa),
418-
// that we detect both as the same value. If we did not do so, we might
419-
// falesely detect that the 2 loops are not perfectly nested since they use
420-
// "different" sets of values.
421-
auto getUnifiedLiveInView = [](mlir::Value liveIn) {
422-
return liveIn.getDefiningOp() != nullptr
423-
? liveIn.getDefiningOp()->getResult(0)
424-
: liveIn;
425-
};
426-
427-
// Re-package both lists of live-ins into sets so that we can use set
428-
// equality to compare the values used in the outerloop vs. the nestd one.
429-
430-
for (auto liveIn : nestedLiveIns)
431-
nestedLiveInsSet.insert(getUnifiedLiveInView(liveIn));
432-
433-
mlir::Value outerLoopIV;
434-
for (auto liveIn : outerLoopLiveIns) {
435-
outerLiveInsSet.insert(getUnifiedLiveInView(liveIn));
436-
437-
// Keep track of the IV of the outerloop. See `isPerfectlyNested` for more
438-
// info on the reason.
439-
if (outerLoopIV == nullptr)
440-
outerLoopIV = getUnifiedLiveInView(liveIn);
441-
}
442-
443-
// For the 2 loops to be perfectly nested, either:
444-
// * both would have exactly the same set of live-in values or,
445-
// * the outer loop would have exactly 1 extra live-in value: the outer
446-
// loop's induction variable; this happens when the outer loop's IV is
447-
// *not* referenced in the nested loop.
448-
bool isPerfectlyNested = [&]() {
449-
if (outerLiveInsSet == nestedLiveInsSet)
450-
return true;
451-
452-
if ((outerLiveInsSet.size() == nestedLiveIns.size() + 1) &&
453-
!nestedLiveInsSet.contains(outerLoopIV))
454-
return true;
455-
456-
return false;
457-
}();
458-
459-
if (!isPerfectlyNested)
485+
if (!isPerfectlyNested(outerLoop, nestedUnorderedLoop))
460486
return mlir::failure();
461487

462488
outerLoop = nestedUnorderedLoop;
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
! Tests loop-nest detection algorithm for do-concurrent mapping.
2+
3+
! REQUIRES: asserts
4+
5+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host \
6+
! RUN: -mmlir -debug %s -o - 2> %t.log || true
7+
8+
! RUN: FileCheck %s < %t.log
9+
10+
program main
11+
implicit none
12+
13+
contains
14+
15+
subroutine foo(n)
16+
implicit none
17+
integer :: n, m
18+
integer :: i, j, k
19+
integer :: x
20+
integer, dimension(n) :: a
21+
integer, dimension(n, n, n) :: b
22+
23+
! NOTE This for sure is a perfect loop nest. However, the way `do-concurrent`
24+
! loops are now emitted by flang is probably not correct. This is being looked
25+
! into at the moment and once we have flang emitting proper loop headers, we
26+
! will revisit this.
27+
!
28+
! CHECK: Loop pair starting at location
29+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
30+
do concurrent(i=1:n, j=1:bar(n*m, n/m))
31+
a(i) = n
32+
end do
33+
34+
! NOTE same as above.
35+
!
36+
! CHECK: Loop pair starting at location
37+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
38+
do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m))
39+
a(i) = n
40+
end do
41+
42+
! NOTE This is **not** a perfect nest since the inner call to `bar` will allocate
43+
! memory for the temp results of `n*m` and `n/m` **inside** the outer loop.
44+
!
45+
! CHECK: Loop pair starting at location
46+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
47+
do concurrent(i=bar(n, x):n)
48+
do concurrent(j=1:bar(n*m, n/m))
49+
a(i) = n
50+
end do
51+
end do
52+
53+
! CHECK: Loop pair starting at location
54+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
55+
do concurrent(i=1:n)
56+
x = 10
57+
do concurrent(j=1:m)
58+
b(i,j,k) = i * j + k
59+
end do
60+
end do
61+
62+
! CHECK: Loop pair starting at location
63+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
64+
do concurrent(i=1:n)
65+
do concurrent(j=1:m)
66+
b(i,j,k) = i * j + k
67+
end do
68+
x = 10
69+
end do
70+
71+
! CHECK: Loop pair starting at location
72+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested
73+
do concurrent(i=1:n)
74+
do concurrent(j=1:m)
75+
b(i,j,k) = i * j + k
76+
x = 10
77+
end do
78+
end do
79+
end subroutine
80+
81+
pure function bar(n, m)
82+
implicit none
83+
integer, intent(in) :: n, m
84+
integer :: bar
85+
86+
bar = n + m
87+
end function
88+
89+
end program main

0 commit comments

Comments
 (0)