Skip to content

Commit c761405

Browse files
author
Ferran Toda
committed
Loop sequences and loop fuse semantics
1 parent e4d94f4 commit c761405

File tree

15 files changed

+489
-157
lines changed

15 files changed

+489
-157
lines changed

flang/include/flang/Parser/parse-tree.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5177,15 +5177,15 @@ using NestedConstruct =
51775177
struct OpenMPLoopConstruct {
51785178
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
51795179
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
5180-
: t({std::move(a), std::nullopt, std::nullopt}) {}
5180+
: t({std::move(a), std::list<NestedConstruct>(), std::nullopt}) {}
51815181

51825182
const OmpBeginLoopDirective &BeginDir() const {
51835183
return std::get<OmpBeginLoopDirective>(t);
51845184
}
51855185
const std::optional<OmpEndLoopDirective> &EndDir() const {
51865186
return std::get<std::optional<OmpEndLoopDirective>>(t);
51875187
}
5188-
std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
5188+
std::tuple<OmpBeginLoopDirective, std::list<NestedConstruct>,
51895189
std::optional<OmpEndLoopDirective>>
51905190
t;
51915191
};

flang/include/flang/Semantics/openmp-directive-sets.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,17 @@ static const OmpDirectiveSet loopConstructSet{
275275
Directive::OMPD_teams_distribute_parallel_do_simd,
276276
Directive::OMPD_teams_distribute_simd,
277277
Directive::OMPD_teams_loop,
278+
Directive::OMPD_fuse,
278279
Directive::OMPD_tile,
279280
Directive::OMPD_unroll,
280281
};
281282

