Skip to content

Commit 5939e18

Browse files
committed
Avoid attaching the sizes clause to the parent construct, instead find the tile
sizes through the parse tree when getting the information needed to create the loop nest ops.
1 parent 2923141 commit 5939e18

File tree

4 files changed

+99
-23
lines changed

4 files changed

+99
-23
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
using namespace Fortran::lower::omp;
4949
using namespace Fortran::common::openmp;
50+
using namespace Fortran::semantics;
5051

5152
//===----------------------------------------------------------------------===//
5253
// Code generation helper functions
@@ -1690,6 +1691,7 @@ genLoopNestClauses(lower::AbstractConverter &converter,
16901691
int64_t collapseValue = evaluate::ToInt64(collapse.v).value();
16911692
clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr(collapseValue);
16921693
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
1694+
// This case handles the stand-alone tiling construct
16931695
const auto &sizes = std::get<clause::Sizes>(clause.u);
16941696
llvm::SmallVector<int64_t> sizeValues;
16951697
for (auto &size : sizes.v) {
@@ -1699,6 +1701,12 @@ genLoopNestClauses(lower::AbstractConverter &converter,
16991701
clauseOps.tileSizes = sizeValues;
17001702
}
17011703
}
1704+
1705+
llvm::SmallVector<int64_t> sizeValues;
1706+
auto *ompCons{eval.getIf<parser::OpenMPConstruct>()};
1707+
collectTileSizesFromOpenMPConstruct (ompCons, sizeValues, semaCtx);
1708+
if (sizeValues.size() > 0)
1709+
clauseOps.tileSizes = sizeValues;
17021710
}
17031711

17041712
static void genLoopClauses(
@@ -3961,21 +3969,6 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
39613969
List<Clause> clauses = makeClauses(
39623970
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
39633971

3964-
const auto &innerOptional =
3965-
std::get<std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
3966-
loopConstruct.t);
3967-
if (innerOptional.has_value()) {
3968-
const auto &innerLoopDirective = innerOptional.value().value();
3969-
const auto &innerBegin =
3970-
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
3971-
const auto &innerDirective =
3972-
std::get<parser::OmpLoopDirective>(innerBegin.t);
3973-
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
3974-
clauses.append(
3975-
makeClauses(std::get<parser::OmpClauseList>(innerBegin.t), semaCtx));
3976-
}
3977-
}
3978-
39793972
if (auto &endLoopDirective =
39803973
std::get<std::optional<parser::OmpEndLoopDirective>>(
39813974
loopConstruct.t)) {

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "ClauseFinder.h"
1616
#include "flang/Lower/OpenMP/Clauses.h"
17+
#include "flang/Evaluate/fold.h"
1718
#include <flang/Lower/AbstractConverter.h>
1819
#include <flang/Lower/ConvertType.h>
1920
#include <flang/Lower/DirectivesCommon.h>
@@ -24,10 +25,31 @@
2425
#include <flang/Parser/parse-tree.h>
2526
#include <flang/Parser/tools.h>
2627
#include <flang/Semantics/tools.h>
28+
#include <flang/Semantics/type.h>
2729
#include <llvm/Support/CommandLine.h>
2830

2931
#include <iterator>
3032

33+
using namespace Fortran::semantics;
34+
35+
template <typename T>
36+
MaybeIntExpr
37+
EvaluateIntExpr(SemanticsContext &context, const T &expr) {
38+
if (MaybeExpr maybeExpr{
39+
Fold(context.foldingContext(), AnalyzeExpr(context, expr))}) {
40+
if (auto *intExpr{Fortran::evaluate::UnwrapExpr<SomeIntExpr>(*maybeExpr)}) {
41+
return std::move(*intExpr);
42+
}
43+
}
44+
return std::nullopt;
45+
}
46+
47+
template <typename T>
48+
std::optional<std::int64_t>
49+
EvaluateInt64(SemanticsContext &context, const T &expr) {
50+
return Fortran::evaluate::ToInt64(EvaluateIntExpr(context, expr));
51+
}
52+
3153
llvm::cl::opt<bool> treatIndexAsSection(
3254
"openmp-treat-index-as-section",
3355
llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
@@ -615,6 +637,43 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
615637
}
616638
}
617639

