@@ -403,6 +403,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
403
403
return ;
404
404
405
405
const parser::OmpClauseList *beginClauseList = nullptr ;
406
+ const parser::OmpClauseList *middleClauseList = nullptr ;
406
407
const parser::OmpClauseList *endClauseList = nullptr ;
407
408
common::visit (
408
409
common::visitors{
@@ -417,6 +418,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
417
418
beginClauseList =
418
419
&std::get<parser::OmpClauseList>(beginDirective.t );
419
420
421
+ // FIXME(JAN): For now we check if there is an inner
422
+ // OpenMPLoopConstruct, and extract the size clause from there
423
+ const auto &innerOptional = std::get<std::optional<
424
+ common::Indirection<parser::OpenMPLoopConstruct>>>(
425
+ ompConstruct.t );
426
+ if (innerOptional.has_value ()) {
427
+ const auto &innerLoopDirective = innerOptional.value ().value ();
428
+ const auto &innerBegin =
429
+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
430
+ const auto &innerDirective =
431
+ std::get<parser::OmpLoopDirective>(innerBegin.t );
432
+ if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
433
+ middleClauseList =
434
+ &std::get<parser::OmpClauseList>(innerBegin.t );
435
+ }
436
+ }
420
437
if (auto &endDirective =
421
438
std::get<std::optional<parser::OmpEndLoopDirective>>(
422
439
ompConstruct.t )) {
@@ -430,6 +447,9 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
430
447
assert (beginClauseList && " expected begin directive" );
431
448
clauses.append (makeClauses (*beginClauseList, semaCtx));
432
449
450
+ if (middleClauseList)
451
+ clauses.append (makeClauses (*middleClauseList, semaCtx));
452
+
433
453
if (endClauseList)
434
454
clauses.append (makeClauses (*endClauseList, semaCtx));
435
455
};
@@ -910,6 +930,7 @@ static void genLoopVars(
910
930
storeOp =
911
931
createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
912
932
}
933
+
913
934
firOpBuilder.setInsertionPointAfter (storeOp);
914
935
}
915
936
@@ -1660,6 +1681,23 @@ genLoopNestClauses(lower::AbstractConverter &converter,
1660
1681
cp.processCollapse (loc, eval, clauseOps, iv);
1661
1682
1662
1683
clauseOps.loopInclusive = converter.getFirOpBuilder ().getUnitAttr ();
1684
+
1685
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
1686
+ for (auto &clause : clauses) {
1687
+ if (clause.id == llvm::omp::Clause::OMPC_collapse) {
1688
+ const auto &collapse = std::get<clause::Collapse>(clause.u );
1689
+ int64_t collapseValue = evaluate::ToInt64 (collapse.v ).value ();
1690
+ clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr (collapseValue);
1691
+ } else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
1692
+ const auto &sizes = std::get<clause::Sizes>(clause.u );
1693
+ llvm::SmallVector<int64_t > sizeValues;
1694
+ for (auto &size : sizes.v ) {
1695
+ int64_t sizeValue = evaluate::ToInt64 (size).value ();
1696
+ sizeValues.push_back (sizeValue);
1697
+ }
1698
+ clauseOps.tileSizes = sizeValues;
1699
+ }
1700
+ }
1663
1701
}
1664
1702
1665
1703
static void genLoopClauses (
@@ -2036,9 +2074,9 @@ static mlir::omp::LoopNestOp genLoopNestOp(
2036
2074
return llvm::SmallVector<const semantics::Symbol *>(iv);
2037
2075
};
2038
2076
2039
- auto *nestedEval =
2040
- getCollapsedLoopEval (eval, getCollapseValue (item-> clauses )) ;
2041
-
2077
+ uint64_t nestValue = getCollapseValue (item-> clauses );
2078
+ nestValue = nestValue < iv. size () ? iv. size () : nestValue ;
2079
+ auto *nestedEval = getCollapsedLoopEval (eval, nestValue);
2042
2080
return genOpWithBody<mlir::omp::LoopNestOp>(
2043
2081
OpWithBodyGenInfo (converter, symTable, semaCtx, loc, *nestedEval,
2044
2082
directive)
@@ -3863,6 +3901,20 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
3863
3901
std::get<parser::OmpBeginLoopDirective>(loopConstruct.t );
3864
3902
List<Clause> clauses = makeClauses (
3865
3903
std::get<parser::OmpClauseList>(beginLoopDirective.t ), semaCtx);
3904
+
3905
+ const auto &innerOptional = std::get<std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(loopConstruct.t );
3906
+ if (innerOptional.has_value ()) {
3907
+ const auto &innerLoopDirective = innerOptional.value ().value ();
3908
+ const auto &innerBegin =
3909
+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
3910
+ const auto &innerDirective =
3911
+ std::get<parser::OmpLoopDirective>(innerBegin.t );
3912
+ if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
3913
+ clauses.append (
3914
+ makeClauses (std::get<parser::OmpClauseList>(innerBegin.t ), semaCtx));
3915
+ }
3916
+ }
3917
+
3866
3918
if (auto &endLoopDirective =
3867
3919
std::get<std::optional<parser::OmpEndLoopDirective>>(
3868
3920
loopConstruct.t )) {
@@ -3994,18 +4046,6 @@ void Fortran::lower::genOpenMPSymbolProperties(
3994
4046
lower::genDeclareTargetIntGlobal (converter, var);
3995
4047
}
3996
4048
3997
- int64_t
3998
- Fortran::lower::getCollapseValue (const parser::OmpClauseList &clauseList) {
3999
- for (const parser::OmpClause &clause : clauseList.v ) {
4000
- if (const auto &collapseClause =
4001
- std::get_if<parser::OmpClause::Collapse>(&clause.u )) {
4002
- const auto *expr = semantics::GetExpr (collapseClause->v );
4003
- return evaluate::ToInt64 (*expr).value ();
4004
- }
4005
- }
4006
- return 1 ;
4007
- }
4008
-
4009
4049
void Fortran::lower::genThreadprivateOp (lower::AbstractConverter &converter,
4010
4050
const lower::pft::Variable &var) {
4011
4051
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
0 commit comments