Skip to content

Commit 67697c6

Browse files
authored
Merge pull request #199 from ergawy/simplify_loop_nest_detection
[flang][OpenMP][DoConcurrent] Simplify loop-nest detection logic
2 parents 109ac78 + 8e6a70e commit 67697c6

File tree

3 files changed

+223
-171
lines changed

3 files changed

+223
-171
lines changed

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 103 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ namespace flangomp {
3636
#include "flang/Optimizer/OpenMP/Passes.h.inc"
3737
} // namespace flangomp
3838

39-
#define DEBUG_TYPE "fopenmp-do-concurrent-conversion"
39+
#define DEBUG_TYPE "do-concurrent-conversion"
40+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
4041

4142
namespace Fortran {
4243
namespace lower {
@@ -175,9 +176,24 @@ bool isIndVarUltimateOperand(mlir::Operation *op, fir::DoLoopOp doLoop) {
175176
return false;
176177
}
177178

179+
/// For the \p doLoop parameter, find the operations that declares its induction
180+
/// variable or allocates memory for it.
181+
mlir::Operation *findLoopIndVarMemDecl(fir::DoLoopOp doLoop) {
182+
mlir::Value result = nullptr;
183+
mlir::visitUsedValuesDefinedAbove(
184+
doLoop.getRegion(), [&](mlir::OpOperand *operand) {
185+
if (isIndVarUltimateOperand(operand->getOwner(), doLoop)) {
186+
assert(result == nullptr &&
187+
"loop can have only one induction variable");
188+
result = operand->get();
189+
}
190+
});
191+
192+
assert(result != nullptr && result.getDefiningOp() != nullptr);
193+
return result.getDefiningOp();
194+
}
195+
178196
/// Collect the list of values used inside the loop but defined outside of it.
179-
/// The first item in the returned list is always the loop's induction
180-
/// variable.
181197
void collectLoopLiveIns(fir::DoLoopOp doLoop,
182198
llvm::SmallVectorImpl<mlir::Value> &liveIns) {
183199
llvm::SmallDenseSet<mlir::Value> seenValues;
@@ -194,9 +210,6 @@ void collectLoopLiveIns(fir::DoLoopOp doLoop,
194210
return;
195211

196212
liveIns.push_back(operand->get());
197-
198-
if (isIndVarUltimateOperand(operand->getOwner(), doLoop))
199-
std::swap(*liveIns.begin(), *liveIns.rbegin());
200213
});
201214
}
202215

@@ -286,24 +299,78 @@ void collectIndirectConstOpChain(mlir::Operation *link,
286299
opChain.insert(link);
287300
}
288301

302+
/// Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff
303+
/// there are no operations in \p outerloop's other than:
304+
///
305+
/// 1. the operations needed to assing/update \p outerLoop's induction variable.
306+
/// 2. \p innerLoop itself.
307+
///
308+
/// \p return true if \p innerLoop is perfectly nested inside \p outerLoop
309+
/// according to the above definition.
310+
bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
311+
mlir::BackwardSliceOptions backwardSliceOptions;
312+
backwardSliceOptions.inclusive = true;
313+
// We will collect the backward slices for innerLoop's LB, UB, and step.
314+
// However, we want to limit the scope of these slices to the scope of
315+
// outerLoop's region.
316+
backwardSliceOptions.filter = [&](mlir::Operation *op) {
317+
return !mlir::areValuesDefinedAbove(op->getResults(),
318+
outerLoop.getRegion());
319+
};
320+
321+
mlir::ForwardSliceOptions forwardSliceOptions;
322+
forwardSliceOptions.inclusive = true;
323+
// We don't care about the outer-loop's induction variable's uses within the
324+
// inner-loop, so we filter out these uses.
325+
forwardSliceOptions.filter = [&](mlir::Operation *op) {
326+
return mlir::areValuesDefinedAbove(op->getResults(), innerLoop.getRegion());
327+
};
328+
329+
llvm::SetVector<mlir::Operation *> indVarSlice;
330+
mlir::getForwardSlice(outerLoop.getInductionVar(), &indVarSlice,
331+
forwardSliceOptions);
332+
llvm::DenseSet<mlir::Operation *> innerLoopSetupOpsSet(indVarSlice.begin(),
333+
indVarSlice.end());
334+
335+
llvm::DenseSet<mlir::Operation *> loopBodySet;
336+
outerLoop.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
337+
if (op == outerLoop)
338+
return mlir::WalkResult::advance();
339+
340+
if (op == innerLoop)
341+
return mlir::WalkResult::skip();
342+
343+
if (mlir::isa<fir::ResultOp>(op))
344+
return mlir::WalkResult::advance();
345+
346+
loopBodySet.insert(op);
347+
return mlir::WalkResult::advance();
348+
});
349+
350+
bool result = (loopBodySet == innerLoopSetupOpsSet);
351+
mlir::Location loc = outerLoop.getLoc();
352+
LLVM_DEBUG(DBGS() << "Loop pair starting at location " << loc << " is"
353+
<< (result ? "" : " not") << " perfectly nested\n");
354+
355+
return result;
356+
}
357+
289358
/// Starting with `outerLoop` collect a perfectly nested loop nest, if any. This
290359
/// function collects as much as possible loops in the nest; it case it fails to
291360
/// recognize a certain nested loop as part of the nest it just returns the
292361
/// parent loops it discovered before.
293-
mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop,
362+
mlir::LogicalResult collectLoopNest(fir::DoLoopOp currentLoop,
294363
LoopNestToIndVarMap &loopNest) {
295-
assert(outerLoop.getUnordered());
296-
llvm::SmallVector<mlir::Value> outerLoopLiveIns;
297-
collectLoopLiveIns(outerLoop, outerLoopLiveIns);
364+
assert(currentLoop.getUnordered());
298365

299366
while (true) {
300367
loopNest.try_emplace(
301-
outerLoop,
368+
currentLoop,
302369
InductionVariableInfo{
303-
outerLoopLiveIns.front().getDefiningOp(),
304-
std::move(looputils::extractIndVarUpdateOps(outerLoop))});
370+
findLoopIndVarMemDecl(currentLoop),
371+
std::move(looputils::extractIndVarUpdateOps(currentLoop))});
305372

