Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 78 additions & 3 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3190,22 +3190,47 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get_if<Fortran::parser::OpenACCCombinedConstruct>(&acc.u);

Fortran::lower::pft::Evaluation *curEval = &getEval();
// Determine collapse depth/force and loopCount
bool collapseForce = false;
uint64_t collapseDepth = 1;
uint64_t loopCount = 1;
auto parseCollapse = [&](const Fortran::parser::AccClauseList &cl)
-> std::pair<bool, uint64_t> {
bool force = false;
uint64_t depth = 1;
for (const Fortran::parser::AccClause &clause : cl.v) {
if (const auto *collapseClause =
std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
force = std::get<bool>(arg.t);
const auto &intExpr =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
if (const auto *expr = Fortran::semantics::GetExpr(intExpr)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having collapse clause parsing in a single place would be ideal. We currently have getLoopCountForCollapseAndTile in OpenACC.cpp. Any chance to make something similar - like getLoopCountForCollapse - and then use it both here and in getLoopCountForCollapseAndTile.

if (auto v = Fortran::evaluate::ToInt64(*expr))
depth = *v;
}
break;
}
}
return {force, depth};
};

if (accLoop || accCombined) {
uint64_t loopCount;
if (accLoop) {
const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
} else if (accCombined) {
const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
std::get<Fortran::parser::AccBeginCombinedDirective>(
accCombined->t);
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
}

if (curEval->lowerAsStructured()) {
Expand All @@ -3215,8 +3240,58 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
}

for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
genFIR(e);
const bool isStructured = curEval && curEval->lowerAsStructured();
if (isStructured && collapseForce && collapseDepth > 1) {
// force: collect prologue/epilogue for the first collapseDepth nested loops
// and sink them into the innermost loop body at that depth
llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, epilogue;
Fortran::lower::pft::Evaluation *parent = &getEval();
Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr;
for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) {
epilogue.clear();
auto &kids = parent->getNestedEvaluations();
// Collect all non-loop statements before the next inner loop as prologue,
// then mark remaining siblings as epilogue and descend into the inner loop.
Fortran::lower::pft::Evaluation *childLoop = nullptr;
for (auto it = kids.begin(); it != kids.end(); ++it) {
if (it->getIf<Fortran::parser::DoConstruct>()) {
childLoop = &*it;
for (auto it2 = std::next(it); it2 != kids.end(); ++it2)
epilogue.push_back(&*it2);
break;
}
prologue.push_back(&*it);
}
// Semantics guarantees collapseDepth does not exceed nest depth
// so childLoop must be found here.
assert(childLoop && "Expected inner DoConstruct for collapse");
parent = childLoop;
innermostLoopEval = childLoop;
}

// Track sunk evaluations (avoid double-lowering)
llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
for (auto *e : prologue) sunk.insert(e);
for (auto *e : epilogue) sunk.insert(e);

auto sink =
[&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
for (auto *e : lst)
genFIR(*e);
};

sink(prologue);

// Lower innermost loop body, skipping sunk
for (Fortran::lower::pft::Evaluation &e : innermostLoopEval->getNestedEvaluations())
if (!sunk.contains(&e)) genFIR(e);

sink(epilogue);
} else {
// Normal lowering
for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
genFIR(e);
}
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);

Expand Down
26 changes: 18 additions & 8 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2144,8 +2144,23 @@ static void processDoLoopBounds(
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(outerDoConstruct)));
} else {
auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
assert(doCons && "expect do construct");
// Safely locate the next inner DoConstruct within this eval.
const Fortran::parser::DoConstruct *doCons = nullptr;
if (crtEval && crtEval->hasNestedEvaluations()) {
for (Fortran::lower::pft::Evaluation &child :
crtEval->getNestedEvaluations()) {
if (auto *cand = child.getIf<Fortran::parser::DoConstruct>()) {
doCons = cand;
// Prepare to descend for the next iteration
crtEval = &child;
break;
}
}
}
if (!doCons) {
// No deeper loop; stop collecting collapsed bounds.
break;
}
loopControl = &*doCons->GetLoopControl();
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(*doCons)));
Expand All @@ -2172,8 +2187,7 @@ static void processDoLoopBounds(

inclusiveBounds.push_back(true);

if (i < loopsToProcess - 1)
crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
// crtEval already updated when descending; no blind increment here.
}
}
}
Expand Down Expand Up @@ -2406,10 +2420,6 @@ static mlir::acc::LoopOp createLoopOp(
std::get_if<Fortran::parser::AccClause::Collapse>(
&clause.u)) {
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
const auto &force = std::get<bool>(arg.t);
if (force)
TODO(clauseLocation, "OpenACC collapse force modifier");

const auto &intExpr =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
const auto *expr = Fortran::semantics::GetExpr(intExpr);
Expand Down
41 changes: 41 additions & 0 deletions flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s

! Verify collapse(force:2) sinks prologue (between loops) and epilogue (after inner loop)
! into the acc.loop region body.

subroutine collapse_force_sink(n, m)
integer, intent(in) :: n, m
real, dimension(n,m) :: a
real, dimension(n) :: bb, cc
integer :: i, j

!$acc parallel loop collapse(force:2)
do i = 1, n
bb(i) = 4.2 ! prologue (between loops)
do j = 1, m
a(i,j) = a(i,j) + 2.0
end do
cc(i) = 7.3 ! epilogue (after inner loop)
end do
!$acc end parallel loop
end subroutine

! CHECK: func.func @_QPcollapse_force_sink(
! CHECK: acc.parallel
! Ensure outer acc.loop is combined(parallel)
! CHECK: acc.loop combined(parallel)
! Prologue: constant 4.2 and an assign before inner loop
! CHECK: arith.constant 4.200000e+00
! CHECK: hlfir.assign
! Inner loop and its body include 2.0 add and an assign
! CHECK: acc.loop
! CHECK: arith.constant 2.000000e+00
! CHECK: arith.addf
! CHECK: hlfir.assign
! Epilogue: constant 7.3 and an assign after inner loop
! CHECK: arith.constant 7.300000e+00
! CHECK: hlfir.assign
! And the outer acc.loop has collapse = [2]
! CHECK: } attributes {collapse = [2]


Loading