Skip to content

Commit 3c62ead

Browse files
authored
[flang][OpenACC] lower acc loops with early exits (#164992)
Lower acc loop with early exit using the newly added "unstructured" attribute. The core change of this patch is to refactor the loop control variable so that for loop with early exits, the induction variables are privatized, but no bounds operands are added to the acc.loop. The logic of the loop is implemented by the FIR loop lowering logic by generating explicit control flow.
1 parent 3149c7c commit 3c62ead

File tree

2 files changed

+144
-103
lines changed

2 files changed

+144
-103
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 142 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,6 +2003,49 @@ static void determineDefaultLoopParMode(
20032003
}
20042004
}
20052005

2006+
// Helper to visit Bounds of DO LOOP nest.
2007+
static void visitLoopControl(
2008+
Fortran::lower::AbstractConverter &converter,
2009+
const Fortran::parser::DoConstruct &outerDoConstruct,
2010+
uint64_t loopsToProcess, Fortran::lower::pft::Evaluation &eval,
2011+
std::function<void(const Fortran::parser::LoopControl::Bounds &,
2012+
mlir::Location)>
2013+
callback) {
2014+
Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation();
2015+
for (uint64_t i = 0; i < loopsToProcess; ++i) {
2016+
const Fortran::parser::LoopControl *loopControl;
2017+
if (i == 0) {
2018+
loopControl = &*outerDoConstruct.GetLoopControl();
2019+
mlir::Location loc = converter.genLocation(
2020+
Fortran::parser::FindSourceLocation(outerDoConstruct));
2021+
callback(std::get<Fortran::parser::LoopControl::Bounds>(loopControl->u),
2022+
loc);
2023+
} else {
2024+
// Safely locate the next inner DoConstruct within this eval.
2025+
const Fortran::parser::DoConstruct *innerDo = nullptr;
2026+
if (crtEval && crtEval->hasNestedEvaluations()) {
2027+
for (Fortran::lower::pft::Evaluation &child :
2028+
crtEval->getNestedEvaluations()) {
2029+
if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) {
2030+
innerDo = stmt;
2031+
// Prepare to descend for the next iteration
2032+
crtEval = &child;
2033+
break;
2034+
}
2035+
}
2036+
}
2037+
if (!innerDo)
2038+
break; // No deeper loop; stop collecting collapsed bounds.
2039+
2040+
loopControl = &*innerDo->GetLoopControl();
2041+
mlir::Location loc =
2042+
converter.genLocation(Fortran::parser::FindSourceLocation(*innerDo));
2043+
callback(std::get<Fortran::parser::LoopControl::Bounds>(loopControl->u),
2044+
loc);
2045+
}
2046+
}
2047+
}
2048+
20062049
// Extract loop bounds, steps, induction variables, and privatization info
20072050
// for both DO CONCURRENT and regular do loops
20082051
static void processDoLoopBounds(
@@ -2024,7 +2067,6 @@ static void processDoLoopBounds(
20242067
llvm::SmallVector<mlir::Location> &locs, uint64_t loopsToProcess) {
20252068
assert(loopsToProcess > 0 && "expect at least one loop");
20262069
locs.push_back(currentLocation); // Location of the directive
2027-
Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation();
20282070
bool isDoConcurrent = outerDoConstruct.IsDoConcurrent();
20292071

20302072
if (isDoConcurrent) {
@@ -2065,57 +2107,29 @@ static void processDoLoopBounds(
20652107
inclusiveBounds.push_back(true);
20662108
}
20672109
} else {
2068-
for (uint64_t i = 0; i < loopsToProcess; ++i) {
2069-
const Fortran::parser::LoopControl *loopControl;
2070-
if (i == 0) {
2071-
loopControl = &*outerDoConstruct.GetLoopControl();
2072-
locs.push_back(converter.genLocation(
2073-
Fortran::parser::FindSourceLocation(outerDoConstruct)));
2074-
} else {
2075-
// Safely locate the next inner DoConstruct within this eval.
2076-
const Fortran::parser::DoConstruct *innerDo = nullptr;
2077-
if (crtEval && crtEval->hasNestedEvaluations()) {
2078-
for (Fortran::lower::pft::Evaluation &child :
2079-
crtEval->getNestedEvaluations()) {
2080-
if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) {
2081-
innerDo = stmt;
2082-
// Prepare to descend for the next iteration
2083-
crtEval = &child;
2084-
break;
2085-
}
2086-
}
2087-
}
2088-
if (!innerDo)
2089-
break; // No deeper loop; stop collecting collapsed bounds.
2090-
2091-
loopControl = &*innerDo->GetLoopControl();
2092-
locs.push_back(converter.genLocation(
2093-
Fortran::parser::FindSourceLocation(*innerDo)));
2094-
}
2095-
2096-
const Fortran::parser::LoopControl::Bounds *bounds =
2097-
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
2098-
assert(bounds && "Expected bounds on the loop construct");
2099-
lowerbounds.push_back(fir::getBase(converter.genExprValue(
2100-
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
2101-
upperbounds.push_back(fir::getBase(converter.genExprValue(
2102-
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
2103-
if (bounds->step)
2104-
steps.push_back(fir::getBase(converter.genExprValue(
2105-
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
2106-
else // If `step` is not present, assume it is `1`.
2107-
steps.push_back(builder.createIntegerConstant(
2108-
currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1));
2109-
2110-
Fortran::semantics::Symbol &ivSym =
2111-
bounds->name.thing.symbol->GetUltimate();
2112-
privatizeIv(converter, ivSym, currentLocation, ivTypes, ivLocs,
2113-
privateOperands, ivPrivate, privatizationRecipes);
2114-
2115-
inclusiveBounds.push_back(true);
2116-
2117-
// crtEval already updated when descending; no blind increment here.
2118-
}
2110+
visitLoopControl(
2111+
converter, outerDoConstruct, loopsToProcess, eval,
2112+
[&](const Fortran::parser::LoopControl::Bounds &bounds,
2113+
mlir::Location loc) {
2114+
locs.push_back(loc);
2115+
lowerbounds.push_back(fir::getBase(converter.genExprValue(
2116+
*Fortran::semantics::GetExpr(bounds.lower), stmtCtx)));
2117+
upperbounds.push_back(fir::getBase(converter.genExprValue(
2118+
*Fortran::semantics::GetExpr(bounds.upper), stmtCtx)));
2119+
if (bounds.step)
2120+
steps.push_back(fir::getBase(converter.genExprValue(
2121+
*Fortran::semantics::GetExpr(bounds.step), stmtCtx)));
2122+
else // If `step` is not present, assume it is `1`.
2123+
steps.push_back(builder.createIntegerConstant(
2124+
currentLocation, upperbounds[upperbounds.size() - 1].getType(),
2125+
1));
2126+
Fortran::semantics::Symbol &ivSym =
2127+
bounds.name.thing.symbol->GetUltimate();
2128+
privatizeIv(converter, ivSym, currentLocation, ivTypes, ivLocs,
2129+
privateOperands, ivPrivate, privatizationRecipes);
2130+
2131+
inclusiveBounds.push_back(true);
2132+
});
21192133
}
21202134
}
21212135

@@ -2251,6 +2265,34 @@ static void remapDataOperandSymbols(
22512265
}
22522266
}
22532267

2268+
static void privatizeInductionVariables(
2269+
Fortran::lower::AbstractConverter &converter,
2270+
mlir::Location currentLocation,
2271+
const Fortran::parser::DoConstruct &outerDoConstruct,
2272+
Fortran::lower::pft::Evaluation &eval,
2273+
llvm::SmallVector<mlir::Value> &privateOperands,
2274+
llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
2275+
&ivPrivate,
2276+
llvm::SmallVector<mlir::Attribute> &privatizationRecipes,
2277+
llvm::SmallVector<mlir::Location> &locs, uint64_t loopsToProcess) {
2278+
// ivTypes and locs will be ignored since no acc.loop control arguments will
2279+
// be created.
2280+
llvm::SmallVector<mlir::Type> ivTypes;
2281+
llvm::SmallVector<mlir::Location> ivLocs;
2282+
assert(!outerDoConstruct.IsDoConcurrent() &&
2283+
"do concurrent loops are not expected to contained earlty exits");
2284+
visitLoopControl(converter, outerDoConstruct, loopsToProcess, eval,
2285+
[&](const Fortran::parser::LoopControl::Bounds &bounds,
2286+
mlir::Location loc) {
2287+
locs.push_back(loc);
2288+
Fortran::semantics::Symbol &ivSym =
2289+
bounds.name.thing.symbol->GetUltimate();
2290+
privatizeIv(converter, ivSym, currentLocation, ivTypes,
2291+
ivLocs, privateOperands, ivPrivate,
2292+
privatizationRecipes);
2293+
});
2294+
}
2295+
22542296
static mlir::acc::LoopOp buildACCLoopOp(
22552297
Fortran::lower::AbstractConverter &converter,
22562298
mlir::Location currentLocation,
@@ -2280,13 +2322,22 @@ static mlir::acc::LoopOp buildACCLoopOp(
22802322
llvm::SmallVector<mlir::Location> locs;
22812323
llvm::SmallVector<mlir::Value> lowerbounds, upperbounds, steps;
22822324

2283-
// Look at the do/do concurrent loops to extract bounds information.
2284-
processDoLoopBounds(converter, currentLocation, stmtCtx, builder,
2285-
outerDoConstruct, eval, lowerbounds, upperbounds, steps,
2286-
privateOperands, ivPrivate, privatizationRecipes, ivTypes,
2287-
ivLocs, inclusiveBounds, locs, loopsToProcess);
2288-
2289-
// Prepare the operand segment size attribute and the operands value range.
2325+
// Look at the do/do concurrent loops to extract bounds information unless
2326+
// this loop is lowered in an unstructured fashion, in which case bounds are
2327+
// not represented on acc.loop and explicit control flow is used inside body.
2328+
if (!eval.lowerAsUnstructured()) {
2329+
processDoLoopBounds(converter, currentLocation, stmtCtx, builder,
2330+
outerDoConstruct, eval, lowerbounds, upperbounds, steps,
2331+
privateOperands, ivPrivate, privatizationRecipes,
2332+
ivTypes, ivLocs, inclusiveBounds, locs, loopsToProcess);
2333+
} else {
2334+
// When the loop contains early exits, privatize induction variables, but do
2335+
// not create acc.loop bounds. The control flow of the loop will be
2336+
// generated explicitly in the acc.loop body that is just a container.
2337+
privatizeInductionVariables(converter, currentLocation, outerDoConstruct,
2338+
eval, privateOperands, ivPrivate,
2339+
privatizationRecipes, locs, loopsToProcess);
2340+
}
22902341
llvm::SmallVector<mlir::Value> operands;
22912342
llvm::SmallVector<int32_t> operandSegments;
22922343
addOperands(operands, operandSegments, lowerbounds);
@@ -2315,20 +2366,36 @@ static mlir::acc::LoopOp buildACCLoopOp(
23152366
// Remap symbols from data clauses to use data operation results
23162367
remapDataOperandSymbols(converter, builder, loopOp, dataOperandSymbolPairs);
23172368

2318-
for (auto [arg, iv] :
2319-
llvm::zip(loopOp.getLoopRegions().front()->front().getArguments(),
2320-
ivPrivate)) {
2321-
// Store block argument to the related iv private variable.
2322-
mlir::Value privateValue =
2323-
converter.getSymbolAddress(std::get<Fortran::semantics::SymbolRef>(iv));
2324-
fir::StoreOp::create(builder, currentLocation, arg, privateValue);
2369+
if (!eval.lowerAsUnstructured()) {
2370+
for (auto [arg, iv] :
2371+
llvm::zip(loopOp.getLoopRegions().front()->front().getArguments(),
2372+
ivPrivate)) {
2373+
// Store block argument to the related iv private variable.
2374+
mlir::Value privateValue = converter.getSymbolAddress(
2375+
std::get<Fortran::semantics::SymbolRef>(iv));
2376+
fir::StoreOp::create(builder, currentLocation, arg, privateValue);
2377+
}
2378+
loopOp.setInclusiveUpperbound(inclusiveBounds);
2379+
} else {
2380+
loopOp.setUnstructuredAttr(builder.getUnitAttr());
23252381
}
23262382

2327-
loopOp.setInclusiveUpperbound(inclusiveBounds);
2328-
23292383
return loopOp;
23302384
}
23312385

2386+
static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
2387+
bool hasReturnStmt = false;
2388+
for (auto &e : eval.getNestedEvaluations()) {
2389+
e.visit(Fortran::common::visitors{
2390+
[&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
2391+
[&](const auto &s) {},
2392+
});
2393+
if (e.hasNestedEvaluations())
2394+
hasReturnStmt = hasEarlyReturn(e);
2395+
}
2396+
return hasReturnStmt;
2397+
}
2398+
23322399
static mlir::acc::LoopOp createLoopOp(
23332400
Fortran::lower::AbstractConverter &converter,
23342401
mlir::Location currentLocation,
@@ -2338,8 +2405,7 @@ static mlir::acc::LoopOp createLoopOp(
23382405
Fortran::lower::pft::Evaluation &eval,
23392406
const Fortran::parser::AccClauseList &accClauseList,
23402407
std::optional<mlir::acc::CombinedConstructsType> combinedConstructs =
2341-
std::nullopt,
2342-
bool needEarlyReturnHandling = false) {
2408+
std::nullopt) {
23432409
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
23442410
llvm::SmallVector<mlir::Value> tileOperands, privateOperands,
23452411
reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
@@ -2515,7 +2581,10 @@ static mlir::acc::LoopOp createLoopOp(
25152581

25162582
llvm::SmallVector<mlir::Type> retTy;
25172583
mlir::Value yieldValue;
2518-
if (needEarlyReturnHandling) {
2584+
if (eval.lowerAsUnstructured() && hasEarlyReturn(eval)) {
2585+
// When there is a return statement inside the loop, add a result to the
2586+
// acc.loop that will be used in a conditional branch after the loop to
2587+
// return.
25192588
mlir::Type i1Ty = builder.getI1Type();
25202589
yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
25212590
retTy.push_back(i1Ty);
@@ -2596,19 +2665,6 @@ static mlir::acc::LoopOp createLoopOp(
25962665
return loopOp;
25972666
}
25982667

2599-
static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
2600-
bool hasReturnStmt = false;
2601-
for (auto &e : eval.getNestedEvaluations()) {
2602-
e.visit(Fortran::common::visitors{
2603-
[&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
2604-
[&](const auto &s) {},
2605-
});
2606-
if (e.hasNestedEvaluations())
2607-
hasReturnStmt = hasEarlyReturn(e);
2608-
}
2609-
return hasReturnStmt;
2610-
}
2611-
26122668
static mlir::Value
26132669
genACC(Fortran::lower::AbstractConverter &converter,
26142670
Fortran::semantics::SemanticsContext &semanticsContext,
@@ -2622,17 +2678,6 @@ genACC(Fortran::lower::AbstractConverter &converter,
26222678

26232679
mlir::Location currentLocation =
26242680
converter.genLocation(beginLoopDirective.source);
2625-
bool needEarlyExitHandling = false;
2626-
if (eval.lowerAsUnstructured()) {
2627-
needEarlyExitHandling = hasEarlyReturn(eval);
2628-
// If the loop is lowered in an unstructured fashion, lowering generates
2629-
// explicit control flow that duplicates the looping semantics of the
2630-
// loops.
2631-
if (!needEarlyExitHandling)
2632-
TODO(currentLocation,
2633-
"loop with early exit inside OpenACC loop construct");
2634-
}
2635-
26362681
Fortran::lower::StatementContext stmtCtx;
26372682

26382683
assert(loopDirective.v == llvm::acc::ACCD_loop &&
@@ -2645,8 +2690,8 @@ genACC(Fortran::lower::AbstractConverter &converter,
26452690
std::get<std::optional<Fortran::parser::DoConstruct>>(loopConstruct.t);
26462691
auto loopOp = createLoopOp(converter, currentLocation, semanticsContext,
26472692
stmtCtx, *outerDoConstruct, eval, accClauseList,
2648-
/*combinedConstructs=*/{}, needEarlyExitHandling);
2649-
if (needEarlyExitHandling)
2693+
/*combinedConstructs=*/{});
2694+
if (loopOp.getNumResults() == 1)
26502695
return loopOp.getResult(0);
26512696

26522697
return mlir::Value{};
@@ -3431,10 +3476,6 @@ genACC(Fortran::lower::AbstractConverter &converter,
34313476
converter.genLocation(beginCombinedDirective.source);
34323477
Fortran::lower::StatementContext stmtCtx;
34333478

3434-
if (eval.lowerAsUnstructured())
3435-
TODO(currentLocation,
3436-
"loop with early exit inside OpenACC combined construct");
3437-
34383479
if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
34393480
createComputeOp<mlir::acc::KernelsOp>(
34403481
converter, currentLocation, eval, semanticsContext, stmtCtx,

flang/test/Lower/OpenACC/acc-unstructured.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
2-
! XFAIL: *
32

43
subroutine test_unstructured1(a, b, c)
54
integer :: i, j, k
@@ -55,10 +54,11 @@ subroutine test_unstructured2(a, b, c)
5554

5655
! CHECK-LABEL: func.func @_QPtest_unstructured2
5756
! CHECK: acc.parallel
58-
! CHECK: acc.loop
57+
! CHECK: acc.loop combined(parallel) private(@privatization_ref_i32 -> %{{.*}} : !fir.ref<i32>) {
5958
! CHECK: fir.call @_FortranAStopStatementText
6059
! CHECK: acc.yield
6160
! CHECK: acc.yield
61+
! CHECK: } attributes {independent = [#acc.device_type<none>], unstructured}
6262
! CHECK: acc.yield
6363

6464
end subroutine

0 commit comments

Comments
 (0)