Skip to content

Commit f4ebee0

Browse files
[Flang][OpenMP] Add semantic support for Loop Sequences and OpenMP loop fuse (#161213)
This patch adds semantics for the `omp fuse` directive in flang, as specified in OpenMP 6.0. This patch also enables semantic support for loop sequences which are needed for the fuse directive along with semantics for the `looprange` clause. These changes are only semantic. Relevant tests have been added , and previous behavior is retained with no changes. --------- Co-authored-by: Ferran Toda <[email protected]> Co-authored-by: Krzysztof Parzyszek <[email protected]>
1 parent a2dc4e0 commit f4ebee0

23 files changed

+963
-216
lines changed

flang/include/flang/Parser/openmp-utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
123123
const OpenMPDeclarativeConstruct *GetOmp(const DeclarationConstruct &x);
124124
const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);
125125

126+
const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
127+
const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);
128+
126129
const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
127130

128131
template <typename T>

flang/include/flang/Semantics/openmp-directive-sets.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,17 @@ static const OmpDirectiveSet loopConstructSet{
275275
Directive::OMPD_teams_distribute_parallel_do_simd,
276276
Directive::OMPD_teams_distribute_simd,
277277
Directive::OMPD_teams_loop,
278+
Directive::OMPD_fuse,
278279
Directive::OMPD_tile,
279280
Directive::OMPD_unroll,
280281
};
281282

283+
static const OmpDirectiveSet loopTransformationSet{
284+
Directive::OMPD_tile,
285+
Directive::OMPD_unroll,
286+
Directive::OMPD_fuse,
287+
};
288+
282289
static const OmpDirectiveSet nonPartialVarSet{
283290
Directive::OMPD_allocate,
284291
Directive::OMPD_allocators,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3507,6 +3507,13 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
35073507
case llvm::omp::Directive::OMPD_tile:
35083508
genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
35093509
break;
3510+
case llvm::omp::Directive::OMPD_fuse: {
3511+
unsigned version = semaCtx.langOptions().OpenMPVersion;
3512+
if (!semaCtx.langOptions().OpenMPSimd)
3513+
TODO(loc, "Unhandled loop directive (" +
3514+
llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
3515+
break;
3516+
}
35103517
case llvm::omp::Directive::OMPD_unroll:
35113518
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
35123519
break;
@@ -3962,22 +3969,24 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
39623969

39633970
mlir::Location currentLocation = converter.genLocation(beginSpec.source);
39643971

3965-
if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
3966-
loopConstruct.GetNestedConstruct()) {
3967-
llvm::omp::Directive nestedDirective =
3968-
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
3969-
switch (nestedDirective) {
3970-
case llvm::omp::Directive::OMPD_tile:
3971-
// Skip OMPD_tile since the tile sizes will be retrieved when
3972-
// generating the omp.loop_nest op.
3973-
break;
3974-
default: {
3975-
unsigned version = semaCtx.langOptions().OpenMPVersion;
3976-
TODO(currentLocation,
3977-
"Applying a loop-associated on the loop generated by the " +
3978-
llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
3979-
" construct");
3980-
}
3972+
for (auto &construct : std::get<parser::Block>(loopConstruct.t)) {
3973+
if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
3974+
parser::omp::GetOmpLoop(construct)) {
3975+
llvm::omp::Directive nestedDirective =
3976+
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
3977+
switch (nestedDirective) {
3978+
case llvm::omp::Directive::OMPD_tile:
3979+
// Skip OMPD_tile since the tile sizes will be retrieved when
3980+
// generating the omp.loop_nest op.
3981+
break;
3982+
default: {
3983+
unsigned version = semaCtx.langOptions().OpenMPVersion;
3984+
TODO(currentLocation,
3985+
"Applying a loop-associated on the loop generated by the " +
3986+
llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
3987+
" construct");
3988+
}
3989+
}
39813990
}
39823991
}
39833992

flang/lib/Parser/openmp-parsers.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2260,6 +2260,7 @@ static constexpr DirectiveSet GetLoopDirectives() {
22602260
unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
22612261
unsigned(Directive::OMPD_teams_distribute_simd),
22622262
unsigned(Directive::OMPD_teams_loop),
2263+
unsigned(Directive::OMPD_fuse),
22632264
unsigned(Directive::OMPD_tile),
22642265
unsigned(Directive::OMPD_unroll),
22652266
};

flang/lib/Parser/openmp-utils.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,23 @@ const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x) {
4141
return nullptr;
4242
}
4343

44+
const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x) {
45+
if (auto *construct{GetOmp(x)}) {
46+
if (auto *omp{std::get_if<OpenMPLoopConstruct>(&construct->u)}) {
47+
return omp;
48+
}
49+
}
50+
return nullptr;
51+
}
52+
const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x) {
53+
if (auto *y{std::get_if<ExecutableConstruct>(&x.u)}) {
54+
if (auto *z{std::get_if<common::Indirection<DoConstruct>>(&y->u)}) {
55+
return &z->value();
56+
}
57+
}
58+
return nullptr;
59+
}
60+
4461
const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
4562
// Clauses with OmpObjectList as its data member
4663
using MemberObjectListClauses = std::tuple<OmpClause::Copyin,

