Skip to content

Commit e70e9ec

Browse files
authored
[flang][OpenMP] Store Block in OpenMPLoopConstruct, add access functions (llvm#168078)
Instead of storing a variant with specific types, store parser::Block as the body. Add two access functions to make the traversal of the nest simpler. This will allow storing loop-nest sequences in the future.
1 parent fd1bdfd commit e70e9ec

19 files changed

+283
-337
lines changed

flang/include/flang/Parser/parse-tree.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5345,21 +5345,21 @@ struct OmpEndLoopDirective : public OmpEndDirective {
53455345
};
53465346

53475347
// OpenMP directives enclosing do loop
5348-
using NestedConstruct =
5349-
std::variant<DoConstruct, common::Indirection<OpenMPLoopConstruct>>;
53505348
struct OpenMPLoopConstruct {
53515349
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
53525350
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
5353-
: t({std::move(a), std::nullopt, std::nullopt}) {}
5351+
: t({std::move(a), Block{}, std::nullopt}) {}
53545352

53555353
const OmpBeginLoopDirective &BeginDir() const {
53565354
return std::get<OmpBeginLoopDirective>(t);
53575355
}
53585356
const std::optional<OmpEndLoopDirective> &EndDir() const {
53595357
return std::get<std::optional<OmpEndLoopDirective>>(t);
53605358
}
5361-
std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
5362-
std::optional<OmpEndLoopDirective>>
5359+
const DoConstruct *GetNestedLoop() const;
5360+
const OpenMPLoopConstruct *GetNestedConstruct() const;
5361+
5362+
std::tuple<OmpBeginLoopDirective, Block, std::optional<OmpEndLoopDirective>>
53635363
t;
53645364
};
53655365

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3962,27 +3962,22 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
39623962

39633963
mlir::Location currentLocation = converter.genLocation(beginSpec.source);
39643964

3965-
auto &optLoopCons =
3966-
std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
3967-
if (optLoopCons.has_value()) {
3968-
if (auto *ompNestedLoopCons{
3969-
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
3970-
&*optLoopCons)}) {
3971-
llvm::omp::Directive nestedDirective =
3972-
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
3973-
switch (nestedDirective) {
3974-
case llvm::omp::Directive::OMPD_tile:
3975-
// Skip OMPD_tile since the tile sizes will be retrieved when
3976-
// generating the omp.loop_nest op.
3977-
break;
3978-
default: {
3979-
unsigned version = semaCtx.langOptions().OpenMPVersion;
3980-
TODO(currentLocation,
3981-
"Applying a loop-associated on the loop generated by the " +
3982-
llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
3983-
" construct");
3984-
}
3985-
}
3965+
if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
3966+
loopConstruct.GetNestedConstruct()) {
3967+
llvm::omp::Directive nestedDirective =
3968+
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
3969+
switch (nestedDirective) {
3970+
case llvm::omp::Directive::OMPD_tile:
3971+
// Skip OMPD_tile since the tile sizes will be retrieved when
3972+
// generating the omp.loop_nest op.
3973+
break;
3974+
default: {
3975+
unsigned version = semaCtx.langOptions().OpenMPVersion;
3976+
TODO(currentLocation,
3977+
"Applying a loop-associated on the loop generated by the " +
3978+
llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
3979+
" construct");
3980+
}
39863981
}
39873982
}
39883983

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -779,17 +779,9 @@ static void processTileSizesFromOpenMPConstruct(
779779
if (!ompCons)
780780
return;
781781
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
782-
const auto &nestedOptional =
783-
std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
784-
assert(nestedOptional.has_value() &&
785-
"Expected a DoConstruct or OpenMPLoopConstruct");
786-
const auto *innerConstruct =
787-
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
788-
&(nestedOptional.value()));
789-
if (innerConstruct) {
790-
const auto &innerLoopDirective = innerConstruct->value();
782+
if (auto *innerConstruct = ompLoop->GetNestedConstruct()) {
791783
const parser::OmpDirectiveSpecification &innerBeginSpec =
792-
innerLoopDirective.BeginDir();
784+
innerConstruct->BeginDir();
793785
if (innerBeginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
794786
// Get the size values from parse tree and convert to a vector.
795787
for (const auto &clause : innerBeginSpec.Clauses().v) {

flang/lib/Parser/parse-tree.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "flang/Parser/parse-tree.h"
10+
1011
#include "flang/Common/idioms.h"
1112
#include "flang/Common/indirection.h"
13+
#include "flang/Parser/openmp-utils.h"
1214
#include "flang/Parser/tools.h"
1315
#include "flang/Parser/user-state.h"
1416
#include "llvm/ADT/ArrayRef.h"
@@ -432,6 +434,20 @@ const OmpClauseList &OmpDirectiveSpecification::Clauses() const {
432434
return empty;
433435
}
434436

437+
const DoConstruct *OpenMPLoopConstruct::GetNestedLoop() const {
438+
if (auto &body{std::get<Block>(t)}; !body.empty()) {
439+
return Unwrap<DoConstruct>(body.front());
440+
}
441+
return nullptr;
442+
}
443+
444+
const OpenMPLoopConstruct *OpenMPLoopConstruct::GetNestedConstruct() const {
445+
if (auto &body{std::get<Block>(t)}; !body.empty()) {
446+
return Unwrap<OpenMPLoopConstruct>(body.front());
447+
}
448+
return nullptr;
449+
}
450+
435451
static bool InitCharBlocksFromStrings(llvm::MutableArrayRef<CharBlock> blocks,
436452
llvm::ArrayRef<std::string> strings) {
437453
for (auto [i, n] : llvm::enumerate(strings)) {

flang/lib/Parser/unparse.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,12 +2706,6 @@ class UnparseVisitor {
27062706
Put("\n");
27072707
EndOpenMP();
27082708
}
2709-
void Unparse(const OpenMPLoopConstruct &x) {
2710-
Walk(std::get<OmpBeginLoopDirective>(x.t));
2711-
Walk(std::get<std::optional<std::variant<DoConstruct,
2712-
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
2713-
Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
2714-
}
27152709
void Unparse(const BasedPointer &x) {
27162710
Put('('), Walk(std::get<0>(x.t)), Put(","), Walk(std::get<1>(x.t));
27172711
Walk("(", std::get<std::optional<ArraySpec>>(x.t), ")"), Put(')');

flang/lib/Semantics/canonicalize-omp.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ class CanonicalizationOfOmp {
143143
parser::ToUpperCaseLetters(dirName.source.ToString()));
144144
};
145145

146+
auto &body{std::get<parser::Block>(x.t)};
147+
146148
nextIt = it;
147149
while (++nextIt != block.end()) {
148150
// Ignore compiler directives.
@@ -152,9 +154,7 @@ class CanonicalizationOfOmp {
152154
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
153155
if (doCons->GetLoopControl()) {
154156
// move DoConstruct
155-
std::get<std::optional<std::variant<parser::DoConstruct,
156-
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
157-
std::move(*doCons);
157+
body.push_back(std::move(*nextIt));
158158
nextIt = block.erase(nextIt);
159159
// try to match OmpEndLoopDirective
160160
if (nextIt != block.end()) {
@@ -198,10 +198,7 @@ class CanonicalizationOfOmp {
198198
++endIt;
199199
}
200200
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
201-
auto &ompLoop = std::get<std::optional<parser::NestedConstruct>>(x.t);
202-
ompLoop =
203-
std::optional<parser::NestedConstruct>{parser::NestedConstruct{
204-
common::Indirection{std::move(*ompLoopCons)}}};
201+
body.push_back(std::move(*nextIt));
205202
nextIt = block.erase(nextIt);
206203
} else if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
207204
beginName.v == llvm::omp::Directive::OMPD_tile) {

flang/lib/Semantics/check-omp-loop.cpp

Lines changed: 37 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -285,13 +285,9 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
285285
}
286286
SetLoopInfo(x);
287287

288-
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
289-
if (optLoopCons.has_value()) {
290-
if (const auto &doConstruct{
291-
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
292-
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
293-
CheckNoBranching(doBlock, beginName.v, beginName.source);
294-
}
288+
if (const auto *doConstruct{x.GetNestedLoop()}) {
289+
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
290+
CheckNoBranching(doBlock, beginName.v, beginName.source);
295291
}
296292
CheckLoopItrVariableIsInt(x);
297293
CheckAssociatedLoopConstraints(x);
@@ -314,46 +310,34 @@ const parser::Name OmpStructureChecker::GetLoopIndex(
314310
}
315311

316312
void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
317-
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
318-
if (optLoopCons.has_value()) {
319-
if (const auto &loopConstruct{
320-
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
321-
const parser::DoConstruct *loop{&*loopConstruct};
322-
if (loop && loop->IsDoNormal()) {
323-
const parser::Name &itrVal{GetLoopIndex(loop)};
324-
SetLoopIv(itrVal.symbol);
325-
}
313+
if (const auto *loop{x.GetNestedLoop()}) {
314+
if (loop->IsDoNormal()) {
315+
const parser::Name &itrVal{GetLoopIndex(loop)};
316+
SetLoopIv(itrVal.symbol);
326317
}
327318
}
328319
}
329320

330321
void OmpStructureChecker::CheckLoopItrVariableIsInt(
331322
const parser::OpenMPLoopConstruct &x) {
332-
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
333-
if (optLoopCons.has_value()) {
334-
if (const auto &loopConstruct{
335-
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
336-
337-
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
338-
if (loop->IsDoNormal()) {
339-
const parser::Name &itrVal{GetLoopIndex(loop)};
340-
if (itrVal.symbol) {
341-
const auto *type{itrVal.symbol->GetType()};
342-
if (!type->IsNumeric(TypeCategory::Integer)) {
343-
context_.Say(itrVal.source,
344-
"The DO loop iteration"
345-
" variable must be of the type integer."_err_en_US,
346-
itrVal.ToString());
347-
}
348-
}
323+
for (const parser::DoConstruct *loop{x.GetNestedLoop()}; loop;) {
324+
if (loop->IsDoNormal()) {
325+
const parser::Name &itrVal{GetLoopIndex(loop)};
326+
if (itrVal.symbol) {
327+
const auto *type{itrVal.symbol->GetType()};
328+
if (!type->IsNumeric(TypeCategory::Integer)) {
329+
context_.Say(itrVal.source,
330+
"The DO loop iteration"
331+
" variable must be of the type integer."_err_en_US,
332+
itrVal.ToString());
349333
}
350-
// Get the next DoConstruct if block is not empty.
351-
const auto &block{std::get<parser::Block>(loop->t)};
352-
const auto it{block.begin()};
353-
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
354-
: nullptr;
355334
}
356335
}
336+
// Get the next DoConstruct if block is not empty.
337+
const auto &block{std::get<parser::Block>(loop->t)};
338+
const auto it{block.begin()};
339+
loop =
340+
it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it) : nullptr;
357341
}
358342
}
359343

@@ -417,29 +401,23 @@ void OmpStructureChecker::CheckDistLinear(
417401

418402
// Match the loop index variables with the collected symbols from linear
419403
// clauses.
420-
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
421-
if (optLoopCons.has_value()) {
422-
if (const auto &loopConstruct{
423-
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
424-
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
425-
if (loop->IsDoNormal()) {
426-
const parser::Name &itrVal{GetLoopIndex(loop)};
427-
if (itrVal.symbol) {
428-
// Remove the symbol from the collected set
429-
indexVars.erase(&itrVal.symbol->GetUltimate());
430-
}
431-
collapseVal--;
432-
if (collapseVal == 0) {
433-
break;
434-
}
435-
}
436-
// Get the next DoConstruct if block is not empty.
437-
const auto &block{std::get<parser::Block>(loop->t)};
438-
const auto it{block.begin()};
439-
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
440-
: nullptr;
404+
for (const parser::DoConstruct *loop{x.GetNestedLoop()}; loop;) {
405+
if (loop->IsDoNormal()) {
406+
const parser::Name &itrVal{GetLoopIndex(loop)};
407+
if (itrVal.symbol) {
408+
// Remove the symbol from the collected set
409+
indexVars.erase(&itrVal.symbol->GetUltimate());
410+
}
411+
collapseVal--;
412+
if (collapseVal == 0) {
413+
break;
441414
}
442415
}
416+
// Get the next DoConstruct if block is not empty.
417+
const auto &block{std::get<parser::Block>(loop->t)};
418+
const auto it{block.begin()};
419+
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
420+
: nullptr;
443421
}
444422

445423
// Show error for the remaining variables

0 commit comments

Comments
 (0)