306-
auto directlyNestedLoops = outerLoop.getRegion().getOps<fir::DoLoopOp>();
373+
auto directlyNestedLoops = currentLoop.getRegion().getOps<fir::DoLoopOp>();
307374
llvm::SmallVector<fir::DoLoopOp> unorderedLoops;
308375

309376
for (auto nestedLoop : directlyNestedLoops)
@@ -318,69 +385,10 @@ mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop,
318385

319386
fir::DoLoopOp nestedUnorderedLoop = unorderedLoops.front();
320387

321-
if ((nestedUnorderedLoop.getLowerBound().getDefiningOp() == nullptr) ||
322-
(nestedUnorderedLoop.getUpperBound().getDefiningOp() == nullptr) ||
323-
(nestedUnorderedLoop.getStep().getDefiningOp() == nullptr))
324-
return mlir::failure();
325-
326-
llvm::SmallVector<mlir::Value> nestedLiveIns;
327-
collectLoopLiveIns(nestedUnorderedLoop, nestedLiveIns);
328-
329-
llvm::DenseSet<mlir::Value> outerLiveInsSet;
330-
llvm::DenseSet<mlir::Value> nestedLiveInsSet;
331-
332-
// Returns a "unified" view of an mlir::Value. This utility checks if the
333-
// value is defined by an op, and if so, return the first value defined by
334-
// that op (if there are many), otherwise just returns the value.
335-
//
336-
// This serves the purpose that if, for example, `%op_res#0` is used in the
337-
// outer loop and `%op_res#1` is used in the nested loop (or vice versa),
338-
// that we detect both as the same value. If we did not do so, we might
339-
// falesely detect that the 2 loops are not perfectly nested since they use
340-
// "different" sets of values.
341-
auto getUnifiedLiveInView = [](mlir::Value liveIn) {
342-
return liveIn.getDefiningOp() != nullptr
343-
? liveIn.getDefiningOp()->getResult(0)
344-
: liveIn;
345-
};
346-
347-
// Re-package both lists of live-ins into sets so that we can use set
348-
// equality to compare the values used in the outerloop vs. the nestd one.
349-
350-
for (auto liveIn : nestedLiveIns)
351-
nestedLiveInsSet.insert(getUnifiedLiveInView(liveIn));
352-
353-
mlir::Value outerLoopIV;
354-
for (auto liveIn : outerLoopLiveIns) {
355-
outerLiveInsSet.insert(getUnifiedLiveInView(liveIn));
356-
357-
// Keep track of the IV of the outerloop. See `isPerfectlyNested` for more
358-
// info on the reason.
359-
if (outerLoopIV == nullptr)
360-
outerLoopIV = getUnifiedLiveInView(liveIn);
361-
}
362-
363-
// For the 2 loops to be perfectly nested, either:
364-
// * both would have exactly the same set of live-in values or,
365-
// * the outer loop would have exactly 1 extra live-in value: the outer
366-
// loop's induction variable; this happens when the outer loop's IV is
367-
// *not* referenced in the nested loop.
368-
bool isPerfectlyNested = [&]() {
369-
if (outerLiveInsSet == nestedLiveInsSet)
370-
return true;
371-
372-
if ((outerLiveInsSet.size() == nestedLiveIns.size() + 1) &&
373-
!nestedLiveInsSet.contains(outerLoopIV))
374-
return true;
375-
376-
return false;
377-
}();
378-
379-
if (!isPerfectlyNested)
388+
if (!isPerfectlyNested(currentLoop, nestedUnorderedLoop))
380389
return mlir::failure();
381390

