15
15
#include " Clauses.h"
16
16
17
17
#include " ClauseFinder.h"
18
+ #include " flang/Evaluate/fold.h"
18
19
#include < flang/Lower/AbstractConverter.h>
19
20
#include < flang/Lower/ConvertType.h>
20
21
#include < flang/Lower/DirectivesCommon.h>
24
25
#include < flang/Parser/parse-tree.h>
25
26
#include < flang/Parser/tools.h>
26
27
#include < flang/Semantics/tools.h>
28
+ #include < flang/Semantics/type.h>
27
29
#include < llvm/Support/CommandLine.h>
28
30
29
31
#include < iterator>
30
32
33
+ using namespace Fortran ::semantics;
34
+
35
+ template <typename T>
36
+ MaybeIntExpr
37
+ EvaluateIntExpr (SemanticsContext &context, const T &expr) {
38
+ if (MaybeExpr maybeExpr{
39
+ Fold (context.foldingContext (), AnalyzeExpr (context, expr))}) {
40
+ if (auto *intExpr{Fortran::evaluate::UnwrapExpr<SomeIntExpr>(*maybeExpr)}) {
41
+ return std::move (*intExpr);
42
+ }
43
+ }
44
+ return std::nullopt;
45
+ }
46
+
47
+ template <typename T>
48
+ std::optional<std::int64_t >
49
+ EvaluateInt64 (SemanticsContext &context, const T &expr) {
50
+ return Fortran::evaluate::ToInt64 (EvaluateIntExpr (context, expr));
51
+ }
52
+
31
53
llvm::cl::opt<bool > treatIndexAsSection (
32
54
" openmp-treat-index-as-section" ,
33
55
llvm::cl::desc (" In the OpenMP data clauses treat `a(N)` as `a(N:N)`." ),
@@ -613,6 +635,43 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
613
635
}
614
636
}
615
637
638
+ // Populates the sizes vector with values if the given OpenMPConstruct
639
+ // Contains a loop construct with an inner tiling construct.
640
+ void collectTileSizesFromOpenMPConstruct (
641
+ const parser::OpenMPConstruct *ompCons,
642
+ llvm::SmallVectorImpl<int64_t > &tileSizes,
643
+ SemanticsContext &semaCtx) {
644
+ if (!ompCons)
645
+ return ;
646
+
647
+ if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u )}) {
648
+ const auto &innerOptional = std::get<
649
+ std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
650
+ ompLoop->t );
651
+ if (innerOptional.has_value ()) {
652
+ const auto &innerLoopDirective = innerOptional.value ().value ();
653
+ const auto &innerBegin =
654
+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
655
+ const auto &innerDirective =
656
+ std::get<parser::OmpLoopDirective>(innerBegin.t ).v ;
657
+
658
+ if (innerDirective == llvm::omp::Directive::OMPD_tile) {
659
+ // Get the size values from parse tree and convert to a vector
660
+ const auto &innerClauseList{
661
+ std::get<parser::OmpClauseList>(innerBegin.t )};
662
+ for (const auto &clause : innerClauseList.v )
663
+ if (const auto tclause{
664
+ std::get_if<parser::OmpClause::Sizes>(&clause.u )}) {
665
+ for (auto &tval : tclause->v ) {
666
+ if (const auto v{EvaluateInt64 (semaCtx, tval)})
667
+ tileSizes.push_back (*v);
668
+ }
669
+ }
670
+ }
671
+ }
672
+ }
673
+ }
674
+
616
675
bool collectLoopRelatedInfo (
617
676
lower::AbstractConverter &converter, mlir::Location currentLocation,
618
677
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
@@ -634,11 +693,34 @@ bool collectLoopRelatedInfo(
634
693
collapseValue = evaluate::ToInt64 (clause->v ).value ();
635
694
found = true ;
636
695
}
696
+
697
+ // Collect sizes from tile directive if present
637
698
std::int64_t sizesLengthValue = 0l ;
638
- if (auto *clause =
639
- ClauseFinder::findUniqueClause<omp::clause::Sizes>(clauses)) {
640
- sizesLengthValue = clause->v .size ();
641
- found = true ;
699
+ if (auto *ompCons{eval.getIf <parser::OpenMPConstruct>()}) {
700
+ if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u )}) {
701
+ const auto &innerOptional = std::get<
702
+ std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
703
+ ompLoop->t );
704
+ if (innerOptional.has_value ()) {
705
+ const auto &innerLoopDirective = innerOptional.value ().value ();
706
+ const auto &innerBegin =
707
+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
708
+ const auto &innerDirective =
709
+ std::get<parser::OmpLoopDirective>(innerBegin.t ).v ;
710
+
711
+ if (innerDirective == llvm::omp::Directive::OMPD_tile) {
712
+ // Get the size values from parse tree and convert to a vector
713
+ const auto &innerClauseList{
714
+ std::get<parser::OmpClauseList>(innerBegin.t )};
715
+ for (const auto &clause : innerClauseList.v )
716
+ if (const auto tclause{
717
+ std::get_if<parser::OmpClause::Sizes>(&clause.u )}) {
718
+ sizesLengthValue = tclause->v .size ();
719
+ found = true ;
720
+ }
721
+ }
722
+ }
723
+ }
642
724
}
643
725
644
726
collapseValue = collapseValue - sizesLengthValue;
0 commit comments