Skip to content

Commit eb6207b

Browse files
SusanTansvkeerthy
authored andcommitted
[flang][openacc] Add support for force clause for loop collapse (#162534)
Currently the force clause `collapse (force:num_level)` is NYI. Added support to sink any prologue and epilogue code to the inner most level as specified.
1 parent af51b87 commit eb6207b

File tree

4 files changed

+151
-25
lines changed

4 files changed

+151
-25
lines changed

flang/include/flang/Lower/OpenACC.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
122122
/// clause.
123123
uint64_t getLoopCountForCollapseAndTile(const Fortran::parser::AccClauseList &);
124124

125+
/// Parse collapse clause and return {size, force}. If absent, returns
126+
/// {1,false}.
127+
std::pair<uint64_t, bool>
128+
getCollapseSizeAndForce(const Fortran::parser::AccClauseList &);
129+
125130
/// Checks whether the current insertion point is inside OpenACC loop.
126131
bool isInOpenACCLoop(fir::FirOpBuilder &);
127132

flang/lib/Lower/Bridge.cpp

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3192,22 +3192,29 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31923192
std::get_if<Fortran::parser::OpenACCCombinedConstruct>(&acc.u);
31933193

31943194
Fortran::lower::pft::Evaluation *curEval = &getEval();
3195+
// Determine collapse depth/force and loopCount
3196+
bool collapseForce = false;
3197+
uint64_t collapseDepth = 1;
3198+
uint64_t loopCount = 1;
31953199

31963200
if (accLoop || accCombined) {
3197-
uint64_t loopCount;
31983201
if (accLoop) {
31993202
const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
32003203
std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
32013204
const Fortran::parser::AccClauseList &clauseList =
32023205
std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
32033206
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
3207+
std::tie(collapseDepth, collapseForce) =
3208+
Fortran::lower::getCollapseSizeAndForce(clauseList);
32043209
} else if (accCombined) {
32053210
const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
32063211
std::get<Fortran::parser::AccBeginCombinedDirective>(
32073212
accCombined->t);
32083213
const Fortran::parser::AccClauseList &clauseList =
32093214
std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
32103215
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
3216+
std::tie(collapseDepth, collapseForce) =
3217+
Fortran::lower::getCollapseSizeAndForce(clauseList);
32113218
}
32123219

32133220
if (curEval->lowerAsStructured()) {
@@ -3217,8 +3224,63 @@ class FirConverter : public Fortran::lower::AbstractConverter {
32173224
}
32183225
}
32193226

3220-
for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
3221-
genFIR(e);
3227+
const bool isStructured = curEval && curEval->lowerAsStructured();
3228+
if (isStructured && collapseForce && collapseDepth > 1) {
3229+
// force: collect prologue/epilogue for the first collapseDepth nested
3230+
// loops and sink them into the innermost loop body at that depth
3231+
llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, epilogue;
3232+
Fortran::lower::pft::Evaluation *parent = &getEval();
3233+
Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr;
3234+
for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) {
3235+
epilogue.clear();
3236+
auto &kids = parent->getNestedEvaluations();
3237+
// Collect all non-loop statements before the next inner loop as
3238+
// prologue, then mark remaining siblings as epilogue and descend into
3239+
// the inner loop.
3240+
Fortran::lower::pft::Evaluation *childLoop = nullptr;
3241+
for (auto it = kids.begin(); it != kids.end(); ++it) {
3242+
if (it->getIf<Fortran::parser::DoConstruct>()) {
3243+
childLoop = &*it;
3244+
for (auto it2 = std::next(it); it2 != kids.end(); ++it2)
3245+
epilogue.push_back(&*it2);
3246+
break;
3247+
}
3248+
prologue.push_back(&*it);
3249+
}
3250+
// Semantics guarantees collapseDepth does not exceed nest depth
3251+
// so childLoop must be found here.
3252+
assert(childLoop && "Expected inner DoConstruct for collapse");
3253+
parent = childLoop;
3254+
innermostLoopEval = childLoop;
3255+
}
3256+
3257+
// Track sunk evaluations (avoid double-lowering)
3258+
llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
3259+
for (auto *e : prologue)
3260+
sunk.insert(e);
3261+
for (auto *e : epilogue)
3262+
sunk.insert(e);
3263+
3264+
auto sink =
3265+
[&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
3266+
for (auto *e : lst)
3267+
genFIR(*e);
3268+
};
3269+
3270+
sink(prologue);
3271+
3272+
// Lower innermost loop body, skipping sunk
3273+
for (Fortran::lower::pft::Evaluation &e :
3274+
innermostLoopEval->getNestedEvaluations())
3275+
if (!sunk.contains(&e))
3276+
genFIR(e);
3277+
3278+
sink(epilogue);
3279+
} else {
3280+
// Normal lowering
3281+
for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
3282+
genFIR(e);
3283+
}
32223284
localSymbols.popScope();
32233285
builder->restoreInsertionPoint(insertPt);
32243286

flang/lib/Lower/OpenACC.cpp

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,11 +2178,25 @@ static void processDoLoopBounds(
21782178
locs.push_back(converter.genLocation(
21792179
Fortran::parser::FindSourceLocation(outerDoConstruct)));
21802180
} else {
2181-
auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
2182-
assert(doCons && "expect do construct");
2183-
loopControl = &*doCons->GetLoopControl();
2181+
// Safely locate the next inner DoConstruct within this eval.
2182+
const Fortran::parser::DoConstruct *innerDo = nullptr;
2183+
if (crtEval && crtEval->hasNestedEvaluations()) {
2184+
for (Fortran::lower::pft::Evaluation &child :
2185+
crtEval->getNestedEvaluations()) {
2186+
if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) {
2187+
innerDo = stmt;
2188+
// Prepare to descend for the next iteration
2189+
crtEval = &child;
2190+
break;
2191+
}
2192+
}
2193+
}
2194+
if (!innerDo)
2195+
break; // No deeper loop; stop collecting collapsed bounds.
2196+
2197+
loopControl = &*innerDo->GetLoopControl();
21842198
locs.push_back(converter.genLocation(
2185-
Fortran::parser::FindSourceLocation(*doCons)));
2199+
Fortran::parser::FindSourceLocation(*innerDo)));
21862200
}
21872201

