Skip to content

Commit 3029793

Browse files
committed
Avoid attaching the sizes clause to the parent construct, instead find the tile
sizes through the parse tree when getting the information needed to create the loop nest ops.
1 parent 5a3d8d2 commit 3029793

File tree

4 files changed

+100
-23
lines changed

4 files changed

+100
-23
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
using namespace Fortran::lower::omp;
4747
using namespace Fortran::common::openmp;
48+
using namespace Fortran::semantics;
4849

4950
static llvm::cl::opt<bool> DumpAtomicAnalysis("fdebug-dump-atomic-analysis");
5051

@@ -1742,6 +1743,7 @@ genLoopNestClauses(lower::AbstractConverter &converter,
17421743
int64_t collapseValue = evaluate::ToInt64(collapse.v).value();
17431744
clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr(collapseValue);
17441745
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
1746+
// This case handles the stand-alone tiling construct
17451747
const auto &sizes = std::get<clause::Sizes>(clause.u);
17461748
llvm::SmallVector<int64_t> sizeValues;
17471749
for (auto &size : sizes.v) {
@@ -1751,6 +1753,12 @@ genLoopNestClauses(lower::AbstractConverter &converter,
17511753
clauseOps.tileSizes = sizeValues;
17521754
}
17531755
}
1756+
1757+
llvm::SmallVector<int64_t> sizeValues;
1758+
auto *ompCons{eval.getIf<parser::OpenMPConstruct>()};
1759+
collectTileSizesFromOpenMPConstruct (ompCons, sizeValues, semaCtx);
1760+
if (sizeValues.size() > 0)
1761+
clauseOps.tileSizes = sizeValues;
17541762
}
17551763

17561764
static void genLoopClauses(
@@ -4228,21 +4236,6 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
42284236
List<Clause> clauses = makeClauses(
42294237
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
42304238

4231-
const auto &innerOptional =
4232-
std::get<std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
4233-
loopConstruct.t);
4234-
if (innerOptional.has_value()) {
4235-
const auto &innerLoopDirective = innerOptional.value().value();
4236-
const auto &innerBegin =
4237-
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
4238-
const auto &innerDirective =
4239-
std::get<parser::OmpLoopDirective>(innerBegin.t);
4240-
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
4241-
clauses.append(
4242-
makeClauses(std::get<parser::OmpClauseList>(innerBegin.t), semaCtx));
4243-
}
4244-
}
4245-
42464239
if (auto &endLoopDirective =
42474240
std::get<std::optional<parser::OmpEndLoopDirective>>(
42484241
loopConstruct.t)) {

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "Clauses.h"
1616

1717
#include "ClauseFinder.h"
18+
#include "flang/Evaluate/fold.h"
1819
#include <flang/Lower/AbstractConverter.h>
1920
#include <flang/Lower/ConvertType.h>
2021
#include <flang/Lower/DirectivesCommon.h>
@@ -24,10 +25,31 @@
2425
#include <flang/Parser/parse-tree.h>
2526
#include <flang/Parser/tools.h>
2627
#include <flang/Semantics/tools.h>
28+
#include <flang/Semantics/type.h>
2729
#include <llvm/Support/CommandLine.h>
2830

2931
#include <iterator>
3032

33+
using namespace Fortran::semantics;
34+
35+
template <typename T>
36+
MaybeIntExpr
37+
EvaluateIntExpr(SemanticsContext &context, const T &expr) {
38+
if (MaybeExpr maybeExpr{
39+
Fold(context.foldingContext(), AnalyzeExpr(context, expr))}) {
40+
if (auto *intExpr{Fortran::evaluate::UnwrapExpr<SomeIntExpr>(*maybeExpr)}) {
41+
return std::move(*intExpr);
42+
}
43+
}
44+
return std::nullopt;
45+
}
46+
47+
template <typename T>
48+
std::optional<std::int64_t>
49+
EvaluateInt64(SemanticsContext &context, const T &expr) {
50+
return Fortran::evaluate::ToInt64(EvaluateIntExpr(context, expr));
51+
}
52+
3153
llvm::cl::opt<bool> treatIndexAsSection(
3254
"openmp-treat-index-as-section",
3355
llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
@@ -613,6 +635,43 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
613635
}
614636
}
615637