382-
outerLoop = nestedUnorderedLoop;
383-
outerLoopLiveIns = std::move(nestedLiveIns);
391+
currentLoop = nestedUnorderedLoop;
384392
}
385393

386394
return mlir::success();
@@ -554,10 +562,6 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
554562
"defining operation.");
555563
}
556564

557-
llvm::SmallVector<mlir::Value> outermostLoopLiveIns;
558-
looputils::collectLoopLiveIns(doLoop, outermostLoopLiveIns);
559-
assert(!outermostLoopLiveIns.empty());
560-
561565
looputils::LoopNestToIndVarMap loopNest;
562566
bool hasRemainingNestedLoops =
563567
failed(looputils::collectLoopNest(doLoop, loopNest));
@@ -566,15 +570,19 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
566570
"Some `do concurent` loops are not perfectly-nested. "
567571
"These will be serialzied.");
568572

573+
llvm::SmallVector<mlir::Value> loopNestLiveIns;
574+
looputils::collectLoopLiveIns(loopNest.back().first, loopNestLiveIns);
575+
assert(!loopNestLiveIns.empty());
576+
569577
llvm::SetVector<mlir::Value> locals;
570578
looputils::collectLoopLocalValues(loopNest.back().first, locals);
571579
// We do not want to map "loop-local" values to the device through
572580
// `omp.map.info` ops. Therefore, we remove them from the list of live-ins.
573-
outermostLoopLiveIns.erase(llvm::remove_if(outermostLoopLiveIns,
574-
[&](mlir::Value liveIn) {
575-
return locals.contains(liveIn);
576-
}),
577-
outermostLoopLiveIns.end());
581+
loopNestLiveIns.erase(llvm::remove_if(loopNestLiveIns,
582+
[&](mlir::Value liveIn) {
583+
return locals.contains(liveIn);
584+
}),
585+
loopNestLiveIns.end());
578586

579587
looputils::sinkLoopIVArgs(rewriter, loopNest);
580588

@@ -590,24 +598,25 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
590598
loopNestClauseOps, &targetClauseOps);
591599

