Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -5260,15 +5260,15 @@ using NestedConstruct =
struct OpenMPLoopConstruct {
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
: t({std::move(a), std::nullopt, std::nullopt}) {}
: t({std::move(a), std::list<NestedConstruct>(), std::nullopt}) {}

const OmpBeginLoopDirective &BeginDir() const {
return std::get<OmpBeginLoopDirective>(t);
}
const std::optional<OmpEndLoopDirective> &EndDir() const {
return std::get<std::optional<OmpEndLoopDirective>>(t);
}
std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
std::tuple<OmpBeginLoopDirective, std::list<NestedConstruct>,
std::optional<OmpEndLoopDirective>>
t;
};
Expand Down
7 changes: 7 additions & 0 deletions flang/include/flang/Semantics/openmp-directive-sets.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,17 @@ static const OmpDirectiveSet loopConstructSet{
Directive::OMPD_teams_distribute_parallel_do_simd,
Directive::OMPD_teams_distribute_simd,
Directive::OMPD_teams_loop,
Directive::OMPD_fuse,
Directive::OMPD_tile,
Directive::OMPD_unroll,
};

static const OmpDirectiveSet loopTransformationSet{
Directive::OMPD_tile,
Directive::OMPD_unroll,
Directive::OMPD_fuse,
};

