Skip to content

Commit e459d4d

Browse files
[flang][OpenMP] Support for Collapse
Convert Fortran parse-tree into MLIR for collapse-clause. Includes simple Fortran to LLVM-IR test, with auto-generated check-lines (some of which have been edited by hand).
1 parent f8144e6 commit e459d4d

File tree

4 files changed

+331
-43
lines changed

4 files changed

+331
-43
lines changed

flang/include/flang/Lower/OpenMP.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
#ifndef FORTRAN_LOWER_OPENMP_H
1414
#define FORTRAN_LOWER_OPENMP_H
1515

16+
#include <cinttypes>
17+
1618
namespace Fortran {
1719
namespace parser {
1820
struct OpenMPConstruct;
1921
struct OmpEndLoopDirective;
22+
struct OmpClauseList;
2023
} // namespace parser
2124

2225
namespace lower {
@@ -29,6 +32,9 @@ struct Evaluation;
2932

3033
void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &,
3134
const parser::OpenMPConstruct &);
35+
36+
int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
37+
3238
} // namespace lower
3339
} // namespace Fortran
3440

flang/lib/Lower/Bridge.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,13 +1283,28 @@ class FirConverter : public Fortran::lower::AbstractConverter {
12831283
auto insertPt = builder->saveInsertionPoint();
12841284
localSymbols.pushScope();
12851285
genOpenMPConstruct(*this, getEval(), omp);
1286+
1287+
auto ompLoop = std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);
1288+
12861289
// If loop is part of an OpenMP Construct then the OpenMP dialect
12871290
// workshare loop operation has already been created. Only the
12881291
// body needs to be created here and the do_loop can be skipped.
1289-
Fortran::lower::pft::Evaluation *curEval =
1290-
std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u)
1291-
? &getEval().getFirstNestedEvaluation()
1292-
: &getEval();
1292+
// Skip the number of collapsed loops, which is 1 when there is a
1293+
// no collapse requested.
1294+
1295+
Fortran::lower::pft::Evaluation *curEval = &getEval();
1296+
if (ompLoop) {
1297+
const auto &wsLoopOpClauseList = std::get<Fortran::parser::OmpClauseList>(
1298+
std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t);
1299+
int64_t collapseValue =
1300+
Fortran::lower::getCollapseValue(wsLoopOpClauseList);
1301+
1302+
curEval = &curEval->getFirstNestedEvaluation();
1303+
for (auto i = 1; i < collapseValue; i++) {
1304+
curEval = &*std::next(curEval->getNestedEvaluations().begin());
1305+
}
1306+
}
1307+
12931308
for (auto &e : curEval->getNestedEvaluations())
12941309
genFIR(e);
12951310
localSymbols.popScope();

flang/lib/Lower/OpenMP.cpp

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,29 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
9090
}
9191

