@@ -90,23 +90,29 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
9090}
9191
9292template <typename Op>
93- static void
94- createBodyOfOp (Op &op, Fortran::lower::AbstractConverter &converter,
95- mlir::Location &loc,
96- const Fortran::parser::OmpClauseList *clauses = nullptr ,
97- const Fortran::semantics::Symbol *arg = nullptr ) {
93+ static void createBodyOfOp (
94+ Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
95+ const Fortran::parser::OmpClauseList *clauses = nullptr ,
96+ const SmallVector<const Fortran::semantics::Symbol *> &args = {}) {
9897 auto &firOpBuilder = converter.getFirOpBuilder ();
9998 // If an argument for the region is provided then create the block with that
10099 // argument. Also update the symbol's address with the mlir argument value.
101100 // e.g. For loops the argument is the induction variable. And all further
102101 // uses of the induction variable should use this mlir value.
103- if (arg) {
104- firOpBuilder.createBlock (&op.getRegion (), {}, {converter.genType (*arg)});
105- fir::ExtendedValue exval = op.getRegion ().front ().getArgument (0 );
106- [[maybe_unused]] bool success = converter.bindSymbol (*arg, exval);
107- assert (
108- success &&
109- " Existing binding prevents setting MLIR value for the index variable" );
102+ if (args.size ()) {
103+ SmallVector<Type> tiv;
104+ int argIndex = 0 ;
105+ for (auto &arg : args) {
106+ tiv.push_back (converter.genType (*arg));
107+ }
108+ firOpBuilder.createBlock (&op.getRegion (), {}, tiv);
109+ for (auto &arg : args) {
110+ fir::ExtendedValue exval = op.getRegion ().front ().getArgument (argIndex);
111+ [[maybe_unused]] bool success = converter.bindSymbol (*arg, exval);
112+ assert (success && " Existing binding prevents setting MLIR value for the "
113+ " index variable" );
114+ argIndex++;
115+ }
110116 } else {
111117 firOpBuilder.createBlock (&op.getRegion ());
112118 }
@@ -397,6 +403,18 @@ getSIMDModifier(const Fortran::parser::OmpScheduleClause &x) {
397403 return mlir::omp::ScheduleModifier::none;
398404}
399405
406+ int64_t Fortran::lower::getCollapseValue (
407+ const Fortran::parser::OmpClauseList &clauseList) {
408+ for (const auto &clause : clauseList.v ) {
409+ if (const auto &collapseClause =
410+ std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u )) {
411+ const auto *expr = Fortran::semantics::GetExpr (collapseClause->v );
412+ return Fortran::evaluate::ToInt64 (*expr).value ();
413+ }
414+ }
415+ return 1 ;
416+ }
417+
400418static void genOMP (Fortran::lower::AbstractConverter &converter,
401419 Fortran::lower::pft::Evaluation &eval,
402420 const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
@@ -437,34 +455,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
437455 genObjectList (ompObjectList, converter, lastPrivateClauseOperands);
438456 }
439457 }
440- // FIXME: Can be done in a better way ?
441- auto &doConstructEval =
442- eval.getFirstNestedEvaluation ().getFirstNestedEvaluation ();
443- auto *doStmt = doConstructEval.getIf <Fortran::parser::NonLabelDoStmt>();
444-
445- const auto &loopControl =
446- std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t );
447- const Fortran::parser::LoopControl::Bounds *bounds =
448- std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u );
449- Fortran::semantics::Symbol *iv = nullptr ;
450- if (bounds) {
451- Fortran::lower::StatementContext stmtCtx;
452- lowerBound.push_back (fir::getBase (converter.genExprValue (
453- *Fortran::semantics::GetExpr (bounds->lower ), stmtCtx)));
454- upperBound.push_back (fir::getBase (converter.genExprValue (
455- *Fortran::semantics::GetExpr (bounds->upper ), stmtCtx)));
456- if (bounds->step ) {
457- step.push_back (fir::getBase (converter.genExprValue (
458- *Fortran::semantics::GetExpr (bounds->step ), stmtCtx)));
459- }
460- // If `step` is not present, assume it as `1`.
461- else {
462- step.push_back (firOpBuilder.createIntegerConstant (
463- currentLocation, firOpBuilder.getIntegerType (32 ), 1 ));
464- }
465- iv = bounds->name .thing .symbol ;
466- }
467458
459+ int64_t collapseValue = Fortran::lower::getCollapseValue (wsLoopOpClauseList);
468460 for (const auto &clause : wsLoopOpClauseList.v ) {
469461 if (const auto &scheduleClause =
470462 std::get_if<Fortran::parser::OmpClause::Schedule>(&clause.u )) {
@@ -480,6 +472,41 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
480472 }
481473 }
482474
475+ // Collect the loops to collapse.
476+ auto *doConstructEval = &eval.getFirstNestedEvaluation ();
477+
478+ SmallVector<const Fortran::semantics::Symbol *> iv;
479+ do {
480+ auto *doLoop = &doConstructEval->getFirstNestedEvaluation ();
481+ auto *doStmt = doLoop->getIf <Fortran::parser::NonLabelDoStmt>();
482+ assert (doStmt && " Expected do loop to be in the nested evaluation" );
483+ const auto &loopControl =
484+ std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t );
485+ const Fortran::parser::LoopControl::Bounds *bounds =
486+ std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u );
487+ if (bounds) {
488+ Fortran::lower::StatementContext stmtCtx;
489+ lowerBound.push_back (fir::getBase (converter.genExprValue (
490+ *Fortran::semantics::GetExpr (bounds->lower ), stmtCtx)));
491+ upperBound.push_back (fir::getBase (converter.genExprValue (
492+ *Fortran::semantics::GetExpr (bounds->upper ), stmtCtx)));
493+ if (bounds->step ) {
494+ step.push_back (fir::getBase (converter.genExprValue (
495+ *Fortran::semantics::GetExpr (bounds->step ), stmtCtx)));
496+ }
497+ // If `step` is not present, assume it as `1`.
498+ else {
499+ step.push_back (firOpBuilder.createIntegerConstant (
500+ currentLocation, firOpBuilder.getIntegerType (32 ), 1 ));
501+ }
502+ iv.push_back (bounds->name .thing .symbol );
503+ }
504+
505+ collapseValue--;
506+ doConstructEval =
507+ &*std::next (doConstructEval->getNestedEvaluations ().begin ());
508+ } while (collapseValue > 0 );
509+
483510 // FIXME: Add support for following clauses:
484511 // 1. linear
485512 // 2. order
0 commit comments