Skip to content

Commit 131ade1

Browse files
committed
Avoid attaching the sizes clause to the parent construct, instead find the tile
sizes through the parse tree when getting the information needed to create the loop nest ops.
1 parent b045881 commit 131ade1

File tree

4 files changed

+99
-23
lines changed

4 files changed

+99
-23
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
using namespace Fortran::lower::omp;
4848
using namespace Fortran::common::openmp;
49+
using namespace Fortran::semantics;
4950

5051
//===----------------------------------------------------------------------===//
5152
// Code generation helper functions
@@ -1690,6 +1691,7 @@ genLoopNestClauses(lower::AbstractConverter &converter,
16901691
int64_t collapseValue = evaluate::ToInt64(collapse.v).value();
16911692
clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr(collapseValue);
16921693
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
1694+
// This case handles the stand-alone tiling construct
16931695
const auto &sizes = std::get<clause::Sizes>(clause.u);
16941696
llvm::SmallVector<int64_t> sizeValues;
16951697
for (auto &size : sizes.v) {
@@ -1699,6 +1701,12 @@ genLoopNestClauses(lower::AbstractConverter &converter,
16991701
clauseOps.tileSizes = sizeValues;
17001702
}
17011703
}
1704+
1705+
llvm::SmallVector<int64_t> sizeValues;
1706+
auto *ompCons{eval.getIf<parser::OpenMPConstruct>()};
1707+
collectTileSizesFromOpenMPConstruct (ompCons, sizeValues, semaCtx);
1708+
if (sizeValues.size() > 0)
1709+
clauseOps.tileSizes = sizeValues;
17021710
}
17031711

17041712
static void genLoopClauses(
@@ -3936,21 +3944,6 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
39363944
List<Clause> clauses = makeClauses(
39373945
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);
39383946

3939-
const auto &innerOptional =
3940-
std::get<std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
3941-
loopConstruct.t);
3942-
if (innerOptional.has_value()) {
3943-
const auto &innerLoopDirective = innerOptional.value().value();
3944-
const auto &innerBegin =
3945-
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
3946-
const auto &innerDirective =
3947-
std::get<parser::OmpLoopDirective>(innerBegin.t);
3948-
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
3949-
clauses.append(
3950-
makeClauses(std::get<parser::OmpClauseList>(innerBegin.t), semaCtx));
3951-
}
3952-
}
3953-
39543947
if (auto &endLoopDirective =
39553948
std::get<std::optional<parser::OmpEndLoopDirective>>(
39563949
loopConstruct.t)) {

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "ClauseFinder.h"
1616
#include "flang/Lower/OpenMP/Clauses.h"
17+
#include "flang/Evaluate/fold.h"
1718
#include <flang/Lower/AbstractConverter.h>
1819
#include <flang/Lower/ConvertType.h>
1920
#include <flang/Lower/DirectivesCommon.h>
@@ -24,10 +25,31 @@
2425
#include <flang/Parser/parse-tree.h>
2526
#include <flang/Parser/tools.h>
2627
#include <flang/Semantics/tools.h>
28+
#include <flang/Semantics/type.h>
2729
#include <llvm/Support/CommandLine.h>
2830

2931
#include <iterator>
3032

33+
using namespace Fortran::semantics;
34+
35+
template <typename T>
36+
MaybeIntExpr
37+
EvaluateIntExpr(SemanticsContext &context, const T &expr) {
38+
if (MaybeExpr maybeExpr{
39+
Fold(context.foldingContext(), AnalyzeExpr(context, expr))}) {
40+
if (auto *intExpr{Fortran::evaluate::UnwrapExpr<SomeIntExpr>(*maybeExpr)}) {
41+
return std::move(*intExpr);
42+
}
43+
}
44+
return std::nullopt;
45+
}
46+
47+
template <typename T>
48+
std::optional<std::int64_t>
49+
EvaluateInt64(SemanticsContext &context, const T &expr) {
50+
return Fortran::evaluate::ToInt64(EvaluateIntExpr(context, expr));
51+
}
52+
3153
llvm::cl::opt<bool> treatIndexAsSection(
3254
"openmp-treat-index-as-section",
3355
llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
@@ -614,6 +636,43 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
614636
}
615637
}
616638

