1010#include " flang/Parser/parse-tree-visitor.h"
1111#include " flang/Parser/parse-tree.h"
1212#include " flang/Semantics/semantics.h"
13+ #include " flang/Semantics/openmp-directive-sets.h"
1314
1415// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
1516// Constructs more structured which provide explicit scopes for later
@@ -137,30 +138,42 @@ class CanonicalizationOfOmp {
137138 " A DO loop must follow the %s directive" _err_en_US,
138139 parser::ToUpperCaseLetters (dirName.source .ToString ()));
139140 };
140- auto tileUnrollError = [](const parser::OmpDirectiveName &dirName,
141+ auto transformUnrollError = [](const parser::OmpDirectiveName &dirName,
141142 parser::Messages &messages) {
142143 messages.Say (dirName.source ,
143- " If a loop construct has been fully unrolled, it cannot then be tiled " _err_en_US,
144+ " If a loop construct has been fully unrolled, it cannot then be further transformed " _err_en_US,
144145 parser::ToUpperCaseLetters (dirName.source .ToString ()));
145146 };
147+ auto missingEndFuse = [](auto &dir, auto &messages) {
148+ messages.Say (dir.source ,
149+ " The %s construct requires the END FUSE directive" _err_en_US,
150+ parser::ToUpperCaseLetters (dir.source .ToString ()));
151+ };
152+
153+ bool endFuseNeeded = beginName.v == llvm::omp::Directive::OMPD_fuse;
146154
147155 nextIt = it;
148- while (++nextIt != block.end ()) {
156+ nextIt++;
157+ while (nextIt != block.end ()) {
149158 // Ignore compiler directives.
150- if (GetConstructIf<parser::CompilerDirective>(*nextIt))
159+ if (GetConstructIf<parser::CompilerDirective>(*nextIt)) {
160+ nextIt++;
151161 continue ;
162+ }
152163
153164 if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
154165 if (doCons->GetLoopControl ()) {
155166 // move DoConstruct
156- std::get<std::optional<std::variant<parser::DoConstruct,
157- common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t ) =
158- std::move (*doCons);
167+ std::get<std::list<parser::NestedConstruct>>(x.t ).push_back (
168+ std::move (*doCons));
159169 nextIt = block.erase (nextIt);
160170 // try to match OmpEndLoopDirective
161171 if (nextIt != block.end ()) {
162172 if (auto *endDir{
163173 GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
174+ auto &endDirName = endDir->DirName ();
175+ if (endDirName.v == llvm::omp::Directive::OMPD_fuse)
176+ endFuseNeeded = false ;
164177 std::get<std::optional<parser::OmpEndLoopDirective>>(x.t ) =
165178 std::move (*endDir);
166179 nextIt = block.erase (nextIt);
@@ -170,17 +183,37 @@ class CanonicalizationOfOmp {
170183 messages_.Say (beginName.source ,
171184 " DO loop after the %s directive must have loop control" _err_en_US,
172185 parser::ToUpperCaseLetters (beginName.source .ToString ()));
186+ endFuseNeeded = false ;
173187 }
174188 } else if (auto *ompLoopCons{
175189 GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
176190 // We should allow UNROLL and TILE constructs to be inserted between an
177191 // OpenMP Loop Construct and the DO loop itself
178192 auto &nestedBeginDirective = ompLoopCons->BeginDir ();
179193 auto &nestedBeginName = nestedBeginDirective.DirName ();
180- if ((nestedBeginName.v == llvm::omp::Directive::OMPD_unroll ||
181- nestedBeginName.v == llvm::omp::Directive::OMPD_tile) &&
182- !(nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
183- beginName.v == llvm::omp::Directive::OMPD_tile)) {
194+ if (llvm::omp::loopTransformationSet.test (nestedBeginName.v )) {
195+ if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
196+ llvm::omp::loopTransformationSet.test (beginName.v )) {
197+ // if a loop has been unrolled, the user can not then tile that loop
198+ // as it has been unrolled
199+ const parser::OmpClauseList &unrollClauseList{
200+ nestedBeginDirective.Clauses ()};
201+ if (unrollClauseList.v .empty ()) {
202+ // if the clause list is empty for an unroll construct, we assume
203+ // the loop is being fully unrolled
204+ transformUnrollError (beginName, messages_);
205+ endFuseNeeded = false ;
206+ } else {
207+ // parse the clauses for the unroll directive to find the full
208+ // clause
209+ for (auto &clause : unrollClauseList.v ) {
210+ if (clause.Id () == llvm::omp::OMPC_full) {
211+ transformUnrollError (beginName, messages_);
212+ endFuseNeeded = false ;
213+ }
214+ }
215+ }
216+ }
184217 // iterate through the remaining block items to find the end directive
185218 // for the unroll/tile directive.
186219 parser::Block::iterator endIt;
@@ -190,6 +223,8 @@ class CanonicalizationOfOmp {
190223 GetConstructIf<parser::OmpEndLoopDirective>(*endIt)}) {
191224 auto &endDirName = endDir->DirName ();
192225 if (endDirName.v == beginName.v ) {
226+ if (endDirName.v == llvm::omp::Directive::OMPD_fuse)
227+ endFuseNeeded = false ;
193228 std::get<std::optional<parser::OmpEndLoopDirective>>(x.t ) =
194229 std::move (*endDir);
195230 endIt = block.erase (endIt);
@@ -199,43 +234,30 @@ class CanonicalizationOfOmp {
199234 ++endIt;
200235 }
201236 RewriteOpenMPLoopConstruct (*ompLoopCons, block, nextIt);
202- auto &ompLoop = std::get<std::optional<parser::NestedConstruct>>(x.t );
203- ompLoop =
204- std::optional<parser::NestedConstruct>{parser::NestedConstruct{
205- common::Indirection{std::move (*ompLoopCons)}}};
237+ auto &loopConsList = std::get<std::list<parser::NestedConstruct>>(x.t );
238+ loopConsList.push_back (parser::NestedConstruct{
239+ common::Indirection{std::move (*ompLoopCons)}});
206240 nextIt = block.erase (nextIt);
207- } else if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
208- beginName.v == llvm::omp::Directive::OMPD_tile) {
209- // if a loop has been unrolled, the user can not then tile that loop
210- // as it has been unrolled
211- const parser::OmpClauseList &unrollClauseList{
212- nestedBeginDirective.Clauses ()};
213- if (unrollClauseList.v .empty ()) {
214- // if the clause list is empty for an unroll construct, we assume
215- // the loop is being fully unrolled
216- tileUnrollError (beginName, messages_);
217- } else {
218- // parse the clauses for the unroll directive to find the full
219- // clause
220- for (auto &clause : unrollClauseList.v ) {
221- if (clause.Id () == llvm::omp::OMPC_full) {
222- tileUnrollError (beginName, messages_);
223- }
224- }
225- }
226241 } else {
227242 messages_.Say (nestedBeginName.source ,
228243 " Only Loop Transformation Constructs or Loop Nests can be nested within Loop Constructs" _err_en_US,
229244 parser::ToUpperCaseLetters (nestedBeginName.source .ToString ()));
245+ endFuseNeeded = false ;
230246 }
231247 } else {
232248 missingDoConstruct (beginName, messages_);
249+ endFuseNeeded = false ;
233250 }
251+ if (endFuseNeeded)
252+ continue ;
234253 // If we get here, we either found a loop, or issued an error message.
235254 return ;
236255 }
237256 if (nextIt == block.end ()) {
238- missingDoConstruct (beginName, messages_);
257+ if (endFuseNeeded)
258+ missingEndFuse (beginName, messages_);
259+ else
260+ missingDoConstruct (beginName, messages_);
239261 }
240262 }
241263
0 commit comments