21882202
const Fortran::parser::LoopControl::Bounds *bounds =
@@ -2206,8 +2220,7 @@ static void processDoLoopBounds(
22062220

22072221
inclusiveBounds.push_back(true);
22082222

2209-
if (i < loopsToProcess - 1)
2210-
crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
2223+
// crtEval already updated when descending; no blind increment here.
22112224
}
22122225
}
22132226
}
@@ -2553,10 +2566,6 @@ static mlir::acc::LoopOp createLoopOp(
25532566
std::get_if<Fortran::parser::AccClause::Collapse>(
25542567
&clause.u)) {
25552568
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
2556-
const auto &force = std::get<bool>(arg.t);
2557-
if (force)
2558-
TODO(clauseLocation, "OpenACC collapse force modifier");
2559-
25602569
const auto &intExpr =
25612570
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
25622571
const auto *expr = Fortran::semantics::GetExpr(intExpr);
@@ -5029,25 +5038,34 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
50295038

50305039
uint64_t Fortran::lower::getLoopCountForCollapseAndTile(
50315040
const Fortran::parser::AccClauseList &clauseList) {
5032-
uint64_t collapseLoopCount = 1;
5041+
uint64_t collapseLoopCount = getCollapseSizeAndForce(clauseList).first;
50335042
uint64_t tileLoopCount = 1;
50345043
for (const Fortran::parser::AccClause &clause : clauseList.v) {
5035-
if (const auto *collapseClause =
5036-
std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
5037-
const parser::AccCollapseArg &arg = collapseClause->v;
5038-
const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
5039-
collapseLoopCount = *Fortran::semantics::GetIntValue(collapseValue);
5040-
}
50415044
if (const auto *tileClause =
50425045
std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
50435046
const parser::AccTileExprList &tileExprList = tileClause->v;
5044-
const std::list<parser::AccTileExpr> &listTileExpr = tileExprList.v;
5045-
tileLoopCount = listTileExpr.size();
5047+
tileLoopCount = tileExprList.v.size();
5048+
}
5049+
}
5050+
return tileLoopCount > collapseLoopCount ? tileLoopCount : collapseLoopCount;
5051+
}
5052+
5053+
std::pair<uint64_t, bool> Fortran::lower::getCollapseSizeAndForce(
5054+
const Fortran::parser::AccClauseList &clauseList) {
5055+
uint64_t size = 1;
5056+
bool force = false;
5057+
for (const Fortran::parser::AccClause &clause : clauseList.v) {
5058+
if (const auto *collapseClause =
5059+
std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
5060+
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
5061+
force = std::get<bool>(arg.t);
5062+
const auto &collapseValue =
5063+
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
5064+
size = *Fortran::semantics::GetIntValue(collapseValue);
5065+
break;
50465066
}
50475067
}
5048-
if (tileLoopCount > collapseLoopCount)
5049-
return tileLoopCount;
5050-
return collapseLoopCount;
5068+
return {size, force};
50515069
}
50525070

50535071
/// Create an ACC loop operation for a DO construct when inside ACC compute
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
2+
3+
! Verify collapse(force:2) sinks prologue (between loops) and epilogue (after inner loop)
4+
! into the acc.loop region body.
5+
6+
subroutine collapse_force_sink(n, m)
7+
integer, intent(in) :: n, m
8+
real, dimension(n,m) :: a
9+
real, dimension(n) :: bb, cc
10+
integer :: i, j
11+
12+
!$acc parallel loop collapse(force:2)
13+
do i = 1, n
14+
bb(i) = 4.2 ! prologue (between loops)
15+
do j = 1, m
16+
a(i,j) = a(i,j) + 2.0
17+
end do
18+
cc(i) = 7.3 ! epilogue (after inner loop)
19+
end do
20+
!$acc end parallel loop
21+
end subroutine
22+
23+
! CHECK: func.func @_QPcollapse_force_sink(
24+
! CHECK: acc.parallel
25+
! Ensure outer acc.loop is combined(parallel)
26+
! CHECK: acc.loop combined(parallel)
27+
! Prologue: constant 4.2 and an assign before inner loop
28+
! CHECK: arith.constant 4.200000e+00
29+
! CHECK: hlfir.assign
30+
! Inner loop and its body include 2.0 add and an assign
31+
! CHECK: acc.loop
32+
! CHECK: arith.constant 2.000000e+00
33+
! CHECK: arith.addf
34+
! CHECK: hlfir.assign
35+
! Epilogue: constant 7.3 and an assign after inner loop
36+
! CHECK: arith.constant 7.300000e+00
37+
! CHECK: hlfir.assign
38+
! And the outer acc.loop has collapse = [2]
39+
! CHECK: } attributes {collapse = [2]
40+
41+

0 commit comments

Comments
 (0)