Skip to content

Commit 0a59fb0

Browse files
[flang][openmp]Add UserReductionDetails and use in DECLARE REDUCTION
This adds another puzzle piece for the support of OpenMP DECLARE REDUCTION functionality. This adds support for operators with derived types, as well as declaring multiple different types with the same name or operator. A new detail class for UserReductionDetials is introduced to hold the list of types supported for a given reduction declaration. Tests for parsing and symbol generation added. Declare reduction is still not supported to lowering, it will generate a "Not yet implemented" fatal error.
1 parent b3b0070 commit 0a59fb0

File tree

12 files changed

+616
-20
lines changed

12 files changed

+616
-20
lines changed

flang/include/flang/Semantics/symbol.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,14 +701,33 @@ class GenericDetails {
701701
};
702702
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const GenericDetails &);
703703

704+
class UserReductionDetails : public WithBindName {
705+
public:
706+
using TypeVector = std::vector<const DeclTypeSpec *>;
707+
UserReductionDetails() = default;
708+
709+
void AddType(const DeclTypeSpec *type) { typeList_.push_back(type); }
710+
const TypeVector &GetTypeList() const { return typeList_; }
711+
712+
bool SupportsType(const DeclTypeSpec *type) const {
713+
for (auto t : typeList_)
714+
if (t == type)
715+
return true;
716+
return false;
717+
}
718+
719+
private:
720+
TypeVector typeList_;
721+
};
722+
704723
class UnknownDetails {};
705724

706725
using Details = std::variant<UnknownDetails, MainProgramDetails, ModuleDetails,
707726
SubprogramDetails, SubprogramNameDetails, EntityDetails,
708727
ObjectEntityDetails, ProcEntityDetails, AssocEntityDetails,
709728
DerivedTypeDetails, UseDetails, UseErrorDetails, HostAssocDetails,
710729
GenericDetails, ProcBindingDetails, NamelistDetails, CommonBlockDetails,
711-
TypeParamDetails, MiscDetails>;
730+
TypeParamDetails, MiscDetails, UserReductionDetails>;
712731
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const Details &);
713732
std::string DetailsToString(const Details &);
714733

flang/lib/Semantics/check-omp-structure.cpp

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "check-omp-structure.h"
1010
#include "definable.h"
11+
#include "resolve-names-utils.h"
1112
#include "flang/Evaluate/check-expression.h"
1213
#include "flang/Evaluate/expression.h"
1314
#include "flang/Evaluate/type.h"
@@ -3361,8 +3362,8 @@ bool OmpStructureChecker::CheckReductionOperator(
33613362
valid =
33623363
llvm::is_contained({"max", "min", "iand", "ior", "ieor"}, realName);
33633364
if (!valid) {
3364-
auto *misc{name->symbol->detailsIf<MiscDetails>()};
3365-
valid = misc && misc->kind() == MiscDetails::Kind::ConstructName;
3365+
auto *reductionDetails{name->symbol->detailsIf<UserReductionDetails>()};
3366+
valid = reductionDetails != nullptr;
33663367
}
33673368
}
33683369
if (!valid) {
@@ -3444,7 +3445,8 @@ void OmpStructureChecker::CheckReductionObjects(
34443445
}
34453446

34463447
static bool IsReductionAllowedForType(
3447-
const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type) {
3448+
const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type,
3449+
const Scope &scope) {
34483450
auto isLogical{[](const DeclTypeSpec &type) -> bool {
34493451
return type.category() == DeclTypeSpec::Logical;
34503452
}};
@@ -3464,9 +3466,11 @@ static bool IsReductionAllowedForType(
34643466
case parser::DefinedOperator::IntrinsicOperator::Multiply:
34653467
case parser::DefinedOperator::IntrinsicOperator::Add:
34663468
case parser::DefinedOperator::IntrinsicOperator::Subtract:
3467-
return type.IsNumeric(TypeCategory::Integer) ||
3469+
if (type.IsNumeric(TypeCategory::Integer) ||
34683470
type.IsNumeric(TypeCategory::Real) ||
3469-
type.IsNumeric(TypeCategory::Complex);
3471+
type.IsNumeric(TypeCategory::Complex))
3472+
return true;
3473+
break;
34703474

34713475
case parser::DefinedOperator::IntrinsicOperator::AND:
34723476
case parser::DefinedOperator::IntrinsicOperator::OR:
@@ -3479,8 +3483,18 @@ static bool IsReductionAllowedForType(
34793483
DIE("This should have been caught in CheckIntrinsicOperator");
34803484
return false;
34813485
}
3486+
parser::CharBlock name{MakeNameFromOperator(*intrinsicOp)};
3487+
Symbol *symbol{scope.FindSymbol(name)};
3488+
if (symbol) {
3489+
const auto *reductionDetails{symbol->detailsIf<UserReductionDetails>()};
3490+
assert(reductionDetails && "Expected to find reductiondetails");
3491+
3492+
return reductionDetails->SupportsType(&type);
3493+
}
3494+
return false;
34823495
}
3483-
return true;
3496+
assert(0 && "Intrinsic Operator not found - parsing gone wrong?");
3497+
return false; // Reject everything else.
34843498
}};
34853499