639+
// Populates the sizes vector with values if the given OpenMPConstruct
640+
// Contains a loop construct with an inner tiling construct.
641+
void collectTileSizesFromOpenMPConstruct(
642+
const parser::OpenMPConstruct *ompCons,
643+
llvm::SmallVectorImpl<int64_t> &tileSizes,
644+
SemanticsContext &semaCtx) {
645+
if (!ompCons)
646+
return;
647+
648+
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
649+
const auto &innerOptional = std::get<
650+
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
651+
ompLoop->t);
652+
if (innerOptional.has_value()) {
653+
const auto &innerLoopDirective = innerOptional.value().value();
654+
const auto &innerBegin =
655+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
656+
const auto &innerDirective =
657+
std::get<parser::OmpLoopDirective>(innerBegin.t).v;
658+
659+
if (innerDirective == llvm::omp::Directive::OMPD_tile) {
660+
// Get the size values from parse tree and convert to a vector
661+
const auto &innerClauseList{
662+
std::get<parser::OmpClauseList>(innerBegin.t)};
663+
for (const auto &clause : innerClauseList.v)
664+
if (const auto tclause{
665+
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
666+
for (auto &tval : tclause->v) {
667+
if (const auto v{EvaluateInt64(semaCtx, tval)})
668+
tileSizes.push_back(*v);
669+
}
670+
}
671+
}
672+
}
673+
}
674+
}
675+
617676
bool collectLoopRelatedInfo(
618677
lower::AbstractConverter &converter, mlir::Location currentLocation,
619678
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
@@ -635,11 +694,34 @@ bool collectLoopRelatedInfo(
635694
collapseValue = evaluate::ToInt64(clause->v).value();
636695
found = true;
637696
}
697+
698+
// Collect sizes from tile directive if present
638699
std::int64_t sizesLengthValue = 0l;
639-
if (auto *clause =
640-
ClauseFinder::findUniqueClause<omp::clause::Sizes>(clauses)) {
641-
sizesLengthValue = clause->v.size();
642-
found = true;
700+
if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
701+
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
702+
const auto &innerOptional = std::get<
703+
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
704+
ompLoop->t);
705+
if (innerOptional.has_value()) {
706+
const auto &innerLoopDirective = innerOptional.value().value();
707+
const auto &innerBegin =
708+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
709+
const auto &innerDirective =
710+
std::get<parser::OmpLoopDirective>(innerBegin.t).v;
711+
712+
if (innerDirective == llvm::omp::Directive::OMPD_tile) {
713+
// Get the size values from parse tree and convert to a vector
714+
const auto &innerClauseList{
715+
std::get<parser::OmpClauseList>(innerBegin.t)};
716+
for (const auto &clause : innerClauseList.v)
717+
if (const auto tclause{
718+
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
719+
sizesLengthValue = tclause->v.size();
720+
found = true;
721+
}
722+
}
723+
}
724+
}
643725
}
644726

645727
collapseValue = collapseValue - sizesLengthValue;

flang/lib/Lower/OpenMP/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ bool collectLoopRelatedInfo(
167167
mlir::omp::LoopRelatedClauseOps &result,
168168
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
169169

170+
void collectTileSizesFromOpenMPConstruct(
171+
const parser::OpenMPConstruct *ompCons,
172+
llvm::SmallVectorImpl<int64_t> &tileSizes,
173+
Fortran::semantics::SemanticsContext &semaCtx);
174+
170175
} // namespace omp
171176
} // namespace lower
172177
} // namespace Fortran

llvm/include/llvm/Frontend/OpenMP/ConstructDecompositionT.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,10 +497,6 @@ bool ConstructDecompositionT<C, H>::applyClause(
497497
if (llvm::omp::isAllowedClauseForDirective(last.id, node->id, version)) {
498498
last.clauses.push_back(node);
499499
return true;
500-
} else {
501-
// llvm::errs() << "** OVERRIDING isAllowedClauseForDirective **\n";
502-
last.clauses.push_back(node);
503-
return true;
504500
}
505501
}
506502

0 commit comments

Comments
 (0)