@@ -404,6 +404,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
404
404
return ;
405
405
406
406
const parser::OmpClauseList *beginClauseList = nullptr ;
407
+ const parser::OmpClauseList *middleClauseList = nullptr ;
407
408
const parser::OmpClauseList *endClauseList = nullptr ;
408
409
common::visit (
409
410
common::visitors{
@@ -418,6 +419,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
418
419
beginClauseList =
419
420
&std::get<parser::OmpClauseList>(beginDirective.t );
420
421
422
+ // FIXME(JAN): For now we check if there is an inner
423
+ // OpenMPLoopConstruct, and extract the size clause from there
424
+ const auto &innerOptional = std::get<std::optional<
425
+ common::Indirection<parser::OpenMPLoopConstruct>>>(
426
+ ompConstruct.t );
427
+ if (innerOptional.has_value ()) {
428
+ const auto &innerLoopDirective = innerOptional.value ().value ();
429
+ const auto &innerBegin =
430
+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
431
+ const auto &innerDirective =
432
+ std::get<parser::OmpLoopDirective>(innerBegin.t );
433
+ if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
434
+ middleClauseList =
435
+ &std::get<parser::OmpClauseList>(innerBegin.t );
436
+ }
437
+ }
421
438
if (auto &endDirective =
422
439
std::get<std::optional<parser::OmpEndLoopDirective>>(
423
440
ompConstruct.t )) {
@@ -431,6 +448,9 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
431
448
assert (beginClauseList && " expected begin directive" );
432
449
clauses.append (makeClauses (*beginClauseList, semaCtx));
433
450
451
+ if (middleClauseList)
452
+ clauses.append (makeClauses (*middleClauseList, semaCtx));
453
+
434
454
if (endClauseList)
435
455
clauses.append (makeClauses (*endClauseList, semaCtx));
436
456
};
@@ -910,6 +930,7 @@ static void genLoopVars(
910
930
storeOp =
911
931
createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
912
932
}
933
+
913
934
firOpBuilder.setInsertionPointAfter (storeOp);
914
935
}
915
936
@@ -1660,6 +1681,23 @@ genLoopNestClauses(lower::AbstractConverter &converter,
1660
1681
cp.processCollapse (loc, eval, clauseOps, iv);
1661
1682
1662
1683
clauseOps.loopInclusive = converter.getFirOpBuilder ().getUnitAttr ();
1684
+
1685
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
1686
+ for (auto &clause : clauses) {
1687
+ if (clause.id == llvm::omp::Clause::OMPC_collapse) {
1688
+ const auto &collapse = std::get<clause::Collapse>(clause.u );
1689
+ int64_t collapseValue = evaluate::ToInt64 (collapse.v ).value ();
1690
+ clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr (collapseValue);
1691
+ } else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
1692
+ const auto &sizes = std::get<clause::Sizes>(clause.u );
1693
+ llvm::SmallVector<int64_t > sizeValues;
1694
+ for (auto &size : sizes.v ) {
1695
+ int64_t sizeValue = evaluate::ToInt64 (size).value ();
1696
+ sizeValues.push_back (sizeValue);
1697
+ }
1698
+ clauseOps.tileSizes = sizeValues;
1699
+ }
1700
+ }
1663
1701
}
1664
1702
1665
1703
static void genLoopClauses (
@@ -2036,9 +2074,9 @@ static mlir::omp::LoopNestOp genLoopNestOp(
2036
2074
return llvm::SmallVector<const semantics::Symbol *>(iv);
2037
2075
};
2038
2076
2039
- auto *nestedEval =
2040
- getCollapsedLoopEval (eval, getCollapseValue (item-> clauses )) ;
2041
-
2077
+ uint64_t nestValue = getCollapseValue (item-> clauses );
2078
+ nestValue = nestValue < iv. size () ? iv. size () : nestValue ;
2079
+ auto *nestedEval = getCollapsedLoopEval (eval, nestValue);
2042
2080
return genOpWithBody<mlir::omp::LoopNestOp>(
2043
2081
OpWithBodyGenInfo (converter, symTable, semaCtx, loc, *nestedEval,
2044
2082
directive)
@@ -3890,6 +3928,20 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
3890
3928
std::get<parser::OmpBeginLoopDirective>(loopConstruct.t );
3891
3929
List<Clause> clauses = makeClauses (
3892
3930
std::get<parser::OmpClauseList>(beginLoopDirective.t ), semaCtx);
3931
+
3932
+ const auto &innerOptional = std::get<std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(loopConstruct.t );
3933
+ if (innerOptional.has_value ()) {
3934
+ const auto &innerLoopDirective = innerOptional.value ().value ();
3935
+ const auto &innerBegin =
3936
+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
3937
+ const auto &innerDirective =
3938
+ std::get<parser::OmpLoopDirective>(innerBegin.t );
3939
+ if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
3940
+ clauses.append (
3941
+ makeClauses (std::get<parser::OmpClauseList>(innerBegin.t ), semaCtx));
3942
+ }
3943
+ }
3944
+
3893
3945
if (auto &endLoopDirective =
3894
3946
std::get<std::optional<parser::OmpEndLoopDirective>>(
3895
3947
loopConstruct.t )) {
@@ -4021,18 +4073,6 @@ void Fortran::lower::genOpenMPSymbolProperties(
4021
4073
lower::genDeclareTargetIntGlobal (converter, var);
4022
4074
}
4023
4075
4024
- int64_t
4025
- Fortran::lower::getCollapseValue (const parser::OmpClauseList &clauseList) {
4026
- for (const parser::OmpClause &clause : clauseList.v ) {
4027
- if (const auto &collapseClause =
4028
- std::get_if<parser::OmpClause::Collapse>(&clause.u )) {
4029
- const auto *expr = semantics::GetExpr (collapseClause->v );
4030
- return evaluate::ToInt64 (*expr).value ();
4031
- }
4032
- }
4033
- return 1 ;
4034
- }
4035
-
4036
4076
void Fortran::lower::genThreadprivateOp (lower::AbstractConverter &converter,
4037
4077
const lower::pft::Variable &var) {
4038
4078
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
0 commit comments