static const OmpDirectiveSet nonPartialVarSet{
Directive::OMPD_allocate,
Directive::OMPD_allocators,
Expand Down
15 changes: 11 additions & 4 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3471,6 +3471,13 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_tile:
genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_fuse: {
unsigned version = semaCtx.langOptions().OpenMPVersion;
if (!semaCtx.langOptions().OpenMPSimd)
TODO(loc, "Unhandled loop directive (" +
llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
break;
}
case llvm::omp::Directive::OMPD_unroll:
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
Expand Down Expand Up @@ -3918,12 +3925,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,

mlir::Location currentLocation = converter.genLocation(beginSpec.source);

auto &optLoopCons =
std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
if (optLoopCons.has_value()) {
auto &loopConsList =
std::get<std::list<parser::NestedConstruct>>(loopConstruct.t);
for (auto &loopCons : loopConsList) {
if (auto *ompNestedLoopCons{
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&*optLoopCons)}) {
&loopCons)}) {
llvm::omp::Directive nestedDirective =
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
switch (nestedDirective) {
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,13 +631,13 @@ static void processTileSizesFromOpenMPConstruct(
if (!ompCons)
return;
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
const auto &nestedOptional =
std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
assert(nestedOptional.has_value() &&
const auto &loopConsList =
std::get<std::list<parser::NestedConstruct>>(ompLoop->t);
assert(loopConsList.size() == 1 &&
"Expected a DoConstruct or OpenMPLoopConstruct");
const auto *innerConstruct =
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&(nestedOptional.value()));
&(loopConsList.front()));
if (innerConstruct) {
const auto &innerLoopDirective = innerConstruct->value();
const parser::OmpDirectiveSpecification &innerBeginSpec =
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Parser/openmp-parsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2037,6 +2037,7 @@ static constexpr DirectiveSet GetLoopDirectives() {
unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
unsigned(Directive::OMPD_teams_distribute_simd),
unsigned(Directive::OMPD_teams_loop),
unsigned(Directive::OMPD_fuse),
unsigned(Directive::OMPD_tile),
unsigned(Directive::OMPD_unroll),
};
Expand Down
3 changes: 1 addition & 2 deletions flang/lib/Parser/unparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2725,8 +2725,7 @@ class UnparseVisitor {
}
void Unparse(const OpenMPLoopConstruct &x) {
Walk(std::get<OmpBeginLoopDirective>(x.t));
Walk(std::get<std::optional<std::variant<DoConstruct,
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
Walk(std::get<std::list<parser::NestedConstruct>>(x.t));
Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
}
void Unparse(const BasedPointer &x) {
Expand Down
130 changes: 78 additions & 52 deletions flang/lib/Semantics/canonicalize-omp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "canonicalize-omp.h"
#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/semantics.h"

// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
Expand Down Expand Up @@ -137,33 +138,45 @@ class CanonicalizationOfOmp {
"A DO loop must follow the %s directive"_err_en_US,
parser::ToUpperCaseLetters(dirName.source.ToString()));
};
auto tileUnrollError = [](const parser::OmpDirectiveName &dirName,
parser::Messages &messages) {
auto transformUnrollError = [](const parser::OmpDirectiveName &dirName,
parser::Messages &messages) {
messages.Say(dirName.source,
"If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
"If a loop construct has been fully unrolled, it cannot then be further transformed"_err_en_US,
parser::ToUpperCaseLetters(dirName.source.ToString()));
};
auto missingEndFuse = [](auto &dir, auto &messages) {
messages.Say(dir.source,
"The %s construct requires the END FUSE directive"_err_en_US,
parser::ToUpperCaseLetters(dir.source.ToString()));
};

bool endFuseNeeded = beginName.v == llvm::omp::Directive::OMPD_fuse;

nextIt = it;
while (++nextIt != block.end()) {
nextIt++;
while (nextIt != block.end()) {
// Ignore compiler directives.
if (GetConstructIf<parser::CompilerDirective>(*nextIt))
if (GetConstructIf<parser::CompilerDirective>(*nextIt)) {
nextIt++;
continue;
}

if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
if (doCons->GetLoopControl()) {
// move DoConstruct
std::get<std::optional<std::variant<parser::DoConstruct,
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
std::move(*doCons);
std::get<std::list<parser::NestedConstruct>>(x.t).push_back(
std::move(*doCons));
nextIt = block.erase(nextIt);
// try to match OmpEndLoopDirective
if (nextIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
nextIt = block.erase(nextIt);
auto &endDirName = endDir->DirName();
if (endDirName.v != llvm::omp::Directive::OMPD_fuse) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
nextIt = block.erase(nextIt);
}
}
}
} else {
Expand All @@ -173,53 +186,48 @@ class CanonicalizationOfOmp {
}
} else if (auto *ompLoopCons{
GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
// We should allow UNROLL and TILE constructs to be inserted between an
// OpenMP Loop Construct and the DO loop itself
// We should allow loop transformation constructs to be inserted between
// an OpenMP Loop Construct and the DO loop itself
auto &nestedBeginDirective = ompLoopCons->BeginDir();
auto &nestedBeginName = nestedBeginDirective.DirName();
if ((nestedBeginName.v == llvm::omp::Directive::OMPD_unroll ||
nestedBeginName.v == llvm::omp::Directive::OMPD_tile) &&
!(nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
beginName.v == llvm::omp::Directive::OMPD_tile)) {
// iterate through the remaining block items to find the end directive
// for the unroll/tile directive.
parser::Block::iterator endIt;
endIt = nextIt;
while (endIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*endIt)}) {
auto &endDirName = endDir->DirName();
if (endDirName.v == beginName.v) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
endIt = block.erase(endIt);
continue;
if (llvm::omp::loopTransformationSet.test(nestedBeginName.v)) {
if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
llvm::omp::loopTransformationSet.test(beginName.v)) {
// if a loop has been unrolled, the user can not then transform that
// loop as it has been unrolled
const parser::OmpClauseList &unrollClauseList{
nestedBeginDirective.Clauses()};
if (unrollClauseList.v.empty()) {
// if the clause list is empty for an unroll construct, we assume
// the loop is being fully unrolled
transformUnrollError(beginName, messages_);
} else {
// parse the clauses for the unroll directive to find the full
// clause
for (auto &clause : unrollClauseList.v) {
if (clause.Id() == llvm::omp::OMPC_full) {
transformUnrollError(beginName, messages_);
}
}
}
++endIt;
}
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
auto &ompLoop = std::get<std::optional<parser::NestedConstruct>>(x.t);
ompLoop =
std::optional<parser::NestedConstruct>{parser::NestedConstruct{
common::Indirection{std::move(*ompLoopCons)}}};
auto &loopConsList =
std::get<std::list<parser::NestedConstruct>>(x.t);
loopConsList.push_back(parser::NestedConstruct{
common::Indirection{std::move(*ompLoopCons)}});
nextIt = block.erase(nextIt);
} else if (nestedBeginName.v == llvm::omp::Directive::OMPD_unroll &&
beginName.v == llvm::omp::Directive::OMPD_tile) {
// if a loop has been unrolled, the user can not then tile that loop
// as it has been unrolled
const parser::OmpClauseList &unrollClauseList{
nestedBeginDirective.Clauses()};
if (unrollClauseList.v.empty()) {
// if the clause list is empty for an unroll construct, we assume
// the loop is being fully unrolled
tileUnrollError(beginName, messages_);
} else {
// parse the clauses for the unroll directive to find the full
// clause
for (auto &clause : unrollClauseList.v) {
if (clause.Id() == llvm::omp::OMPC_full) {
tileUnrollError(beginName, messages_);
// check the following block item to find the end directive
// for the loop transform directive.
if (nextIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
auto &endDirName = endDir->DirName();
if (endDirName.v == beginName.v &&
endDirName.v != llvm::omp::Directive::OMPD_fuse) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
nextIt = block.erase(nextIt);
}
}
}
Expand All @@ -231,11 +239,29 @@ class CanonicalizationOfOmp {
} else {
missingDoConstruct(beginName, messages_);
}

if (endFuseNeeded && nextIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
auto &endDirName = endDir->DirName();
if (endDirName.v == llvm::omp::Directive::OMPD_fuse) {
endFuseNeeded = false;
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
nextIt = block.erase(nextIt);
}
}
}
if (endFuseNeeded)
continue;
// If we get here, we either found a loop, or issued an error message.
return;
}
if (nextIt == block.end()) {
missingDoConstruct(beginName, messages_);
if (endFuseNeeded)
missingEndFuse(beginName, messages_);
else
missingDoConstruct(beginName, messages_);
}
}

Expand Down
Loading