283+
static const OmpDirectiveSet loopTransformationSet{
284+
Directive::OMPD_tile,
285+
Directive::OMPD_unroll,
286+
Directive::OMPD_fuse,
287+
};
288+
282289
static const OmpDirectiveSet nonPartialVarSet{
283290
Directive::OMPD_allocate,
284291
Directive::OMPD_allocators,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3360,6 +3360,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
33603360
newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
33613361
item);
33623362
break;
3363+
case llvm::omp::Directive::OMPD_fuse:
33633364
case llvm::omp::Directive::OMPD_tile: {
33643365
unsigned version = semaCtx.langOptions().OpenMPVersion;
33653366
if (!semaCtx.langOptions().OpenMPSimd)
@@ -3814,12 +3815,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
38143815

38153816
mlir::Location currentLocation = converter.genLocation(beginSpec.source);
38163817

3817-
auto &optLoopCons =
3818-
std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
3819-
if (optLoopCons.has_value()) {
3818+
auto &loopConsList =
3819+
std::get<std::list<parser::NestedConstruct>>(loopConstruct.t);
3820+
for (auto &loopCons : loopConsList) {
38203821
if (auto *ompNestedLoopCons{
38213822
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
3822-
&*optLoopCons)}) {
3823+
&loopCons)}) {
38233824
llvm::omp::Directive nestedDirective =
38243825
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
38253826
switch (nestedDirective) {

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,13 +607,13 @@ static void processTileSizesFromOpenMPConstruct(
607607
if (!ompCons)
608608
return;
609609
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
610-
const auto &nestedOptional =
611-
std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
612-
assert(nestedOptional.has_value() &&
610+
const auto &loopConsList =
611+
std::get<std::list<parser::NestedConstruct>>(ompLoop->t);
612+
assert(loopConsList.size() == 1 &&
613613
"Expected a DoConstruct or OpenMPLoopConstruct");
614614
const auto *innerConstruct =
615615
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
616-
&(nestedOptional.value()));
616+
&(loopConsList.front()));
617617
if (innerConstruct) {
618618
const auto &innerLoopDirective = innerConstruct->value();
619619
const parser::OmpDirectiveSpecification &innerBeginSpec =

flang/lib/Parser/openmp-parsers.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2019,6 +2019,7 @@ static constexpr DirectiveSet GetLoopDirectives() {
20192019
unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
20202020
unsigned(Directive::OMPD_teams_distribute_simd),
20212021
unsigned(Directive::OMPD_teams_loop),
2022+
unsigned(Directive::OMPD_fuse),
20222023
unsigned(Directive::OMPD_tile),
20232024
unsigned(Directive::OMPD_unroll),
20242025
};

flang/lib/Parser/unparse.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2720,8 +2720,7 @@ class UnparseVisitor {
27202720
}
27212721
void Unparse(const OpenMPLoopConstruct &x) {
27222722
Walk(std::get<OmpBeginLoopDirective>(x.t));
2723-
Walk(std::get<std::optional<std::variant<DoConstruct,
2724-
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
2723+
Walk(std::get<std::list<parser::NestedConstruct>>(x.t));
27252724
Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
27262725
}
27272726
void Unparse(const BasedPointer &x) {

flang/lib/Semantics/canonicalize-omp.cpp

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "flang/Parser/parse-tree-visitor.h"
1111
#include "flang/Parser/parse-tree.h"
1212
#include "flang/Semantics/semantics.h"
13+
#include "flang/Semantics/openmp-directive-sets.h"
1314

1415
// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
1516
// Constructs more structured which provide explicit scopes for later
@@ -137,30 +138,42 @@ class CanonicalizationOfOmp {
137138
"A DO loop must follow the %s directive"_err_en_US,
138139
parser::ToUpperCaseLetters(dirName.source.ToString()));
139140
};
140-
auto tileUnrollError = [](const parser::OmpDirectiveName &dirName,
141+
auto transformUnrollError = [](const parser::OmpDirectiveName &dirName,
141142
parser::Messages &messages) {
142143
messages.Say(dirName.source,
143-
"If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
144+
"If a loop construct has been fully unrolled, it cannot then be further transformed"_err_en_US,
144145
parser::ToUpperCaseLetters(dirName.source.ToString()));
145146
};
147+
auto missingEndFuse = [](auto &dir, auto &messages) {
148+
messages.Say(dir.source,
149+
"The %s construct requires the END FUSE directive"_err_en_US,
150+
parser::ToUpperCaseLetters(dir.source.ToString()));
151+
};
152+
153+
bool endFuseNeeded = beginName.v == llvm::omp::Directive::OMPD_fuse;
146154

147155
nextIt = it;
148-
while (++nextIt != block.end()) {
156+
nextIt++;
157+
while (nextIt != block.end()) {
149158
// Ignore compiler directives.
150-
if (GetConstructIf<parser::CompilerDirective>(*nextIt))
159+
if (GetConstructIf<parser::CompilerDirective>(*nextIt)) {
160+
nextIt++;
151161
continue;
162+
}
152163

153164
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
154165
if (doCons->GetLoopControl()) {
155166
// move DoConstruct
156-
std::get<std::optional<std::variant<parser::DoConstruct,
157-
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
158-
std::move(*doCons);
167+
std::get<std::list<parser::NestedConstruct>>(x.t).push_back(
168+
std::move(*doCons));
159169
nextIt = block.erase(nextIt);
160170
// try to match OmpEndLoopDirective
161171
if (nextIt != block.end()) {
162172
if (auto *endDir{
163173
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
174+
auto &endDirName = endDir->DirName();
175+
if (endDirName.v == llvm::omp::Directive::OMPD_fuse)
176+
endFuseNeeded = false;
164177
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
165178
std::move(*endDir);
166179
nextIt = block.erase(nextIt);
@@ -170,17 +183,37 @@ class CanonicalizationOfOmp {
170183
messages_.Say(beginName.source,
171184
"DO loop after the %s directive must have loop control"_err_en_US,
172185
parser::ToUpperCaseLetters(beginName.source.ToString()));
186+
endFuseNeeded = false;
173187
}
174188
} else if (auto *ompLoopCons{
175189
GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
176190
// We should allow UNROLL and TILE constructs to be inserted between an
177191
// OpenMP Loop Construct and the DO loop itself
178192
auto &nestedBeginDirective = ompLoopCons->BeginDir();
179193
auto &nestedBeginName = nestedBeginDirective.DirName();
180-
if ((nestedBeginName.v == llvm::omp::Directive::OMPD_unroll ||
181-
nestedBeginName.v == llvm::omp::Directive::OMPD_tile) &&
182-
!(nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
183-
beginName.v == llvm::omp::Directive::OMPD_tile)) {
194+
if (llvm::omp::loopTransformationSet.test(nestedBeginName.v)) {
195+
if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
196+
llvm::omp::loopTransformationSet.test(beginName.v)) {
197+
// if a loop has been unrolled, the user can not then tile that loop
198+
// as it has been unrolled
199+
const parser::OmpClauseList &unrollClauseList{
200+
nestedBeginDirective.Clauses()};
201+
if (unrollClauseList.v.empty()) {
202+
// if the clause list is empty for an unroll construct, we assume
203+
// the loop is being fully unrolled
204+
transformUnrollError(beginName, messages_);
205+
endFuseNeeded = false;
206+
} else {
207+
// parse the clauses for the unroll directive to find the full
208+
// clause
209+
for (auto &clause : unrollClauseList.v) {
210+
if (clause.Id() == llvm::omp::OMPC_full) {
211+
transformUnrollError(beginName, messages_);
212+
endFuseNeeded = false;
213+
}
214+
}
215+
}
216+
}
184217
// iterate through the remaining block items to find the end directive
185218
// for the unroll/tile directive.
186219
parser::Block::iterator endIt;
@@ -190,6 +223,8 @@ class CanonicalizationOfOmp {
190223
GetConstructIf<parser::OmpEndLoopDirective>(*endIt)}) {
191224
auto &endDirName = endDir->DirName();
192225
if (endDirName.v == beginName.v) {
226+
if (endDirName.v == llvm::omp::Directive::OMPD_fuse)
227+
endFuseNeeded = false;
193228
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
194229
std::move(*endDir);
195230
endIt = block.erase(endIt);
@@ -199,43 +234,30 @@ class CanonicalizationOfOmp {
199234
++endIt;
200235
}
201236
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
202-
auto &ompLoop = std::get<std::optional<parser::NestedConstruct>>(x.t);
203-
ompLoop =
204-
std::optional<parser::NestedConstruct>{parser::NestedConstruct{
205-
common::Indirection{std::move(*ompLoopCons)}}};
237+
auto &loopConsList = std::get<std::list<parser::NestedConstruct>>(x.t);
238+
loopConsList.push_back(parser::NestedConstruct{
239+
common::Indirection{std::move(*ompLoopCons)}});
206240
nextIt = block.erase(nextIt);
207-
} else if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
208-
beginName.v == llvm::omp::Directive::OMPD_tile) {
209-
// if a loop has been unrolled, the user can not then tile that loop
210-
// as it has been unrolled
211-
const parser::OmpClauseList &unrollClauseList{
212-
nestedBeginDirective.Clauses()};
213-
if (unrollClauseList.v.empty()) {
214-
// if the clause list is empty for an unroll construct, we assume
215-
// the loop is being fully unrolled
216-
tileUnrollError(beginName, messages_);
217-
} else {
218-
// parse the clauses for the unroll directive to find the full
219-
// clause
220-
for (auto &clause : unrollClauseList.v) {
221-
if (clause.Id() == llvm::omp::OMPC_full) {
222-
tileUnrollError(beginName, messages_);
223-
}
224-
}
225-
}
226241
} else {
227242
messages_.Say(nestedBeginName.source,
228243
"Only Loop Transformation Constructs or Loop Nests can be nested within Loop Constructs"_err_en_US,
229244
parser::ToUpperCaseLetters(nestedBeginName.source.ToString()));
245+
endFuseNeeded = false;
230246
}
231247
} else {
232248
missingDoConstruct(beginName, messages_);
249+
endFuseNeeded = false;
233250
}
251+
if (endFuseNeeded)
252+
continue;
234253
// If we get here, we either found a loop, or issued an error message.
235254
return;
236255
}
237256
if (nextIt == block.end()) {
238-
missingDoConstruct(beginName, messages_);
257+
if (endFuseNeeded)
258+
missingEndFuse(beginName, messages_);
259+
else
260+
missingDoConstruct(beginName, messages_);
239261
}
240262
}
241263

flang/lib/Semantics/check-omp-loop.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,9 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
286286
}
287287
SetLoopInfo(x);
288288

289-
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
290-
if (optLoopCons.has_value()) {
291-
if (const auto &doConstruct{
292-
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
289+
auto &loopConsList = std::get<std::list<parser::NestedConstruct>>(x.t);
290+
for (auto &loopCons : loopConsList) {
291+
if (const auto &doConstruct{std::get_if<parser::DoConstruct>(&loopCons)}) {
293292
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
294293
CheckNoBranching(doBlock, beginName.v, beginName.source);
295294
}
@@ -315,10 +314,10 @@ const parser::Name OmpStructureChecker::GetLoopIndex(
315314
}
316315

317316
void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
318-
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
319-
if (optLoopCons.has_value()) {
317+
auto &loopConsList = std::get<std::list<parser::NestedConstruct>>(x.t);
318+
if (loopConsList.size() == 1) {
320319
if (const auto &loopConstruct{
321-
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
320+
std::get_if<parser::DoConstruct>(&loopConsList.front())}) {
322321
const parser::DoConstruct *loop{&*loopConstruct};
323322
if (loop && loop->IsDoNormal()) {
324323
const parser::Name &itrVal{GetLoopIndex(loop)};
@@ -330,10 +329,10 @@ void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
330329

331330
void OmpStructureChecker::CheckLoopItrVariableIsInt(
332331
const parser::OpenMPLoopConstruct &x) {
333-
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
334-
if (optLoopCons.has_value()) {
332+
auto &loopConsList = std::get<std::list<parser::NestedConstruct>>(x.t);
333+
for (auto &loopCons : loopConsList) {
335334
if (const auto &loopConstruct{
336-
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
335+
std::get_if<parser::DoConstruct>(&loopCons)}) {
337336

338337
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
339338
if (loop->IsDoNormal()) {
@@ -418,19 +417,20 @@ void OmpStructureChecker::CheckDistLinear(
418417

419418
// Match the loop index variables with the collected symbols from linear
420419
// clauses.
421-
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
422-
if (optLoopCons.has_value()) {
420+
auto &loopConsList = std::get<std::list<parser::NestedConstruct>>(x.t);
421+
for (auto &loopCons : loopConsList) {
422+
std::int64_t collapseVal_ = collapseVal;
423423
if (const auto &loopConstruct{
424-
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
424+
std::get_if<parser::DoConstruct>(&loopCons)}) {
425425
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
426426
if (loop->IsDoNormal()) {
427427
const parser::Name &itrVal{GetLoopIndex(loop)};
428428
if (itrVal.symbol) {
429429
// Remove the symbol from the collected set
430430
indexVars.erase(&itrVal.symbol->GetUltimate());
431431
}
432-
collapseVal--;
433-
if (collapseVal == 0) {
432+
collapseVal_--;
433+
if (collapseVal_ == 0) {
434434
break;
435435
}
436436
}

0 commit comments

Comments
 (0)