@@ -90,23 +90,29 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
90
90
}
91
91
92
92
template <typename Op>
93
- static void
94
- createBodyOfOp (Op &op, Fortran::lower::AbstractConverter &converter,
95
- mlir::Location &loc,
96
- const Fortran::parser::OmpClauseList *clauses = nullptr ,
97
- const Fortran::semantics::Symbol *arg = nullptr ) {
93
+ static void createBodyOfOp (
94
+ Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
95
+ const Fortran::parser::OmpClauseList *clauses = nullptr ,
96
+ const SmallVector<const Fortran::semantics::Symbol *> &args = {}) {
98
97
auto &firOpBuilder = converter.getFirOpBuilder ();
99
98
// If an argument for the region is provided then create the block with that
100
99
// argument. Also update the symbol's address with the mlir argument value.
101
100
// e.g. For loops the argument is the induction variable. And all further
102
101
// uses of the induction variable should use this mlir value.
103
- if (arg) {
104
- firOpBuilder.createBlock (&op.getRegion (), {}, {converter.genType (*arg)});
105
- fir::ExtendedValue exval = op.getRegion ().front ().getArgument (0 );
106
- [[maybe_unused]] bool success = converter.bindSymbol (*arg, exval);
107
- assert (
108
- success &&
109
- " Existing binding prevents setting MLIR value for the index variable" );
102
+ if (args.size ()) {
103
+ SmallVector<Type> tiv;
104
+ int argIndex = 0 ;
105
+ for (auto &arg : args) {
106
+ tiv.push_back (converter.genType (*arg));
107
+ }
108
+ firOpBuilder.createBlock (&op.getRegion (), {}, tiv);
109
+ for (auto &arg : args) {
110
+ fir::ExtendedValue exval = op.getRegion ().front ().getArgument (argIndex);
111
+ [[maybe_unused]] bool success = converter.bindSymbol (*arg, exval);
112
+ assert (success && " Existing binding prevents setting MLIR value for the "
113
+ " index variable" );
114
+ argIndex++;
115
+ }
110
116
} else {
111
117
firOpBuilder.createBlock (&op.getRegion ());
112
118
}
@@ -397,6 +403,18 @@ getSIMDModifier(const Fortran::parser::OmpScheduleClause &x) {
397
403
return mlir::omp::ScheduleModifier::none;
398
404
}
399
405
406
+ int64_t Fortran::lower::getCollapseValue (
407
+ const Fortran::parser::OmpClauseList &clauseList) {
408
+ for (const auto &clause : clauseList.v ) {
409
+ if (const auto &collapseClause =
410
+ std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u )) {
411
+ const auto *expr = Fortran::semantics::GetExpr (collapseClause->v );
412
+ return Fortran::evaluate::ToInt64 (*expr).value ();
413
+ }
414
+ }
415
+ return 1 ;
416
+ }
417
+
400
418
static void genOMP (Fortran::lower::AbstractConverter &converter,
401
419
Fortran::lower::pft::Evaluation &eval,
402
420
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
@@ -437,34 +455,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
437
455
genObjectList (ompObjectList, converter, lastPrivateClauseOperands);
438
456
}
439
457
}
440
- // FIXME: Can be done in a better way ?
441
- auto &doConstructEval =
442
- eval.getFirstNestedEvaluation ().getFirstNestedEvaluation ();
443
- auto *doStmt = doConstructEval.getIf <Fortran::parser::NonLabelDoStmt>();
444
-
445
- const auto &loopControl =
446
- std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t );
447
- const Fortran::parser::LoopControl::Bounds *bounds =
448
- std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u );
449
- Fortran::semantics::Symbol *iv = nullptr ;
450
- if (bounds) {
451
- Fortran::lower::StatementContext stmtCtx;
452
- lowerBound.push_back (fir::getBase (converter.genExprValue (
453
- *Fortran::semantics::GetExpr (bounds->lower ), stmtCtx)));
454
- upperBound.push_back (fir::getBase (converter.genExprValue (
455
- *Fortran::semantics::GetExpr (bounds->upper ), stmtCtx)));
456
- if (bounds->step ) {
457
- step.push_back (fir::getBase (converter.genExprValue (
458
- *Fortran::semantics::GetExpr (bounds->step ), stmtCtx)));
459
- }
460
- // If `step` is not present, assume it as `1`.
461
- else {
462
- step.push_back (firOpBuilder.createIntegerConstant (
463
- currentLocation, firOpBuilder.getIntegerType (32 ), 1 ));
464
- }
465
- iv = bounds->name .thing .symbol ;
466
- }
467
458
459
+ int64_t collapseValue = Fortran::lower::getCollapseValue (wsLoopOpClauseList);
468
460
for (const auto &clause : wsLoopOpClauseList.v ) {
469
461
if (const auto &scheduleClause =
470
462
std::get_if<Fortran::parser::OmpClause::Schedule>(&clause.u )) {
@@ -480,6 +472,41 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
480
472
}
481
473
}
482
474
475
+ // Collect the loops to collapse.
476
+ auto *doConstructEval = &eval.getFirstNestedEvaluation ();
477
+
478
+ SmallVector<const Fortran::semantics::Symbol *> iv;
479
+ do {
480
+ auto *doLoop = &doConstructEval->getFirstNestedEvaluation ();
481
+ auto *doStmt = doLoop->getIf <Fortran::parser::NonLabelDoStmt>();
482
+ assert (doStmt && " Expected do loop to be in the nested evaluation" );
483
+ const auto &loopControl =
484
+ std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t );
485
+ const Fortran::parser::LoopControl::Bounds *bounds =
486
+ std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u );
487
+ if (bounds) {
488
+ Fortran::lower::StatementContext stmtCtx;
489
+ lowerBound.push_back (fir::getBase (converter.genExprValue (
490
+ *Fortran::semantics::GetExpr (bounds->lower ), stmtCtx)));
491
+ upperBound.push_back (fir::getBase (converter.genExprValue (
492
+ *Fortran::semantics::GetExpr (bounds->upper ), stmtCtx)));
493
+ if (bounds->step ) {
494
+ step.push_back (fir::getBase (converter.genExprValue (
495
+ *Fortran::semantics::GetExpr (bounds->step ), stmtCtx)));
496
+ }
497
+ // If `step` is not present, assume it as `1`.
498
+ else {
499
+ step.push_back (firOpBuilder.createIntegerConstant (
500
+ currentLocation, firOpBuilder.getIntegerType (32 ), 1 ));
501
+ }
502
+ iv.push_back (bounds->name .thing .symbol );
503
+ }
504
+
505
+ collapseValue--;
506
+ doConstructEval =
507
+ &*std::next (doConstructEval->getNestedEvaluations ().begin ());
508
+ } while (collapseValue > 0 );
509
+
483
510
// FIXME: Add support for following clauses:
484
511
// 1. linear
485
512
// 2. order
0 commit comments