592600
// Prevent mapping host-evaluated variables.
593-
outermostLoopLiveIns.erase(
594-
llvm::remove_if(outermostLoopLiveIns,
601+
loopNestLiveIns.erase(
602+
llvm::remove_if(loopNestLiveIns,
595603
[&](mlir::Value liveIn) {
596604
return llvm::is_contained(
597605
targetClauseOps.hostEvalVars, liveIn);
598606
}),
599-
outermostLoopLiveIns.end());
607+
loopNestLiveIns.end());
600608

601609
// The outermost loop will contain all the live-in values in all nested
602610
// loops since live-in values are collected recursively for all nested
603611
// ops.
604-
for (mlir::Value liveIn : outermostLoopLiveIns)
612+
for (mlir::Value liveIn : loopNestLiveIns)
605613
targetClauseOps.mapVars.push_back(
606614
genMapInfoOpForLiveIn(rewriter, liveIn));
607615

608616
targetOp =
609-
genTargetOp(doLoop.getLoc(), rewriter, mapper, outermostLoopLiveIns,
617+
genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns,
610618
targetClauseOps, loopNestClauseOps);
619+
611620
genTeamsOp(doLoop.getLoc(), rewriter);
612621
}
613622

@@ -998,10 +1007,11 @@ class DoConcurrentConversionPass
9981007
context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
9991008
concurrentLoopsToSkip);
10001009
mlir::ConversionTarget target(*context);
1001-
target.addLegalDialect<
1002-
fir::FIROpsDialect, hlfir::hlfirDialect, mlir::arith::ArithDialect,
1003-
mlir::func::FuncDialect, mlir::omp::OpenMPDialect,
1004-
mlir::cf::ControlFlowDialect, mlir::math::MathDialect>();
1010+
target
1011+
.addLegalDialect<fir::FIROpsDialect, hlfir::hlfirDialect,
1012+
mlir::arith::ArithDialect, mlir::func::FuncDialect,
1013+
mlir::omp::OpenMPDialect, mlir::cf::ControlFlowDialect,
1014+
mlir::math::MathDialect, mlir::LLVM::LLVMDialect>();
10051015

10061016
target.addDynamicallyLegalOp<fir::DoLoopOp>([&](fir::DoLoopOp op) {
10071017
return !op.getUnordered() || concurrentLoopsToSkip.contains(op);
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
! Tests loop-nest detection algorithm for do-concurrent mapping.
2+
3+
! REQUIRES: asserts
4+
5+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host \
6+
! RUN: -mmlir -debug %s -o - 2> %t.log || true
7+
8+
! RUN: FileCheck %s < %t.log
9+
10+
program main
11+
implicit none
12+
13+
contains
14+
15+
subroutine foo(n)
16+
implicit none
17+
integer :: n, m
18+
integer :: i, j, k
19+
integer :: x
20+
integer, dimension(n) :: a
21+
integer, dimension(n, n, n) :: b
22+
23+
! CHECK: Loop pair starting at location
24+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested
25+
do concurrent(i=1:n, j=1:bar(n*m, n/m))
26+
a(i) = n
27+
end do
28+
29+
! CHECK: Loop pair starting at location
30+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested
31+
do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m))
32+
a(i) = n
33+
end do
34+
35+
! CHECK: Loop pair starting at location
36+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
37+
do concurrent(i=bar(n, x):n)
38+
do concurrent(j=1:bar(n*m, n/m))
39+
a(i) = n
40+
end do
41+
end do
42+
43+
! CHECK: Loop pair starting at location
44+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
45+
do concurrent(i=1:n)
46+
x = 10
47+
do concurrent(j=1:m)
48+
b(i,j,k) = i * j + k
49+
end do
50+
end do
51+
52+
! CHECK: Loop pair starting at location
53+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
54+
do concurrent(i=1:n)
55+
do concurrent(j=1:m)
56+
b(i,j,k) = i * j + k
57+
end do
58+
x = 10
59+
end do
60+
61+
! CHECK: Loop pair starting at location
62+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested
63+
do concurrent(i=1:n)
64+
do concurrent(j=1:m)
65+
b(i,j,k) = i * j + k
66+
x = 10
67+
end do
68+
end do
69+
70+
! CHECK: Loop pair starting at location
71+
! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested
72+
do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m), k=1:bar(n*m, bar(n*m, n/m)))
73+
a(i) = n
74+
end do
75+
76+
77+
end subroutine
78+
79+
pure function bar(n, m)
80+
implicit none
81+
integer, intent(in) :: n, m
82+
integer :: bar
83+
84+
bar = n + m
85+
end function
86+
87+
end program main

0 commit comments

Comments
 (0)