Skip to content

Commit 9bb33e3

Browse files
committed
[flang][OpenACC] lower acc loops with early exits
1 parent b63833f commit 9bb33e3

File tree

2 files changed

+144
-103
lines changed

2 files changed

+144
-103
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 142 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2251,6 +2251,49 @@ static void determineDefaultLoopParMode(
22512251
}
22522252
}
22532253

2254+
// Helper to visit Bounds of DO LOOP nest.
2255+
static void visitLoopControl(
2256+
Fortran::lower::AbstractConverter &converter,
2257+
const Fortran::parser::DoConstruct &outerDoConstruct,
2258+
uint64_t loopsToProcess, Fortran::lower::pft::Evaluation &eval,
2259+
std::function<void(const Fortran::parser::LoopControl::Bounds &,
2260+
mlir::Location)>
2261+
callback) {
2262+
Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation();
2263+
for (uint64_t i = 0; i < loopsToProcess; ++i) {
2264+
const Fortran::parser::LoopControl *loopControl;
2265+
if (i == 0) {
2266+
loopControl = &*outerDoConstruct.GetLoopControl();
2267+
mlir::Location loc = converter.genLocation(
2268+
Fortran::parser::FindSourceLocation(outerDoConstruct));
2269+
callback(std::get<Fortran::parser::LoopControl::Bounds>(loopControl->u),
2270+
loc);
2271+
} else {
2272+
// Safely locate the next inner DoConstruct within this eval.
2273+
const Fortran::parser::DoConstruct *innerDo = nullptr;
2274+
if (crtEval && crtEval->hasNestedEvaluations()) {
2275+
for (Fortran::lower::pft::Evaluation &child :
2276+
crtEval->getNestedEvaluations()) {
2277+
if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) {
2278+
innerDo = stmt;
2279+
// Prepare to descend for the next iteration
2280+
crtEval = &child;
2281+
break;
2282+
}
2283+
}
2284+
}
2285+
if (!innerDo)
2286+
break; // No deeper loop; stop collecting collapsed bounds.
2287+
2288+
loopControl = &*innerDo->GetLoopControl();
2289+
mlir::Location loc =
2290+
converter.genLocation(Fortran::parser::FindSourceLocation(*innerDo));
2291+
callback(std::get<Fortran::parser::LoopControl::Bounds>(loopControl->u),
2292+
loc);
2293+
}
2294+
}
2295+
}
2296+
22542297
// Extract loop bounds, steps, induction variables, and privatization info
22552298
// for both DO CONCURRENT and regular do loops
22562299
static void processDoLoopBounds(
@@ -2272,7 +2315,6 @@ static void processDoLoopBounds(
22722315
llvm::SmallVector<mlir::Location> &locs, uint64_t loopsToProcess) {
22732316
assert(loopsToProcess > 0 && "expect at least one loop");
22742317
locs.push_back(currentLocation); // Location of the directive
2275-
Fortran::lower::pft::Evaluation *crtEval = &eval.getFirstNestedEvaluation();
22762318
bool isDoConcurrent = outerDoConstruct.IsDoConcurrent();
22772319

22782320
if (isDoConcurrent) {
@@ -2313,57 +2355,29 @@ static void processDoLoopBounds(
23132355
inclusiveBounds.push_back(true);
23142356
}
23152357
} else {
2316-
for (uint64_t i = 0; i < loopsToProcess; ++i) {
2317-
const Fortran::parser::LoopControl *loopControl;
2318-
if (i == 0) {
2319-
loopControl = &*outerDoConstruct.GetLoopControl();
2320-
locs.push_back(converter.genLocation(
2321-
Fortran::parser::FindSourceLocation(outerDoConstruct)));
2322-
} else {
2323-
// Safely locate the next inner DoConstruct within this eval.
2324-
const Fortran::parser::DoConstruct *innerDo = nullptr;
2325-
if (crtEval && crtEval->hasNestedEvaluations()) {
2326-
for (Fortran::lower::pft::Evaluation &child :
2327-
crtEval->getNestedEvaluations()) {
2328-
if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) {
2329-
innerDo = stmt;
2330-
// Prepare to descend for the next iteration
2331-
crtEval = &child;
2332-
break;
2333-
}
2334-
}
2335-
}
2336-
if (!innerDo)
2337-
break; // No deeper loop; stop collecting collapsed bounds.
2338-
2339-
loopControl = &*innerDo->GetLoopControl();
2340-
locs.push_back(converter.genLocation(
2341-
Fortran::parser::FindSourceLocation(*innerDo)));
2342-
}
2343-
2344-
const Fortran::parser::LoopControl::Bounds *bounds =
2345-
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
2346-
assert(bounds && "Expected bounds on the loop construct");
2347-
lowerbounds.push_back(fir::getBase(converter.genExprValue(
2348-
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
2349-
upperbounds.push_back(fir::getBase(converter.genExprValue(
2350-
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
2351-
if (bounds->step)
2352-
steps.push_back(fir::getBase(converter.genExprValue(
2353-
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
2354-
else // If `step` is not present, assume it is `1`.
2355-
steps.push_back(builder.createIntegerConstant(
2356-
currentLocation, upperbounds[upperbounds.size() - 1].getType(), 1));
2357-
2358-
Fortran::semantics::Symbol &ivSym =
2359-
bounds->name.thing.symbol->GetUltimate();
2360-
privatizeIv(converter, ivSym, currentLocation, ivTypes, ivLocs,
2361-
privateOperands, ivPrivate, privatizationRecipes);
2362-
2363-
inclusiveBounds.push_back(true);
2364-
2365-
// crtEval already updated when descending; no blind increment here.
2366-
}
2358+
visitLoopControl(
2359+
converter, outerDoConstruct, loopsToProcess, eval,
2360+
[&](const Fortran::parser::LoopControl::Bounds &bounds,
2361+
mlir::Location loc) {
2362+
locs.push_back(loc);
2363+
lowerbounds.push_back(fir::getBase(converter.genExprValue(
2364+
*Fortran::semantics::GetExpr(bounds.lower), stmtCtx)));
2365+
upperbounds.push_back(fir::getBase(converter.genExprValue(
2366+
*Fortran::semantics::GetExpr(bounds.upper), stmtCtx)));
2367+
if (bounds.step)
2368+
steps.push_back(fir::getBase(converter.genExprValue(
2369+
*Fortran::semantics::GetExpr(bounds.step), stmtCtx)));
2370+
else // If `step` is not present, assume it is `1`.
2371+
steps.push_back(builder.createIntegerConstant(
2372+
currentLocation, upperbounds[upperbounds.size() - 1].getType(),
2373+
1));
2374+
Fortran::semantics::Symbol &ivSym =
2375+
bounds.name.thing.symbol->GetUltimate();
2376+
privatizeIv(converter, ivSym, currentLocation, ivTypes, ivLocs,
2377+
privateOperands, ivPrivate, privatizationRecipes);
2378+
2379+
inclusiveBounds.push_back(true);
2380+
});
23672381
}
23682382
}
23692383

@@ -2499,6 +2513,34 @@ static void remapDataOperandSymbols(
24992513
}
25002514
}
25012515

2516+
static void privatizeInductionVariables(
2517+
Fortran::lower::AbstractConverter &converter,
2518+
mlir::Location currentLocation,
2519+
const Fortran::parser::DoConstruct &outerDoConstruct,
2520+
Fortran::lower::pft::Evaluation &eval,
2521+
llvm::SmallVector<mlir::Value> &privateOperands,
2522+
llvm::SmallVector<std::pair<mlir::Value, Fortran::semantics::SymbolRef>>
2523+
&ivPrivate,
2524+
llvm::SmallVector<mlir::Attribute> &privatizationRecipes,
2525+
llvm::SmallVector<mlir::Location> &locs, uint64_t loopsToProcess) {
2526+
// ivTypes and locs will be ignored since no acc.loop control arguments will
2527+
// be created.
2528+
llvm::SmallVector<mlir::Type> ivTypes;
2529+
llvm::SmallVector<mlir::Location> ivLocs;
2530+
assert(!outerDoConstruct.IsDoConcurrent() &&
2531+
"do concurrent loops are not expected to contained earlty exits");
2532+
visitLoopControl(converter, outerDoConstruct, loopsToProcess, eval,
2533+
[&](const Fortran::parser::LoopControl::Bounds &bounds,
2534+
mlir::Location loc) {
2535+
locs.push_back(loc);
2536+
Fortran::semantics::Symbol &ivSym =
2537+
bounds.name.thing.symbol->GetUltimate();
2538+
privatizeIv(converter, ivSym, currentLocation, ivTypes,
2539+
ivLocs, privateOperands, ivPrivate,
2540+
privatizationRecipes);
2541+
});
2542+
}
2543+
25022544
static mlir::acc::LoopOp buildACCLoopOp(
25032545
Fortran::lower::AbstractConverter &converter,
25042546
mlir::Location currentLocation,
@@ -2528,13 +2570,22 @@ static mlir::acc::LoopOp buildACCLoopOp(
25282570
llvm::SmallVector<mlir::Location> locs;
25292571
llvm::SmallVector<mlir::Value> lowerbounds, upperbounds, steps;
25302572

2531-
// Look at the do/do concurrent loops to extract bounds information.
2532-
processDoLoopBounds(converter, currentLocation, stmtCtx, builder,
2533-
outerDoConstruct, eval, lowerbounds, upperbounds, steps,
2534-
privateOperands, ivPrivate, privatizationRecipes, ivTypes,
2535-
ivLocs, inclusiveBounds, locs, loopsToProcess);
2536-
2537-
// Prepare the operand segment size attribute and the operands value range.
2573+
// Look at the do/do concurrent loops to extract bounds information unless
2574+
// this loop is lowered in an unstructured fashion, in which case bounds are
2575+
// not represented on acc.loop and explicit control flow is used inside body.
2576+
if (!eval.lowerAsUnstructured()) {
2577+
processDoLoopBounds(converter, currentLocation, stmtCtx, builder,
2578+
outerDoConstruct, eval, lowerbounds, upperbounds, steps,
2579+
privateOperands, ivPrivate, privatizationRecipes,
2580+
ivTypes, ivLocs, inclusiveBounds, locs, loopsToProcess);
2581+
} else {
2582+
// When the loop contains early exits, privatize induction variables, but do
2583+
// not create acc.loop bounds. The control flow of the loop will be
2584+
// generated explicitly in the acc.loop body that is just a container.
2585+
privatizeInductionVariables(converter, currentLocation, outerDoConstruct,
2586+
eval, privateOperands, ivPrivate,
2587+
privatizationRecipes, locs, loopsToProcess);
2588+
}
25382589
llvm::SmallVector<mlir::Value> operands;
25392590
llvm::SmallVector<int32_t> operandSegments;
25402591
addOperands(operands, operandSegments, lowerbounds);
@@ -2563,20 +2614,36 @@ static mlir::acc::LoopOp buildACCLoopOp(
25632614
// Remap symbols from data clauses to use data operation results
25642615
remapDataOperandSymbols(converter, builder, loopOp, dataOperandSymbolPairs);
25652616

2566-
for (auto [arg, iv] :
2567-
llvm::zip(loopOp.getLoopRegions().front()->front().getArguments(),
2568-
ivPrivate)) {
2569-
// Store block argument to the related iv private variable.
2570-
mlir::Value privateValue =
2571-
converter.getSymbolAddress(std::get<Fortran::semantics::SymbolRef>(iv));
2572-
fir::StoreOp::create(builder, currentLocation, arg, privateValue);
2617+
if (!eval.lowerAsUnstructured()) {
2618+
for (auto [arg, iv] :
2619+
llvm::zip(loopOp.getLoopRegions().front()->front().getArguments(),
2620+
ivPrivate)) {
2621+
// Store block argument to the related iv private variable.
2622+
mlir::Value privateValue = converter.getSymbolAddress(
2623+
std::get<Fortran::semantics::SymbolRef>(iv));
2624+
fir::StoreOp::create(builder, currentLocation, arg, privateValue);
2625+
}
2626+
loopOp.setInclusiveUpperbound(inclusiveBounds);
2627+
} else {
2628+
loopOp.setUnstructuredAttr(builder.getUnitAttr());
25732629
}
25742630

2575-
loopOp.setInclusiveUpperbound(inclusiveBounds);
2576-
25772631
return loopOp;
25782632
}
25792633

2634+
static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
2635+
bool hasReturnStmt = false;
2636+
for (auto &e : eval.getNestedEvaluations()) {
2637+
e.visit(Fortran::common::visitors{
2638+
[&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
2639+
[&](const auto &s) {},
2640+
});
2641+
if (e.hasNestedEvaluations())
2642+
hasReturnStmt = hasEarlyReturn(e);
2643+
}
2644+
return hasReturnStmt;
2645+
}
2646+
25802647
static mlir::acc::LoopOp createLoopOp(
25812648
Fortran::lower::AbstractConverter &converter,
25822649
mlir::Location currentLocation,
@@ -2586,8 +2653,7 @@ static mlir::acc::LoopOp createLoopOp(
25862653
Fortran::lower::pft::Evaluation &eval,
25872654
const Fortran::parser::AccClauseList &accClauseList,
25882655
std::optional<mlir::acc::CombinedConstructsType> combinedConstructs =
2589-
std::nullopt,
2590-
bool needEarlyReturnHandling = false) {
2656+
std::nullopt) {
25912657
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
25922658
llvm::SmallVector<mlir::Value> tileOperands, privateOperands,
25932659
reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
@@ -2763,7 +2829,10 @@ static mlir::acc::LoopOp createLoopOp(
27632829

27642830
llvm::SmallVector<mlir::Type> retTy;
27652831
mlir::Value yieldValue;
2766-
if (needEarlyReturnHandling) {
2832+
if (eval.lowerAsUnstructured() && hasEarlyReturn(eval)) {
2833+
// When there is a return statement inside the loop, add a result to the
2834+
// acc.loop that will be used in a conditional branch after the loop to
2835+
// return.
27672836
mlir::Type i1Ty = builder.getI1Type();
27682837
yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
27692838
retTy.push_back(i1Ty);
@@ -2844,19 +2913,6 @@ static mlir::acc::LoopOp createLoopOp(
28442913
return loopOp;
28452914
}
28462915

2847-
static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
2848-
bool hasReturnStmt = false;
2849-
for (auto &e : eval.getNestedEvaluations()) {
2850-
e.visit(Fortran::common::visitors{
2851-
[&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
2852-
[&](const auto &s) {},
2853-
});
2854-
if (e.hasNestedEvaluations())
2855-
hasReturnStmt = hasEarlyReturn(e);
2856-
}
2857-
return hasReturnStmt;
2858-
}
2859-
28602916
static mlir::Value
28612917
genACC(Fortran::lower::AbstractConverter &converter,
28622918
Fortran::semantics::SemanticsContext &semanticsContext,
@@ -2870,17 +2926,6 @@ genACC(Fortran::lower::AbstractConverter &converter,
28702926

28712927
mlir::Location currentLocation =
28722928
converter.genLocation(beginLoopDirective.source);
2873-
bool needEarlyExitHandling = false;
2874-
if (eval.lowerAsUnstructured()) {
2875-
needEarlyExitHandling = hasEarlyReturn(eval);
2876-
// If the loop is lowered in an unstructured fashion, lowering generates
2877-
// explicit control flow that duplicates the looping semantics of the
2878-
// loops.
2879-
if (!needEarlyExitHandling)
2880-
TODO(currentLocation,
2881-
"loop with early exit inside OpenACC loop construct");
2882-
}
2883-
28842929
Fortran::lower::StatementContext stmtCtx;
28852930

28862931
assert(loopDirective.v == llvm::acc::ACCD_loop &&
@@ -2893,8 +2938,8 @@ genACC(Fortran::lower::AbstractConverter &converter,
28932938
std::get<std::optional<Fortran::parser::DoConstruct>>(loopConstruct.t);
28942939
auto loopOp = createLoopOp(converter, currentLocation, semanticsContext,
28952940
stmtCtx, *outerDoConstruct, eval, accClauseList,
2896-
/*combinedConstructs=*/{}, needEarlyExitHandling);
2897-
if (needEarlyExitHandling)
2941+
/*combinedConstructs=*/{});
2942+
if (loopOp.getNumResults() == 1)
28982943
return loopOp.getResult(0);
28992944

29002945
return mlir::Value{};
@@ -3679,10 +3724,6 @@ genACC(Fortran::lower::AbstractConverter &converter,
36793724
converter.genLocation(beginCombinedDirective.source);
36803725
Fortran::lower::StatementContext stmtCtx;
36813726

3682-
if (eval.lowerAsUnstructured())
3683-
TODO(currentLocation,
3684-
"loop with early exit inside OpenACC combined construct");
3685-
36863727
if (combinedDirective.v == llvm::acc::ACCD_kernels_loop) {
36873728
createComputeOp<mlir::acc::KernelsOp>(
36883729
converter, currentLocation, eval, semanticsContext, stmtCtx,

flang/test/Lower/OpenACC/acc-unstructured.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
2-
! XFAIL: *
32

43
subroutine test_unstructured1(a, b, c)
54
integer :: i, j, k
@@ -55,10 +54,11 @@ subroutine test_unstructured2(a, b, c)
5554

5655
! CHECK-LABEL: func.func @_QPtest_unstructured2
5756
! CHECK: acc.parallel
58-
! CHECK: acc.loop
57+
! CHECK: acc.loop combined(parallel) private(@privatization_ref_i32 -> %{{.*}} : !fir.ref<i32>) {
5958
! CHECK: fir.call @_FortranAStopStatementText
6059
! CHECK: acc.yield
6160
! CHECK: acc.yield
61+
! CHECK: } attributes {independent = [#acc.device_type<none>], unstructured}
6262
! CHECK: acc.yield
6363

6464
end subroutine

0 commit comments

Comments
 (0)