14
14
15
15
#include " ClauseFinder.h"
16
16
#include " flang/Lower/OpenMP/Clauses.h"
17
+ #include " flang/Evaluate/fold.h"
17
18
#include < flang/Lower/AbstractConverter.h>
18
19
#include < flang/Lower/ConvertType.h>
19
20
#include < flang/Lower/DirectivesCommon.h>
24
25
#include < flang/Parser/parse-tree.h>
25
26
#include < flang/Parser/tools.h>
26
27
#include < flang/Semantics/tools.h>
28
+ #include < flang/Semantics/type.h>
27
29
#include < llvm/Support/CommandLine.h>
28
30
29
31
#include < iterator>
30
32
33
+ using namespace Fortran ::semantics;
34
+
35
+ template <typename T>
36
+ MaybeIntExpr
37
+ EvaluateIntExpr (SemanticsContext &context, const T &expr) {
38
+ if (MaybeExpr maybeExpr{
39
+ Fold (context.foldingContext (), AnalyzeExpr (context, expr))}) {
40
+ if (auto *intExpr{Fortran::evaluate::UnwrapExpr<SomeIntExpr>(*maybeExpr)}) {
41
+ return std::move (*intExpr);
42
+ }
43
+ }
44
+ return std::nullopt;
45
+ }
46
+
47
+ template <typename T>
48
+ std::optional<std::int64_t >
49
+ EvaluateInt64 (SemanticsContext &context, const T &expr) {
50
+ return Fortran::evaluate::ToInt64 (EvaluateIntExpr (context, expr));
51
+ }
52
+
31
53
llvm::cl::opt<bool > treatIndexAsSection (
32
54
" openmp-treat-index-as-section" ,
33
55
llvm::cl::desc (" In the OpenMP data clauses treat `a(N)` as `a(N:N)`." ),
@@ -614,6 +636,43 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
614
636
}
615
637
}
616
638
639
+ // Populates the sizes vector with values if the given OpenMPConstruct
640
+ // Contains a loop construct with an inner tiling construct.
641
+ void collectTileSizesFromOpenMPConstruct (
642
+ const parser::OpenMPConstruct *ompCons,
643
+ llvm::SmallVectorImpl<int64_t > &tileSizes,
644
+ SemanticsContext &semaCtx) {
645
+ if (!ompCons)
646
+ return ;
647
+
648
+ if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u )}) {
649
+ const auto &innerOptional = std::get<
650
+ std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
651
+ ompLoop->t );
652
+ if (innerOptional.has_value ()) {
653
+ const auto &innerLoopDirective = innerOptional.value ().value ();
654
+ const auto &innerBegin =
655
+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
656
+ const auto &innerDirective =
657
+ std::get<parser::OmpLoopDirective>(innerBegin.t ).v ;
658
+
659
+ if (innerDirective == llvm::omp::Directive::OMPD_tile) {
660
+ // Get the size values from parse tree and convert to a vector
661
+ const auto &innerClauseList{
662
+ std::get<parser::OmpClauseList>(innerBegin.t )};
663
+ for (const auto &clause : innerClauseList.v )
664
+ if (const auto tclause{
665
+ std::get_if<parser::OmpClause::Sizes>(&clause.u )}) {
666
+ for (auto &tval : tclause->v ) {
667
+ if (const auto v{EvaluateInt64 (semaCtx, tval)})
668
+ tileSizes.push_back (*v);
669
+ }
670
+ }
671
+ }
672
+ }
673
+ }
674
+ }
675
+
617
676
bool collectLoopRelatedInfo (
618
677
lower::AbstractConverter &converter, mlir::Location currentLocation,
619
678
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
@@ -635,11 +694,34 @@ bool collectLoopRelatedInfo(
635
694
collapseValue = evaluate::ToInt64 (clause->v ).value ();
636
695
found = true ;
637
696
}
697
+
698
+ // Collect sizes from tile directive if present
638
699
std::int64_t sizesLengthValue = 0l ;
639
- if (auto *clause =
640
- ClauseFinder::findUniqueClause<omp::clause::Sizes>(clauses)) {
641
- sizesLengthValue = clause->v .size ();
642
- found = true ;
700
+ if (auto *ompCons{eval.getIf <parser::OpenMPConstruct>()}) {
701
+ if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u )}) {
702
+ const auto &innerOptional = std::get<
703
+ std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
704
+ ompLoop->t );
705
+ if (innerOptional.has_value ()) {
706
+ const auto &innerLoopDirective = innerOptional.value ().value ();
707
+ const auto &innerBegin =
708
+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
709
+ const auto &innerDirective =
710
+ std::get<parser::OmpLoopDirective>(innerBegin.t ).v ;
711
+
712
+ if (innerDirective == llvm::omp::Directive::OMPD_tile) {
713
+ // Get the size values from parse tree and convert to a vector
714
+ const auto &innerClauseList{
715
+ std::get<parser::OmpClauseList>(innerBegin.t )};
716
+ for (const auto &clause : innerClauseList.v )
717
+ if (const auto tclause{
718
+ std::get_if<parser::OmpClause::Sizes>(&clause.u )}) {
719
+ sizesLengthValue = tclause->v .size ();
720
+ found = true ;
721
+ }
722
+ }
723
+ }
724
+ }
643
725
}
644
726
645
727
collapseValue = collapseValue - sizesLengthValue;
0 commit comments