34863500
auto checkDesignator{[&](const parser::ProcedureDesignator &procD) {
@@ -3493,18 +3507,42 @@ static bool IsReductionAllowedForType(
34933507
// IAND: arguments must be integers: F2023 16.9.100
34943508
// IEOR: arguments must be integers: F2023 16.9.106
34953509
// IOR: arguments must be integers: F2023 16.9.111
3496-
return type.IsNumeric(TypeCategory::Integer);
3510+
if (type.IsNumeric(TypeCategory::Integer)) {
3511+
return true;
3512+
}
34973513
} else if (realName == "max" || realName == "min") {
34983514
// MAX: arguments must be integer, real, or character:
34993515
// F2023 16.9.135
35003516
// MIN: arguments must be integer, real, or character:
35013517
// F2023 16.9.141
3502-
return type.IsNumeric(TypeCategory::Integer) ||
3503-
type.IsNumeric(TypeCategory::Real) || isCharacter(type);
3518+
if (type.IsNumeric(TypeCategory::Integer) ||
3519+
type.IsNumeric(TypeCategory::Real) || isCharacter(type)) {
3520+
return true;
3521+
}
35043522
}
3523+
3524+
// If we get here, it may be a user declared reduction, so check
3525+
// if the symbol has UserReductionDetails, and if so, the type is
3526+
// supported.
3527+
if (const auto *reductionDetails{
3528+
name->symbol->detailsIf<UserReductionDetails>()}) {
3529+
return reductionDetails->SupportsType(&type);
3530+
}
3531+
3532+
// We also need to check for mangled names (max, min, iand, ieor and ior)
3533+
// and then check if the type is there.
3534+
parser::CharBlock mangledName = MangleSpecialFunctions(name->source);
3535+
if (const auto &symbol{scope.FindSymbol(mangledName)}) {
3536+
if (const auto *reductionDetails{
3537+
symbol->detailsIf<UserReductionDetails>()}) {
3538+
return reductionDetails->SupportsType(&type);
3539+
}
3540+
}
3541+
// Everything else is "not matching type".
3542+
return false;
35053543
}
3506-
// TODO: user defined reduction operators. Just allow everything for now.
3507-
return true;
3544+
assert(0 && "name and name->symbol should be set here...");
3545+
return false;
35083546
}};
35093547