flang/lib/Semantics/canonicalize-omp.cpp

Lines changed: 72 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "canonicalize-omp.h"
1010
#include "flang/Parser/parse-tree-visitor.h"
1111
#include "flang/Parser/parse-tree.h"
12+
#include "flang/Semantics/openmp-directive-sets.h"
1213
#include "flang/Semantics/semantics.h"
1314

1415
// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
@@ -136,20 +137,30 @@ class CanonicalizationOfOmp {
136137
"A DO loop must follow the %s directive"_err_en_US,
137138
parser::ToUpperCaseLetters(dirName.source.ToString()));
138139
};
139-
auto tileUnrollError = [](const parser::OmpDirectiveName &dirName,
140-
parser::Messages &messages) {
140+
auto transformUnrollError = [](const parser::OmpDirectiveName &dirName,
141+
parser::Messages &messages) {
141142
messages.Say(dirName.source,
142-
"If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
143+
"If a loop construct has been fully unrolled, it cannot then be further transformed"_err_en_US,
143144
parser::ToUpperCaseLetters(dirName.source.ToString()));
144145
};
146+
auto missingEndFuse = [](auto &dir, auto &messages) {
147+
messages.Say(dir.source,
148+
"The %s construct requires the END FUSE directive"_err_en_US,
149+
parser::ToUpperCaseLetters(dir.source.ToString()));
150+
};
151+
152+
bool endFuseNeeded = beginName.v == llvm::omp::Directive::OMPD_fuse;
145153

146154
auto &body{std::get<parser::Block>(x.t)};
147155

148156
nextIt = it;
149-
while (++nextIt != block.end()) {
157+
nextIt++;
158+
while (nextIt != block.end()) {
150159
// Ignore compiler directives.
151-
if (GetConstructIf<parser::CompilerDirective>(*nextIt))
160+
if (GetConstructIf<parser::CompilerDirective>(*nextIt)) {
161+
nextIt++;
152162
continue;
163+
}
153164

154165
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
155166
if (doCons->GetLoopControl()) {
@@ -160,9 +171,12 @@ class CanonicalizationOfOmp {
160171
if (nextIt != block.end()) {
161172
if (auto *endDir{
162173
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
163-
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
164-
std::move(*endDir);
165-
nextIt = block.erase(nextIt);
174+
auto &endDirName = endDir->DirName();
175+
if (endDirName.v != llvm::omp::Directive::OMPD_fuse) {
176+
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
177+
std::move(*endDir);
178+
nextIt = block.erase(nextIt);
179+
}
166180
}
167181
}
168182
} else {
@@ -172,50 +186,45 @@ class CanonicalizationOfOmp {
172186
}
173187
} else if (auto *ompLoopCons{
174188
GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
175-
// We should allow UNROLL and TILE constructs to be inserted between an
176-
// OpenMP Loop Construct and the DO loop itself
189+
// We should allow loop transformation constructs to be inserted between
190+
// an OpenMP Loop Construct and the DO loop itself
177191
auto &nestedBeginDirective = ompLoopCons->BeginDir();
178192
auto &nestedBeginName = nestedBeginDirective.DirName();
179-
if ((nestedBeginName.v == llvm::omp::Directive::OMPD_unroll ||
180-
nestedBeginName.v == llvm::omp::Directive::OMPD_tile) &&
181-
!(nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
182-
beginName.v == llvm::omp::Directive::OMPD_tile)) {
183-
// iterate through the remaining block items to find the end directive
184-
// for the unroll/tile directive.
185-
parser::Block::iterator endIt;
186-
endIt = nextIt;
187-
while (endIt != block.end()) {
188-
if (auto *endDir{
189-
GetConstructIf<parser::OmpEndLoopDirective>(*endIt)}) {
190-
auto &endDirName = endDir->DirName();
191-
if (endDirName.v == beginName.v) {
192-
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
193-
std::move(*endDir);
194-
endIt = block.erase(endIt);
195-
continue;
193+
if (llvm::omp::loopTransformationSet.test(nestedBeginName.v)) {
194+
if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
195+
llvm::omp::loopTransformationSet.test(beginName.v)) {
196+
// if a loop has been unrolled, the user can not then transform that
197+
// loop as it has been unrolled
198+
const parser::OmpClauseList &unrollClauseList{
199+
nestedBeginDirective.Clauses()};
200+
if (unrollClauseList.v.empty()) {
201+
// if the clause list is empty for an unroll construct, we assume
202+
// the loop is being fully unrolled
203+
transformUnrollError(beginName, messages_);
204+
} else {
205+
// parse the clauses for the unroll directive to find the full
206+
// clause
207+
for (auto &clause : unrollClauseList.v) {
208+
if (clause.Id() == llvm::omp::OMPC_full) {
209+
transformUnrollError(beginName, messages_);
210+
}
196211
}
197212
}
198-
++endIt;
199213
}
200214
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
201215
body.push_back(std::move(*nextIt));
202216
nextIt = block.erase(nextIt);
203-
} else if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
204-
beginName.v == llvm::omp::Directive::OMPD_tile) {
205-
// if a loop has been unrolled, the user can not then tile that loop
206-
// as it has been unrolled
207-
const parser::OmpClauseList &unrollClauseList{
208-
nestedBeginDirective.Clauses()};
209-
if (unrollClauseList.v.empty()) {
210-
// if the clause list is empty for an unroll construct, we assume
211-
// the loop is being fully unrolled
212-
tileUnrollError(beginName, messages_);
213-
} else {
214-
// parse the clauses for the unroll directive to find the full
215-
// clause
216-
for (auto &clause : unrollClauseList.v) {
217-
if (clause.Id() == llvm::omp::OMPC_full) {
218-
tileUnrollError(beginName, messages_);
217+
// check the following block item to find the end directive
218+
// for the loop transform directive.
219+
if (nextIt != block.end()) {
220+
if (auto *endDir{
221+
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
222+
auto &endDirName = endDir->DirName();
223+
if (endDirName.v == beginName.v &&
224+
endDirName.v != llvm::omp::Directive::OMPD_fuse) {
225+
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
226+
std::move(*endDir);
227+
nextIt = block.erase(nextIt);
219228
}
220229
}
221230
}
@@ -227,11 +236,29 @@ class CanonicalizationOfOmp {
227236
} else {
228237
missingDoConstruct(beginName, messages_);
229238
}
239+
240+
if (endFuseNeeded && nextIt != block.end()) {
241+
if (auto *endDir{
242+
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
243+
auto &endDirName = endDir->DirName();
244+
if (endDirName.v == llvm::omp::Directive::OMPD_fuse) {
245+
endFuseNeeded = false;
246+
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
247+
std::move(*endDir);
248+
nextIt = block.erase(nextIt);
249+
}
250+
}
251+
}
252+
if (endFuseNeeded)
253+
continue;
230254
// If we get here, we either found a loop, or issued an error message.
231255
return;
232256
}
233257
if (nextIt == block.end()) {
234-
missingDoConstruct(beginName, messages_);
258+
if (endFuseNeeded)
259+
missingEndFuse(beginName, messages_);
260+
else
261+
missingDoConstruct(beginName, messages_);
235262
}
236263
}
237264

0 commit comments

Comments
 (0)