640+
// Populates the sizes vector with values if the given OpenMPConstruct
641+
// Contains a loop construct with an inner tiling construct.
642+
void collectTileSizesFromOpenMPConstruct(
643+
const parser::OpenMPConstruct *ompCons,
644+
llvm::SmallVectorImpl<int64_t> &tileSizes,
645+
SemanticsContext &semaCtx) {
646+
if (!ompCons)
647+
return;
648+
649+
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
650+
const auto &innerOptional = std::get<
651+
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
652+
ompLoop->t);
653+
if (innerOptional.has_value()) {
654+
const auto &innerLoopDirective = innerOptional.value().value();
655+
const auto &innerBegin =
656+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
657+
const auto &innerDirective =
658+
std::get<parser::OmpLoopDirective>(innerBegin.t).v;
659+
660+
if (innerDirective == llvm::omp::Directive::OMPD_tile) {
661+
// Get the size values from parse tree and convert to a vector
662+
const auto &innerClauseList{
663+
std::get<parser::OmpClauseList>(innerBegin.t)};
664+
for (const auto &clause : innerClauseList.v)
665+
if (const auto tclause{
666+
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
667+
for (auto &tval : tclause->v) {
668+
if (const auto v{EvaluateInt64(semaCtx, tval)})
669+
tileSizes.push_back(*v);
670+
}
671+
}
672+
}
673+
}
674+
}
675+
}
676+
618677
bool collectLoopRelatedInfo(
619678
lower::AbstractConverter &converter, mlir::Location currentLocation,
620679
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
@@ -636,11 +695,34 @@ bool collectLoopRelatedInfo(
636695
collapseValue = evaluate::ToInt64(clause->v).value();
637696
found = true;
638697
}
698+
699+
// Collect sizes from tile directive if present
639700
std::int64_t sizesLengthValue = 0l;
640-
if (auto *clause =
641-
ClauseFinder::findUniqueClause<omp::clause::Sizes>(clauses)) {
642-
sizesLengthValue = clause->v.size();
643-
found = true;
701+
if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
702+
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
703+
const auto &innerOptional = std::get<
704+
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
705+
ompLoop->t);
706+
if (innerOptional.has_value()) {
707+
const auto &innerLoopDirective = innerOptional.value().value();
708+
const auto &innerBegin =
709+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
710+
const auto &innerDirective =
711+
std::get<parser::OmpLoopDirective>(innerBegin.t).v;
712+
713+
if (innerDirective == llvm::omp::Directive::OMPD_tile) {
714+
// Get the size values from parse tree and convert to a vector
715+
const auto &innerClauseList{
716+
std::get<parser::OmpClauseList>(innerBegin.t)};
717+
for (const auto &clause : innerClauseList.v)
718+
if (const auto tclause{
719+
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
720+
sizesLengthValue = tclause->v.size();
721+
found = true;
722+
}
723+
}
724+
}
725+
}
644726
}
645727

646728
collapseValue = collapseValue - sizesLengthValue;

flang/lib/Lower/OpenMP/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,11 @@ bool collectLoopRelatedInfo(
175175
mlir::omp::LoopRelatedClauseOps &result,
176176
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
177177

178+
void collectTileSizesFromOpenMPConstruct(
179+
const parser::OpenMPConstruct *ompCons,
180+
llvm::SmallVectorImpl<int64_t> &tileSizes,
181+
Fortran::semantics::SemanticsContext &semaCtx);
182+
178183
} // namespace omp
179184
} // namespace lower
180185
} // namespace Fortran

llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,10 +497,6 @@ bool ConstructDecompositionT<C, H>::applyClause(
497497
if (llvm::omp::isAllowedClauseForDirective(last.id, node->id, version)) {
498498
last.clauses.push_back(node);
499499
return true;
500-
} else {
501-
// llvm::errs() << "** OVERRIDING isAllowedClauseForDirective **\n";
502-
last.clauses.push_back(node);
503-
return true;
504500
}
505501
}
506502

0 commit comments

Comments
 (0)