99#include " canonicalize-omp.h"
1010#include " flang/Parser/parse-tree-visitor.h"
1111#include " flang/Parser/parse-tree.h"
12+ #include " flang/Semantics/openmp-directive-sets.h"
1213#include " flang/Semantics/semantics.h"
1314
1415// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
@@ -136,20 +137,30 @@ class CanonicalizationOfOmp {
136137 " A DO loop must follow the %s directive" _err_en_US,
137138 parser::ToUpperCaseLetters (dirName.source .ToString ()));
138139 };
139- auto tileUnrollError = [](const parser::OmpDirectiveName &dirName,
140- parser::Messages &messages) {
140+ auto transformUnrollError = [](const parser::OmpDirectiveName &dirName,
141+ parser::Messages &messages) {
141142 messages.Say (dirName.source ,
142- " If a loop construct has been fully unrolled, it cannot then be tiled " _err_en_US,
143+ " If a loop construct has been fully unrolled, it cannot then be further transformed " _err_en_US,
143144 parser::ToUpperCaseLetters (dirName.source .ToString ()));
144145 };
146+ auto missingEndFuse = [](auto &dir, auto &messages) {
147+ messages.Say (dir.source ,
148+ " The %s construct requires the END FUSE directive" _err_en_US,
149+ parser::ToUpperCaseLetters (dir.source .ToString ()));
150+ };
151+
152+ bool endFuseNeeded = beginName.v == llvm::omp::Directive::OMPD_fuse;
145153
146154 auto &body{std::get<parser::Block>(x.t )};
147155
148156 nextIt = it;
149- while (++nextIt != block.end ()) {
157+ nextIt++;
158+ while (nextIt != block.end ()) {
150159 // Ignore compiler directives.
151- if (GetConstructIf<parser::CompilerDirective>(*nextIt))
160+ if (GetConstructIf<parser::CompilerDirective>(*nextIt)) {
161+ nextIt++;
152162 continue ;
163+ }
153164
154165 if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
155166 if (doCons->GetLoopControl ()) {
@@ -160,9 +171,12 @@ class CanonicalizationOfOmp {
160171 if (nextIt != block.end ()) {
161172 if (auto *endDir{
162173 GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
163- std::get<std::optional<parser::OmpEndLoopDirective>>(x.t ) =
164- std::move (*endDir);
165- nextIt = block.erase (nextIt);
174+ auto &endDirName = endDir->DirName ();
175+ if (endDirName.v != llvm::omp::Directive::OMPD_fuse) {
176+ std::get<std::optional<parser::OmpEndLoopDirective>>(x.t ) =
177+ std::move (*endDir);
178+ nextIt = block.erase (nextIt);
179+ }
166180 }
167181 }
168182 } else {
@@ -172,50 +186,45 @@ class CanonicalizationOfOmp {
172186 }
173187 } else if (auto *ompLoopCons{
174188 GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
175- // We should allow UNROLL and TILE constructs to be inserted between an
176- // OpenMP Loop Construct and the DO loop itself
189+ // We should allow loop transformation constructs to be inserted between
190+ // an OpenMP Loop Construct and the DO loop itself
177191 auto &nestedBeginDirective = ompLoopCons->BeginDir ();
178192 auto &nestedBeginName = nestedBeginDirective.DirName ();
179- if ((nestedBeginName.v == llvm::omp::Directive::OMPD_unroll ||
180- nestedBeginName.v == llvm::omp::Directive::OMPD_tile) &&
181- !(nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
182- beginName.v == llvm::omp::Directive::OMPD_tile)) {
183- // iterate through the remaining block items to find the end directive
184- // for the unroll/tile directive.
185- parser::Block::iterator endIt;
186- endIt = nextIt;
187- while (endIt != block.end ()) {
188- if (auto *endDir{
189- GetConstructIf<parser::OmpEndLoopDirective>(*endIt)}) {
190- auto &endDirName = endDir->DirName ();
191- if (endDirName.v == beginName.v ) {
192- std::get<std::optional<parser::OmpEndLoopDirective>>(x.t ) =
193- std::move (*endDir);
194- endIt = block.erase (endIt);
195- continue ;
193+ if (llvm::omp::loopTransformationSet.test (nestedBeginName.v )) {
194+ if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
195+ llvm::omp::loopTransformationSet.test (beginName.v )) {
196+ // if a loop has been unrolled, the user can not then transform that
197+ // loop as it has been unrolled
198+ const parser::OmpClauseList &unrollClauseList{
199+ nestedBeginDirective.Clauses ()};
200+ if (unrollClauseList.v .empty ()) {
201+ // if the clause list is empty for an unroll construct, we assume
202+ // the loop is being fully unrolled
203+ transformUnrollError (beginName, messages_);
204+ } else {
205+ // parse the clauses for the unroll directive to find the full
206+ // clause
207+ for (auto &clause : unrollClauseList.v ) {
208+ if (clause.Id () == llvm::omp::OMPC_full) {
209+ transformUnrollError (beginName, messages_);
210+ }
196211 }
197212 }
198- ++endIt;
199213 }
200214 RewriteOpenMPLoopConstruct (*ompLoopCons, block, nextIt);
201215 body.push_back (std::move (*nextIt));
202216 nextIt = block.erase (nextIt);
203- } else if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
204- beginName.v == llvm::omp::Directive::OMPD_tile) {
205- // if a loop has been unrolled, the user can not then tile that loop
206- // as it has been unrolled
207- const parser::OmpClauseList &unrollClauseList{
208- nestedBeginDirective.Clauses ()};
209- if (unrollClauseList.v .empty ()) {
210- // if the clause list is empty for an unroll construct, we assume
211- // the loop is being fully unrolled
212- tileUnrollError (beginName, messages_);
213- } else {
214- // parse the clauses for the unroll directive to find the full
215- // clause
216- for (auto &clause : unrollClauseList.v ) {
217- if (clause.Id () == llvm::omp::OMPC_full) {
218- tileUnrollError (beginName, messages_);
217+ // check the following block item to find the end directive
218+ // for the loop transform directive.
219+ if (nextIt != block.end ()) {
220+ if (auto *endDir{
221+ GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
222+ auto &endDirName = endDir->DirName ();
223+ if (endDirName.v == beginName.v &&
224+ endDirName.v != llvm::omp::Directive::OMPD_fuse) {
225+ std::get<std::optional<parser::OmpEndLoopDirective>>(x.t ) =
226+ std::move (*endDir);
227+ nextIt = block.erase (nextIt);
219228 }
220229 }
221230 }
@@ -227,11 +236,29 @@ class CanonicalizationOfOmp {
227236 } else {
228237 missingDoConstruct (beginName, messages_);
229238 }
239+
240+ if (endFuseNeeded && nextIt != block.end ()) {
241+ if (auto *endDir{
242+ GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
243+ auto &endDirName = endDir->DirName ();
244+ if (endDirName.v == llvm::omp::Directive::OMPD_fuse) {
245+ endFuseNeeded = false ;
246+ std::get<std::optional<parser::OmpEndLoopDirective>>(x.t ) =
247+ std::move (*endDir);
248+ nextIt = block.erase (nextIt);
249+ }
250+ }
251+ }
252+ if (endFuseNeeded)
253+ continue ;
230254 // If we get here, we either found a loop, or issued an error message.
231255 return ;
232256 }
233257 if (nextIt == block.end ()) {
234- missingDoConstruct (beginName, messages_);
258+ if (endFuseNeeded)
259+ missingEndFuse (beginName, messages_);
260+ else
261+ missingDoConstruct (beginName, messages_);
235262 }
236263 }
237264
0 commit comments