Skip to content

Commit 1451f3d

Browse files
authored
[flang][OpenMP] Use StylizedInstance in converted clauses (#171907)
Invent `StylizedInstance` class to store special variables together with the instantiated expression in omp::clause::Initializer. This will eliminate the need for visiting the original AST nodes in lowering to MLIR.
1 parent 099985f commit 1451f3d

File tree

8 files changed

+81
-42
lines changed

8 files changed

+81
-42
lines changed

flang/include/flang/Lower/OpenMP/Clauses.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ Object makeObject(const parser::Designator &dsg,
113113
semantics::SemanticsContext &semaCtx);
114114
Object makeObject(const parser::StructureComponent &comp,
115115
semantics::SemanticsContext &semaCtx);
116+
Object makeObject(const parser::EntityDecl &decl,
117+
semantics::SemanticsContext &semaCtx);
116118

117119
inline auto makeObjectFn(semantics::SemanticsContext &semaCtx) {
118120
return [&](auto &&s) { return makeObject(s, semaCtx); };
@@ -172,6 +174,7 @@ std::optional<Object> getBaseObject(const Object &object,
172174
semantics::SemanticsContext &semaCtx);
173175

174176
namespace clause {
177+
using StylizedInstance = tomp::type::StylizedInstanceT<IdTy, ExprTy>;
175178
using Range = tomp::type::RangeT<ExprTy>;
176179
using Mapper = tomp::type::MapperT<IdTy, ExprTy>;
177180
using Iterator = tomp::type::IteratorT<TypeTy, IdTy, ExprTy>;

flang/include/flang/Semantics/openmp-utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ const SomeExpr *HasStorageOverlap(
9797
const SomeExpr &base, llvm::ArrayRef<SomeExpr> exprs);
9898
bool IsAssignment(const parser::ActionStmt *x);
9999
bool IsPointerAssignment(const evaluate::Assignment &x);
100+
101+
MaybeExpr MakeEvaluateExpr(const parser::OmpStylizedInstance &inp);
100102
} // namespace omp
101103
} // namespace Fortran::semantics
102104

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -383,36 +383,36 @@ bool ClauseProcessor::processInclusive(
383383
}
384384

385385
bool ClauseProcessor::processInitializer(
386-
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
386+
lower::SymMap &symMap,
387387
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
388388
if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
389389
genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
390390
mlir::Type type, mlir::Value ompOrig) {
391391
lower::SymMapScope scope(symMap);
392-
const parser::OmpInitializerExpression &iexpr = inp.v.v;
393-
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
394-
const std::list<parser::OmpStylizedDeclaration> &declList =
395-
std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t);
396392
mlir::Value ompPrivVar;
397-
for (const parser::OmpStylizedDeclaration &decl : declList) {
398-
auto &name = std::get<parser::ObjectName>(decl.var.t);
399-
assert(name.symbol && "Name does not have a symbol");
393+
const clause::StylizedInstance &inst = clause->v.front();
394+
395+
for (const Object &object :
396+
std::get<clause::StylizedInstance::Variables>(inst.t)) {
400397
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
401398
fir::StoreOp::create(builder, loc, ompOrig, addr);
402399
fir::FortranVariableFlagsEnum extraFlags = {};
403400
fir::FortranVariableFlagsAttr attributes =
404-
Fortran::lower::translateSymbolAttributes(builder.getContext(),
405-
*name.symbol, extraFlags);
406-
auto declareOp = hlfir::DeclareOp::create(
407-
builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
408-
0, attributes);
409-
if (name.ToString() == "omp_priv")
401+
Fortran::lower::translateSymbolAttributes(
402+
builder.getContext(), *object.sym(), extraFlags);
403+
std::string name = object.sym()->name().ToString();
404+
auto declareOp =
405+
hlfir::DeclareOp::create(builder, loc, addr, name, nullptr, {},
406+
nullptr, nullptr, 0, attributes);
407+
if (name == "omp_priv")
410408
ompPrivVar = declareOp.getResult(0);
411-
symMap.addVariableDefinition(*name.symbol, declareOp);
409+
symMap.addVariableDefinition(*object.sym(), declareOp);
412410
}
411+
413412
// Lower the expression/function call
414413
lower::StatementContext stmtCtx;
415-
const semantics::SomeExpr &initExpr = clause->v.front();
414+
const semantics::SomeExpr &initExpr =
415+
std::get<clause::StylizedInstance::Instance>(inst.t);
416416
mlir::Value result = common::visit(
417417
common::visitors{
418418
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include "flang/Lower/OpenMP/Clauses.h"
2121
#include "flang/Lower/Support/ReductionProcessor.h"
2222
#include "flang/Optimizer/Builder/Todo.h"
23-
#include "flang/Parser/parse-tree.h"
23+
#include "flang/Parser/char-block.h"
2424
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2525

2626
namespace fir {
@@ -89,7 +89,7 @@ class ClauseProcessor {
8989
bool processInclusive(mlir::Location currentLocation,
9090
mlir::omp::InclusiveClauseOps &result) const;
9191
bool processInitializer(
92-
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
92+
lower::SymMap &symMap,
9393
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const;
9494
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
9595
bool processNogroup(mlir::omp::NogroupClauseOps &result) const;

flang/lib/Lower/OpenMP/Clauses.cpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "flang/Parser/parse-tree.h"
1414
#include "flang/Semantics/expression.h"
1515
#include "flang/Semantics/openmp-modifiers.h"
16+
#include "flang/Semantics/openmp-utils.h"
1617
#include "flang/Semantics/symbol.h"
1718

1819
#include <list>
@@ -128,6 +129,11 @@ Object makeObject(const parser::OmpObject &object,
128129
return makeObject(std::get<parser::Designator>(object.u), semaCtx);
129130
}
130131

132+
Object makeObject(const parser::EntityDecl &decl,
133+
semantics::SemanticsContext &semaCtx) {
134+
return makeObject(std::get<parser::ObjectName>(decl.t), semaCtx);
135+
}
136+
131137
ObjectList makeObjects(const parser::OmpArgumentList &objects,
132138
semantics::SemanticsContext &semaCtx) {
133139
return makeList(objects.v, [&](const parser::OmpArgument &arg) {
@@ -275,12 +281,10 @@ makeIteratorSpecifiers(const parser::OmpIteratorSpecifier &inp,
275281
auto &tds = std::get<parser::TypeDeclarationStmt>(inp.t);
276282
auto &entities = std::get<std::list<parser::EntityDecl>>(tds.t);
277283
for (const parser::EntityDecl &ed : entities) {
278-
auto &name = std::get<parser::ObjectName>(ed.t);
279-
assert(name.symbol && "Expecting symbol for iterator variable");
280-
auto *stype = name.symbol->GetType();
281-
assert(stype && "Expecting symbol type");
282-
IteratorSpecifier spec{{evaluate::DynamicType::From(*stype),
283-
makeObject(name, semaCtx), range}};
284+
auto *symbol = std::get<parser::ObjectName>(ed.t).symbol;
285+
auto *type = DEREF(symbol).GetType();
286+
IteratorSpecifier spec{{evaluate::DynamicType::From(DEREF(type)),
287+
makeObject(ed, semaCtx), range}};
284288
specifiers.emplace_back(std::move(spec));
285289
}
286290

@@ -983,19 +987,24 @@ Initializer make(const parser::OmpClause::Initializer &inp,
983987
semantics::SemanticsContext &semaCtx) {
984988
const parser::OmpInitializerExpression &iexpr = inp.v.v;
985989
Initializer initializer;
986-
for (const parser::OmpStylizedInstance &styleInstance : iexpr.v) {
987-
auto &instance =
988-
std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t);
989-
if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
990-
auto &expr = std::get<parser::Expr>(as->t);
991-
initializer.v.push_back(makeExpr(expr, semaCtx));
992-
} else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
993-
assert(call->typedCall && "Expecting typedCall");
994-
const auto &procRef = *call->typedCall;
995-
initializer.v.push_back(semantics::SomeExpr(procRef));
996-
} else {
997-
llvm_unreachable("Unexpected initializer");
998-
}
990+
991+
for (const parser::OmpStylizedInstance &sinst : iexpr.v) {
992+
ObjectList variables;
993+
llvm::transform(
994+
std::get<std::list<parser::OmpStylizedDeclaration>>(sinst.t),
995+
std::back_inserter(variables),
996+
[&](const parser::OmpStylizedDeclaration &s) {
997+
return makeObject(s.var, semaCtx);
998+
});
999+
1000+
SomeExpr instance = [&]() {
1001+
if (auto &&expr = semantics::omp::MakeEvaluateExpr(sinst))
1002+
return std::move(*expr);
1003+
llvm_unreachable("Expecting expression instance");
1004+
}();
1005+
1006+
initializer.v.push_back(
1007+
StylizedInstance{{std::move(variables), std::move(instance)}});
9991008
}
10001009

10011010
return initializer;

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3764,9 +3764,7 @@ static void genOMP(
37643764
List<Clause> clauses = makeClauses(initializer, semaCtx);
37653765
ReductionProcessor::GenInitValueCBTy genInitValueCB;
37663766
ClauseProcessor cp(converter, semaCtx, clauses);
3767-
const parser::OmpClause::Initializer &iclause{
3768-
std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
3769-
cp.processInitializer(symTable, iclause, genInitValueCB);
3767+
cp.processInitializer(symTable, genInitValueCB);
37703768
const auto &identifier =
37713769
std::get<parser::OmpReductionIdentifier>(specifier.t);
37723770
const auto &designator =

flang/lib/Semantics/openmp-utils.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,4 +496,24 @@ bool IsPointerAssignment(const evaluate::Assignment &x) {
496496
return std::holds_alternative<evaluate::Assignment::BoundsSpec>(x.u) ||
497497
std::holds_alternative<evaluate::Assignment::BoundsRemapping>(x.u);
498498
}
499+
500+
MaybeExpr MakeEvaluateExpr(const parser::OmpStylizedInstance &inp) {
501+
auto &instance = std::get<parser::OmpStylizedInstance::Instance>(inp.t);
502+
503+
return common::visit( //
504+
common::visitors{
505+
[&](const parser::AssignmentStmt &s) -> MaybeExpr {
506+
return GetEvaluateExpr(std::get<parser::Expr>(s.t));
507+
},
508+
[&](const parser::CallStmt &s) -> MaybeExpr {
509+
assert(s.typedCall && "Expecting typedCall");
510+
const auto &procRef = *s.typedCall;
511+
return SomeExpr(procRef);
512+
},
513+
[&](const common::Indirection<parser::Expr> &s) -> MaybeExpr {
514+
return GetEvaluateExpr(s.value());
515+
},
516+
},
517+
instance.u);
518+
}
499519
} // namespace Fortran::semantics::omp

llvm/include/llvm/Frontend/OpenMP/ClauseT.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,14 @@ template <typename I, typename E> using ObjectListT = ListT<ObjectT<I, E>>;
189189

190190
using DirectiveName = llvm::omp::Directive;
191191

192+
template <typename I, typename E> //
193+
struct StylizedInstanceT {
194+
using Variables = ObjectListT<I, E>;
195+
using Instance = E;
196+
using TupleTrait = std::true_type;
197+
std::tuple<Variables, Instance> t;
198+
};
199+
192200
template <typename I, typename E> //
193201
struct DefinedOperatorT {
194202
struct DefinedOpName {
@@ -762,8 +770,7 @@ struct InitT {
762770
// V5.2: [5.5.4] `initializer` clause
763771
template <typename T, typename I, typename E> //
764772
struct InitializerT {
765-
using InitializerExpr = E;
766-
using List = ListT<InitializerExpr>;
773+
using List = ListT<type::StylizedInstanceT<I, E>>;
767774
using WrapperTrait = std::true_type;
768775
List v;
769776
};

0 commit comments

Comments
 (0)