9292
template <typename Op>
93-
static void
94-
createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
95-
mlir::Location &loc,
96-
const Fortran::parser::OmpClauseList *clauses = nullptr,
97-
const Fortran::semantics::Symbol *arg = nullptr) {
93+
static void createBodyOfOp(
94+
Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
95+
const Fortran::parser::OmpClauseList *clauses = nullptr,
96+
const SmallVector<const Fortran::semantics::Symbol *> &args = {}) {
9897
auto &firOpBuilder = converter.getFirOpBuilder();
9998
// If an argument for the region is provided then create the block with that
10099
// argument. Also update the symbol's address with the mlir argument value.
101100
// e.g. For loops the argument is the induction variable. And all further
102101
// uses of the induction variable should use this mlir value.
103-
if (arg) {
104-
firOpBuilder.createBlock(&op.getRegion(), {}, {converter.genType(*arg)});
105-
fir::ExtendedValue exval = op.getRegion().front().getArgument(0);
106-
[[maybe_unused]] bool success = converter.bindSymbol(*arg, exval);
107-
assert(
108-
success &&
109-
"Existing binding prevents setting MLIR value for the index variable");
102+
if (args.size()) {
103+
SmallVector<Type> tiv;
104+
int argIndex = 0;
105+
for (auto &arg : args) {
106+
tiv.push_back(converter.genType(*arg));
107+
}
108+
firOpBuilder.createBlock(&op.getRegion(), {}, tiv);
109+
for (auto &arg : args) {
110+
fir::ExtendedValue exval = op.getRegion().front().getArgument(argIndex);
111+
[[maybe_unused]] bool success = converter.bindSymbol(*arg, exval);
112+
assert(success && "Existing binding prevents setting MLIR value for the "
113+
"index variable");
114+
argIndex++;
115+
}
110116
} else {
111117
firOpBuilder.createBlock(&op.getRegion());
112118
}
@@ -397,6 +403,18 @@ getSIMDModifier(const Fortran::parser::OmpScheduleClause &x) {
397403
return mlir::omp::ScheduleModifier::none;
398404
}
399405

406+
int64_t Fortran::lower::getCollapseValue(
407+
const Fortran::parser::OmpClauseList &clauseList) {
408+
for (const auto &clause : clauseList.v) {
409+
if (const auto &collapseClause =
410+
std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
411+
const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
412+
return Fortran::evaluate::ToInt64(*expr).value();
413+
}
414+
}
415+
return 1;
416+
}
417+
400418
static void genOMP(Fortran::lower::AbstractConverter &converter,
401419
Fortran::lower::pft::Evaluation &eval,
402420
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
@@ -437,34 +455,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
437455
genObjectList(ompObjectList, converter, lastPrivateClauseOperands);
438456
}
439457
}
440-
// FIXME: Can be done in a better way ?
441-
auto &doConstructEval =
442-
eval.getFirstNestedEvaluation().getFirstNestedEvaluation();
443-
auto *doStmt = doConstructEval.getIf<Fortran::parser::NonLabelDoStmt>();
444-
445-
const auto &loopControl =
446-
std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
447-
const Fortran::parser::LoopControl::Bounds *bounds =
448-
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
449-
Fortran::semantics::Symbol *iv = nullptr;
450-
if (bounds) {
451-
Fortran::lower::StatementContext stmtCtx;
452-
lowerBound.push_back(fir::getBase(converter.genExprValue(
453-
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
454-
upperBound.push_back(fir::getBase(converter.genExprValue(
455-
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
456-
if (bounds->step) {
457-
step.push_back(fir::getBase(converter.genExprValue(
458-
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
459-
}
460-
// If `step` is not present, assume it as `1`.
461-
else {
462-
step.push_back(firOpBuilder.createIntegerConstant(
463-
currentLocation, firOpBuilder.getIntegerType(32), 1));
464-
}
465-
iv = bounds->name.thing.symbol;
466-
}
467458

459+
int64_t collapseValue = Fortran::lower::getCollapseValue(wsLoopOpClauseList);
468460
for (const auto &clause : wsLoopOpClauseList.v) {
469461
if (const auto &scheduleClause =
470462
std::get_if<Fortran::parser::OmpClause::Schedule>(&clause.u)) {
@@ -480,6 +472,41 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
480472
}
481473
}
482474

475+
// Collect the loops to collapse.
476+
auto *doConstructEval = &eval.getFirstNestedEvaluation();
477+
478+
SmallVector<const Fortran::semantics::Symbol *> iv;
479+
do {
480+
auto *doLoop = &doConstructEval->getFirstNestedEvaluation();
481+
auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
482+
assert(doStmt && "Expected do loop to be in the nested evaluation");
483+
const auto &loopControl =
484+
std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
485+
const Fortran::parser::LoopControl::Bounds *bounds =
486+
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
487+
if (bounds) {
488+
Fortran::lower::StatementContext stmtCtx;
489+
lowerBound.push_back(fir::getBase(converter.genExprValue(
490+
*Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
491+
upperBound.push_back(fir::getBase(converter.genExprValue(
492+
*Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
493+
if (bounds->step) {
494+
step.push_back(fir::getBase(converter.genExprValue(
495+
*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
496+
}
497+
// If `step` is not present, assume it as `1`.
498+
else {
499+
step.push_back(firOpBuilder.createIntegerConstant(
500+
currentLocation, firOpBuilder.getIntegerType(32), 1));
501+
}
502+
iv.push_back(bounds->name.thing.symbol);
503+
}
504+
505+
collapseValue--;
506+
doConstructEval =
507+
&*std::next(doConstructEval->getNestedEvaluations().begin());
508+
} while (collapseValue > 0);
509+
483510
// FIXME: Add support for following clauses:
484511
// 1. linear
485512
// 2. order

0 commit comments

Comments
 (0)