638+
// Populates the sizes vector with values if the given OpenMPConstruct
639+
// Contains a loop construct with an inner tiling construct.
640+
void collectTileSizesFromOpenMPConstruct(
641+
const parser::OpenMPConstruct *ompCons,
642+
llvm::SmallVectorImpl<int64_t> &tileSizes,
643+
SemanticsContext &semaCtx) {
644+
if (!ompCons)
645+
return;
646+
647+
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
648+
const auto &innerOptional = std::get<
649+
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
650+
ompLoop->t);
651+
if (innerOptional.has_value()) {
652+
const auto &innerLoopDirective = innerOptional.value().value();
653+
const auto &innerBegin =
654+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
655+
const auto &innerDirective =
656+
std::get<parser::OmpLoopDirective>(innerBegin.t).v;
657+
658+
if (innerDirective == llvm::omp::Directive::OMPD_tile) {
659+
// Get the size values from parse tree and convert to a vector
660+
const auto &innerClauseList{
661+
std::get<parser::OmpClauseList>(innerBegin.t)};
662+
for (const auto &clause : innerClauseList.v)
663+
if (const auto tclause{
664+
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
665+
for (auto &tval : tclause->v) {
666+
if (const auto v{EvaluateInt64(semaCtx, tval)})
667+
tileSizes.push_back(*v);
668+
}
669+
}
670+
}
671+
}
672+
}
673+
}
674+
616675
bool collectLoopRelatedInfo(
617676
lower::AbstractConverter &converter, mlir::Location currentLocation,
618677
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
@@ -634,11 +693,34 @@ bool collectLoopRelatedInfo(
634693
collapseValue = evaluate::ToInt64(clause->v).value();
635694
found = true;
636695
}
696+
697+
// Collect sizes from tile directive if present
637698
std::int64_t sizesLengthValue = 0l;
638-
if (auto *clause =
639-
ClauseFinder::findUniqueClause<omp::clause::Sizes>(clauses)) {
640-
sizesLengthValue = clause->v.size();
641-
found = true;
699+
if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
700+
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
701+
const auto &innerOptional = std::get<
702+
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
703+
ompLoop->t);
704+
if (innerOptional.has_value()) {
705+
const auto &innerLoopDirective = innerOptional.value().value();
706+
const auto &innerBegin =
707+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
708+
const auto &innerDirective =
709+
std::get<parser::OmpLoopDirective>(innerBegin.t).v;
710+
711+
if (innerDirective == llvm::omp::Directive::OMPD_tile) {
712+
// Get the size values from parse tree and convert to a vector
713+
const auto &innerClauseList{
714+
std::get<parser::OmpClauseList>(innerBegin.t)};
715+
for (const auto &clause : innerClauseList.v)
716+
if (const auto tclause{
717+
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
718+
sizesLengthValue = tclause->v.size();
719+
found = true;
720+
}
721+
}
722+
}
723+
}
642724
}
643725

644726
collapseValue = collapseValue - sizesLengthValue;

flang/lib/Lower/OpenMP/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ bool collectLoopRelatedInfo(
166166
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
167167
mlir::omp::LoopRelatedClauseOps &result,
168168
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
169+
170+
void collectTileSizesFromOpenMPConstruct(
171+
const parser::OpenMPConstruct *ompCons,
172+
llvm::SmallVectorImpl<int64_t> &tileSizes,
173+
Fortran::semantics::SemanticsContext &semaCtx);
174+
169175
} // namespace omp
170176
} // namespace lower
171177
} // namespace Fortran

llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,10 +497,6 @@ bool ConstructDecompositionT<C, H>::applyClause(
497497
if (llvm::omp::isAllowedClauseForDirective(last.id, node->id, version)) {
498498
last.clauses.push_back(node);
499499
return true;
500-
} else {
501-
// llvm::errs() << "** OVERRIDING isAllowedClauseForDirective **\n";
502-
last.clauses.push_back(node);
503-
return true;
504500
}
505501
}
506502

0 commit comments

Comments
 (0)