35103548
return common::visit(
@@ -3519,7 +3557,8 @@ void OmpStructureChecker::CheckReductionObjectTypes(
35193557

35203558
for (auto &[symbol, source] : symbols) {
35213559
if (auto *type{symbol->GetType()}) {
3522-
if (!IsReductionAllowedForType(ident, *type)) {
3560+
const auto &scope{context_.FindScope(symbol->name())};
3561+
if (!IsReductionAllowedForType(ident, *type, scope)) {
35233562
context_.Say(source,
35243563
"The type of '%s' is incompatible with the reduction operator."_err_en_US,
35253564
symbol->name());

flang/lib/Semantics/resolve-names-utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,9 @@ struct SymbolAndTypeMappings;
146146
void MapSubprogramToNewSymbols(const Symbol &oldSymbol, Symbol &newSymbol,
147147
Scope &newScope, SymbolAndTypeMappings * = nullptr);
148148

149+
parser::CharBlock MakeNameFromOperator(
150+
const parser::DefinedOperator::IntrinsicOperator &op);
151+
parser::CharBlock MangleSpecialFunctions(const parser::CharBlock name);
152+
149153
} // namespace Fortran::semantics
150154
#endif // FORTRAN_SEMANTICS_RESOLVE_NAMES_H_

flang/lib/Semantics/resolve-names.cpp

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,15 +1748,75 @@ void OmpVisitor::ProcessMapperSpecifier(const parser::OmpMapperSpecifier &spec,
17481748
PopScope();
17491749
}
17501750

1751+
parser::CharBlock MakeNameFromOperator(
1752+
const parser::DefinedOperator::IntrinsicOperator &op) {
1753+
switch (op) {
1754+
case parser::DefinedOperator::IntrinsicOperator::Multiply:
1755+
return parser::CharBlock{"op.*", 4};
1756+
case parser::DefinedOperator::IntrinsicOperator::Add:
1757+
return parser::CharBlock{"op.+", 4};
1758+
case parser::DefinedOperator::IntrinsicOperator::Subtract:
1759+
return parser::CharBlock{"op.-", 4};
1760+
1761+
case parser::DefinedOperator::IntrinsicOperator::AND:
1762+
return parser::CharBlock{"op.AND", 6};
1763+
case parser::DefinedOperator::IntrinsicOperator::OR:
1764+
return parser::CharBlock{"op.OR", 6};
1765+
case parser::DefinedOperator::IntrinsicOperator::EQV:
1766+
return parser::CharBlock{"op.EQV", 7};
1767+
case parser::DefinedOperator::IntrinsicOperator::NEQV:
1768+
return parser::CharBlock{"op.NEQV", 8};
1769+
1770+
default:
1771+
assert(0 && "Unsupported operator...");
1772+
return parser::CharBlock{"op.?", 4};
1773+
}
1774+
}
1775+
1776+
parser::CharBlock MangleSpecialFunctions(const parser::CharBlock name) {
1777+
if (name == "max") {
1778+
return parser::CharBlock{"op.max", 6};
1779+
}
1780+
if (name == "min") {
1781+
return parser::CharBlock{"op.min", 6};
1782+
}
1783+
if (name == "iand") {
1784+
return parser::CharBlock{"op.iand", 7};
1785+
}
1786+
if (name == "ior") {
1787+
return parser::CharBlock{"op.ior", 6};
1788+
}
1789+
if (name == "ieor") {
1790+
return parser::CharBlock{"op.ieor", 7};
1791+
}
1792+
// All other names: return as is.
1793+
return name;
1794+
}
1795+
17511796
void OmpVisitor::ProcessReductionSpecifier(
17521797
const parser::OmpReductionSpecifier &spec,
17531798
const std::optional<parser::OmpClauseList> &clauses) {
1799+
const parser::Name *name{nullptr};
1800+
parser::Name mangledName{};
1801+
UserReductionDetails reductionDetailsTemp{};
17541802
const auto &id{std::get<parser::OmpReductionIdentifier>(spec.t)};
17551803
if (auto procDes{std::get_if<parser::ProcedureDesignator>(&id.u)}) {
1756-
if (auto *name{std::get_if<parser::Name>(&procDes->u)}) {
1757-
name->symbol =
1758-
&MakeSymbol(*name, MiscDetails{MiscDetails::Kind::ConstructName});
1804+
name = std::get_if<parser::Name>(&procDes->u);
1805+
if (name) {
1806+
mangledName.source = MangleSpecialFunctions(name->source);
17591807
}
1808+
} else {
1809+
const auto &defOp{std::get<parser::DefinedOperator>(id.u)};
1810+
mangledName.source = MakeNameFromOperator(
1811+
std::get<parser::DefinedOperator::IntrinsicOperator>(defOp.u));
1812+
name = &mangledName;
1813+
}
1814+
1815+
UserReductionDetails *reductionDetails{&reductionDetailsTemp};
1816+
Symbol *symbol{name ? name->symbol : nullptr};
1817+
symbol = FindSymbol(mangledName);
1818+
if (symbol) {
1819+
reductionDetails = symbol->detailsIf<UserReductionDetails>();
17601820
}
17611821

17621822
auto &typeList{std::get<parser::OmpTypeNameList>(spec.t)};
@@ -1788,6 +1848,10 @@ void OmpVisitor::ProcessReductionSpecifier(
17881848
const DeclTypeSpec *typeSpec{GetDeclTypeSpec()};
17891849
assert(typeSpec && "We should have a type here");
17901850

1851+
if (reductionDetails) {
1852+
reductionDetails->AddType(typeSpec);
1853+
}
1854+
17911855
for (auto &nm : ompVarNames) {
17921856
ObjectEntityDetails details{};
17931857
details.set_type(*typeSpec);
@@ -1798,6 +1862,13 @@ void OmpVisitor::ProcessReductionSpecifier(
17981862
Walk(clauses);
17991863
PopScope();
18001864
}
1865+
1866+
if (name) {
1867+
if (!symbol) {
1868+
symbol = &MakeSymbol(mangledName, Attrs{}, std::move(*reductionDetails));
1869+
}
1870+
name->symbol = symbol;
1871+
}
18011872
}
18021873

18031874
bool OmpVisitor::Pre(const parser::OmpDirectiveSpecification &x) {

flang/lib/Semantics/symbol.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ void GenericDetails::CopyFrom(const GenericDetails &from) {
246246
// This is primarily for debugging.
247247
std::string DetailsToString(const Details &details) {
248248
return common::visit(
249-
common::visitors{
249+
common::visitors{//
250250
[](const UnknownDetails &) { return "Unknown"; },
251251
[](const MainProgramDetails &) { return "MainProgram"; },
252252
[](const ModuleDetails &) { return "Module"; },
@@ -266,7 +266,7 @@ std::string DetailsToString(const Details &details) {
266266
[](const TypeParamDetails &) { return "TypeParam"; },
267267
[](const MiscDetails &) { return "Misc"; },
268268
[](const AssocEntityDetails &) { return "AssocEntity"; },
269-
},
269+
[](const UserReductionDetails &) { return "UserReductionDetails"; }},
270270
details);
271271
}
272272

@@ -300,6 +300,9 @@ bool Symbol::CanReplaceDetails(const Details &details) const {
300300
[&](const HostAssocDetails &) {
301301
return this->has<HostAssocDetails>();
302302
},
303+
[&](const UserReductionDetails &) {
304+
return this->has<UserReductionDetails>();
305+
},
303306
[](const auto &) { return false; },
304307
},
305308
details);
@@ -598,6 +601,11 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) {
598601
[&](const MiscDetails &x) {
599602
os << ' ' << MiscDetails::EnumToString(x.kind());
600603
},
604+
[&](const UserReductionDetails &x) {
605+
for (auto &type : x.GetTypeList()) {
606+
DumpType(os, type);
607+
}
608+
},
601609
[&](const auto &x) { os << x; },
602610
},
603611
details);

0 commit comments

Comments
 (0)