14
14
15
15
#include " ClauseFinder.h"
16
16
#include " flang/Lower/OpenMP/Clauses.h"
17
+ #include " flang/Evaluate/fold.h"
17
18
#include < flang/Lower/AbstractConverter.h>
18
19
#include < flang/Lower/ConvertType.h>
19
20
#include < flang/Lower/DirectivesCommon.h>
24
25
#include < flang/Parser/parse-tree.h>
25
26
#include < flang/Parser/tools.h>
26
27
#include < flang/Semantics/tools.h>
28
+ #include < flang/Semantics/type.h>
27
29
#include < llvm/Support/CommandLine.h>
28
30
29
31
#include < iterator>
30
32
33
+ using namespace Fortran ::semantics;
34
+
35
+ template <typename T>
36
+ MaybeIntExpr
37
+ EvaluateIntExpr (SemanticsContext &context, const T &expr) {
38
+ if (MaybeExpr maybeExpr{
39
+ Fold (context.foldingContext (), AnalyzeExpr (context, expr))}) {
40
+ if (auto *intExpr{Fortran::evaluate::UnwrapExpr<SomeIntExpr>(*maybeExpr)}) {
41
+ return std::move (*intExpr);
42
+ }
43
+ }
44
+ return std::nullopt;
45
+ }
46
+
47
+ template <typename T>
48
+ std::optional<std::int64_t >
49
+ EvaluateInt64 (SemanticsContext &context, const T &expr) {
50
+ return Fortran::evaluate::ToInt64 (EvaluateIntExpr (context, expr));
51
+ }
52
+
31
53
llvm::cl::opt<bool > treatIndexAsSection (
32
54
" openmp-treat-index-as-section" ,
33
55
llvm::cl::desc (" In the OpenMP data clauses treat `a(N)` as `a(N:N)`." ),
@@ -615,6 +637,43 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
615
637
}
616
638
}
617
639
640
+ // Populates the sizes vector with values if the given OpenMPConstruct
641
+ // Contains a loop construct with an inner tiling construct.
642
+ void collectTileSizesFromOpenMPConstruct (
643
+ const parser::OpenMPConstruct *ompCons,
644
+ llvm::SmallVectorImpl<int64_t > &tileSizes,
645
+ SemanticsContext &semaCtx) {
646
+ if (!ompCons)
647
+ return ;
648
+
649
+ if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u )}) {
650
+ const auto &innerOptional = std::get<
651
+ std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
652
+ ompLoop->t );
653
+ if (innerOptional.has_value ()) {
654
+ const auto &innerLoopDirective = innerOptional.value ().value ();
655
+ const auto &innerBegin =
656
+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
657
+ const auto &innerDirective =
658
+ std::get<parser::OmpLoopDirective>(innerBegin.t ).v ;
659
+
660
+ if (innerDirective == llvm::omp::Directive::OMPD_tile) {
661
+ // Get the size values from parse tree and convert to a vector
662
+ const auto &innerClauseList{
663
+ std::get<parser::OmpClauseList>(innerBegin.t )};
664
+ for (const auto &clause : innerClauseList.v )
665
+ if (const auto tclause{
666
+ std::get_if<parser::OmpClause::Sizes>(&clause.u )}) {
667
+ for (auto &tval : tclause->v ) {
668
+ if (const auto v{EvaluateInt64 (semaCtx, tval)})
669
+ tileSizes.push_back (*v);
670
+ }
671
+ }
672
+ }
673
+ }
674
+ }
675
+ }
676
+
618
677
bool collectLoopRelatedInfo (
619
678
lower::AbstractConverter &converter, mlir::Location currentLocation,
620
679
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
@@ -636,11 +695,34 @@ bool collectLoopRelatedInfo(
636
695
collapseValue = evaluate::ToInt64 (clause->v ).value ();
637
696
found = true ;
638
697
}
698
+
699
+ // Collect sizes from tile directive if present
639
700
std::int64_t sizesLengthValue = 0l ;
640
- if (auto *clause =
641
- ClauseFinder::findUniqueClause<omp::clause::Sizes>(clauses)) {
642
- sizesLengthValue = clause->v .size ();
643
- found = true ;
701
+ if (auto *ompCons{eval.getIf <parser::OpenMPConstruct>()}) {
702
+ if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u )}) {
703
+ const auto &innerOptional = std::get<
704
+ std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
705
+ ompLoop->t );
706
+ if (innerOptional.has_value ()) {
707
+ const auto &innerLoopDirective = innerOptional.value ().value ();
708
+ const auto &innerBegin =
709
+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
710
+ const auto &innerDirective =
711
+ std::get<parser::OmpLoopDirective>(innerBegin.t ).v ;
712
+
713
+ if (innerDirective == llvm::omp::Directive::OMPD_tile) {
714
+ // Get the size values from parse tree and convert to a vector
715
+ const auto &innerClauseList{
716
+ std::get<parser::OmpClauseList>(innerBegin.t )};
717
+ for (const auto &clause : innerClauseList.v )
718
+ if (const auto tclause{
719
+ std::get_if<parser::OmpClause::Sizes>(&clause.u )}) {
720
+ sizesLengthValue = tclause->v .size ();
721
+ found = true ;
722
+ }
723
+ }
724
+ }
725
+ }
644
726
}
645
727
646
728
collapseValue = collapseValue - sizesLengthValue;
0 commit comments