Skip to content

Commit bfcf851

Browse files
jsjodinaadeshps-mcw
authored andcommitted
[OpenMP][flang] Lowering of OpenMP custom reductions to MLIR (llvm#168417)
This patch add support for lowering of custom reductions to MLIR. It also enhances the capability of the pass to automatically mark functions as "declare target" by traversing custom reduction initializers and combiners.
1 parent 1e1e94c commit bfcf851

16 files changed

+772
-106
lines changed

flang/include/flang/Lower/Support/ReductionProcessor.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ namespace omp {
4040

4141
class ReductionProcessor {
4242
public:
43+
using GenInitValueCBTy =
44+
std::function<mlir::Value(fir::FirOpBuilder &builder, mlir::Location loc,
45+
mlir::Type type, mlir::Value ompOrig)>;
46+
using GenCombinerCBTy = std::function<void(
47+
fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type type,
48+
mlir::Value op1, mlir::Value op2, bool isByRef)>;
49+
4350
// TODO: Move this enumeration to the OpenMP dialect
4451
enum ReductionIdentifier {
4552
ID,
@@ -58,6 +65,9 @@ class ReductionProcessor {
5865
IEOR
5966
};
6067

68+
static bool doReductionByRef(mlir::Type reductionType);
69+
static bool doReductionByRef(mlir::Value reductionVar);
70+
6171
static ReductionIdentifier
6272
getReductionType(const omp::clause::ProcedureDesignator &pd);
6373

@@ -109,6 +119,14 @@ class ReductionProcessor {
109119
ReductionIdentifier redId,
110120
mlir::Type type, mlir::Value op1,
111121
mlir::Value op2);
122+
/// Creates an OpenMP reduction declaration and inserts it into the provided
123+
/// symbol table. The init and combiner regions are generated by the callback
124+
/// functions genCombinerCB and genInitValueCB.
125+
template <typename DeclareRedType>
126+
static DeclareRedType createDeclareReductionHelper(
127+
AbstractConverter &converter, llvm::StringRef reductionOpName,
128+
mlir::Type type, mlir::Location loc, bool isByRef,
129+
GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB);
112130

113131
/// Creates an OpenMP reduction declaration and inserts it into the provided
114132
/// symbol table. The declaration has a constant initializer with the neutral

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "ClauseProcessor.h"
1414
#include "Utils.h"
1515

16+
#include "flang/Lower/ConvertCall.h"
1617
#include "flang/Lower/ConvertExprToHLFIR.h"
1718
#include "flang/Lower/OpenMP/Clauses.h"
1819
#include "flang/Lower/PFTBuilder.h"
@@ -402,6 +403,65 @@ bool ClauseProcessor::processInclusive(
402403
return false;
403404
}
404405

406+
bool ClauseProcessor::processInitializer(
407+
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
408+
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const {
409+
if (auto *clause = findUniqueClause<omp::clause::Initializer>()) {
410+
genInitValueCB = [&, clause](fir::FirOpBuilder &builder, mlir::Location loc,
411+
mlir::Type type, mlir::Value ompOrig) {
412+
lower::SymMapScope scope(symMap);
413+
const parser::OmpInitializerExpression &iexpr = inp.v.v;
414+
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
415+
const std::list<parser::OmpStylizedDeclaration> &declList =
416+
std::get<std::list<parser::OmpStylizedDeclaration>>(styleInstance.t);
417+
mlir::Value ompPrivVar;
418+
for (const parser::OmpStylizedDeclaration &decl : declList) {
419+
auto &name = std::get<parser::ObjectName>(decl.var.t);
420+
assert(name.symbol && "Name does not have a symbol");
421+
mlir::Value addr = builder.createTemporary(loc, ompOrig.getType());
422+
fir::StoreOp::create(builder, loc, ompOrig, addr);
423+
fir::FortranVariableFlagsEnum extraFlags = {};
424+
fir::FortranVariableFlagsAttr attributes =
425+
Fortran::lower::translateSymbolAttributes(builder.getContext(),
426+
*name.symbol, extraFlags);
427+
auto declareOp = hlfir::DeclareOp::create(
428+
builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr,
429+
0, attributes);
430+
if (name.ToString() == "omp_priv")
431+
ompPrivVar = declareOp.getResult(0);
432+
symMap.addVariableDefinition(*name.symbol, declareOp);
433+
}
434+
// Lower the expression/function call
435+
lower::StatementContext stmtCtx;
436+
mlir::Value result = common::visit(
437+
common::visitors{
438+
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
439+
convertCallToHLFIR(loc, converter, procRef, std::nullopt,
440+
symMap, stmtCtx);
441+
auto privVal = fir::LoadOp::create(builder, loc, ompPrivVar);
442+
return privVal;
443+
},
444+
[&](const auto &expr) -> mlir::Value {
445+
mlir::Value exprResult = fir::getBase(convertExprToValue(
446+
loc, converter, clause->v, symMap, stmtCtx));
447+
// Conversion can either give a value or a refrence to a value,
448+
// we need to return the reduction type, so an optional load may
449+
// be generated.
450+
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(
451+
exprResult.getType()))
452+
if (ompPrivVar.getType() == refType)
453+
exprResult = fir::LoadOp::create(builder, loc, exprResult);
454+
return exprResult;
455+
}},
456+
clause->v.u);
457+
stmtCtx.finalizeAndPop();
458+
return result;
459+
};
460+
return true;
461+
}
462+
return false;
463+
}
464+
405465
bool ClauseProcessor::processMergeable(
406466
mlir::omp::MergeableClauseOps &result) const {
407467
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "flang/Lower/Bridge.h"
1919
#include "flang/Lower/DirectivesCommon.h"
2020
#include "flang/Lower/OpenMP/Clauses.h"
21+
#include "flang/Lower/Support/ReductionProcessor.h"
2122
#include "flang/Optimizer/Builder/Todo.h"
2223
#include "flang/Parser/dump-parse-tree.h"
2324
#include "flang/Parser/parse-tree.h"
@@ -88,6 +89,9 @@ class ClauseProcessor {
8889
bool processHint(mlir::omp::HintClauseOps &result) const;
8990
bool processInclusive(mlir::Location currentLocation,
9091
mlir::omp::InclusiveClauseOps &result) const;
92+
bool processInitializer(
93+
lower::SymMap &symMap, const parser::OmpClause::Initializer &inp,
94+
ReductionProcessor::GenInitValueCBTy &genInitValueCB) const;
9195
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
9296
bool processNogroup(mlir::omp::NogroupClauseOps &result) const;
9397
bool processNowait(mlir::omp::NowaitClauseOps &result) const;

flang/lib/Lower/OpenMP/Clauses.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,22 @@ Init make(const parser::OmpClause::Init &inp,
981981

982982
Initializer make(const parser::OmpClause::Initializer &inp,
983983
semantics::SemanticsContext &semaCtx) {
984-
llvm_unreachable("Empty: initializer");
984+
const parser::OmpInitializerExpression &iexpr = inp.v.v;
985+
const parser::OmpStylizedInstance &styleInstance = iexpr.v.front();
986+
const parser::OmpStylizedInstance::Instance &instance =
987+
std::get<parser::OmpStylizedInstance::Instance>(styleInstance.t);
988+
if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
989+
auto &expr = std::get<parser::Expr>(as->t);
990+
return Initializer{makeExpr(expr, semaCtx)};
991+
} else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
992+
if (call->typedCall) {
993+
const auto &procRef = *call->typedCall;
994+
semantics::SomeExpr evalProcRef{procRef};
995+
return Initializer{evalProcRef};
996+
}
997+
}
998+
999+
llvm_unreachable("Unexpected initializer");
9851000
}
9861001

9871002
InReduction make(const parser::OmpClause::InReduction &inp,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 149 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818
#include "Decomposer.h"
1919
#include "Utils.h"
2020
#include "flang/Common/idioms.h"
21+
#include "flang/Evaluate/type.h"
2122
#include "flang/Lower/Bridge.h"
2223
#include "flang/Lower/ConvertExpr.h"
24+
#include "flang/Lower/ConvertExprToHLFIR.h"
2325
#include "flang/Lower/ConvertVariable.h"
2426
#include "flang/Lower/DirectivesCommon.h"
2527
#include "flang/Lower/OpenMP/Clauses.h"
2628
#include "flang/Lower/StatementContext.h"
29+
#include "flang/Lower/Support/ReductionProcessor.h"
2730
#include "flang/Lower/SymbolMap.h"
2831
#include "flang/Optimizer/Builder/BoxValue.h"
2932
#include "flang/Optimizer/Builder/FIRBuilder.h"
@@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
28472850
// TODO: Add private syms and vars.
28482851
args.reduction.syms = reductionSyms;
28492852
args.reduction.vars = clauseOps.reductionVars;
2850-
28512853
return genOpWithBody<mlir::omp::TeamsOp>(
28522854
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
28532855
llvm::omp::Directive::OMPD_teams)
@@ -3570,12 +3572,156 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
35703572
TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective");
35713573
}
35723574

3575+
static ReductionProcessor::GenCombinerCBTy
3576+
processReductionCombiner(lower::AbstractConverter &converter,
3577+
lower::SymMap &symTable,
3578+
semantics::SemanticsContext &semaCtx,
3579+
const parser::OmpReductionSpecifier &specifier) {
3580+
ReductionProcessor::GenCombinerCBTy genCombinerCB;
3581+
const auto &combinerExpression =
3582+
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
3583+
.value();
3584+
const parser::OmpStylizedInstance &combinerInstance =
3585+
combinerExpression.v.front();
3586+
const parser::OmpStylizedInstance::Instance &instance =
3587+
std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
3588+
3589+
const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u);
3590+
if (!as) {
3591+
TODO(converter.getCurrentLocation(),
3592+
"A combiner that is a subroutine call is not yet supported");
3593+
}
3594+
auto &expr = std::get<parser::Expr>(as->t);
3595+
genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
3596+
mlir::Type type, mlir::Value lhs, mlir::Value rhs,
3597+
bool isByRef) {
3598+
const auto &evalExpr = makeExpr(expr, semaCtx);
3599+
lower::SymMapScope scope(symTable);
3600+
const std::list<parser::OmpStylizedDeclaration> &declList =
3601+
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
3602+
for (const parser::OmpStylizedDeclaration &decl : declList) {
3603+
auto &name = std::get<parser::ObjectName>(decl.var.t);
3604+
mlir::Value addr = lhs;
3605+
mlir::Type type = lhs.getType();
3606+
bool isRhs = name.ToString() == std::string("omp_in");
3607+
if (isRhs) {
3608+
addr = rhs;
3609+
type = rhs.getType();
3610+
}
3611+
3612+
assert(name.symbol && "Reduction object name does not have a symbol");
3613+
if (!fir::conformsWithPassByRef(type)) {
3614+
addr = builder.createTemporary(loc, type);
3615+
fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr);
3616+
}
3617+
fir::FortranVariableFlagsEnum extraFlags = {};
3618+
fir::FortranVariableFlagsAttr attributes =
3619+
Fortran::lower::translateSymbolAttributes(builder.getContext(),
3620+
*name.symbol, extraFlags);
3621+
auto declareOp =
3622+
hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
3623+
{}, nullptr, nullptr, 0, attributes);
3624+
symTable.addVariableDefinition(*name.symbol, declareOp);
3625+
}
3626+
3627+
lower::StatementContext stmtCtx;
3628+
mlir::Value result = fir::getBase(
3629+
convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
3630+
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
3631+
if (lhs.getType() == refType.getElementType())
3632+
result = fir::LoadOp::create(builder, loc, result);
3633+
stmtCtx.finalizeAndPop();
3634+
if (isByRef) {
3635+
fir::StoreOp::create(builder, loc, result, lhs);
3636+
mlir::omp::YieldOp::create(builder, loc, lhs);
3637+
} else {
3638+
mlir::omp::YieldOp::create(builder, loc, result);
3639+
}
3640+
};
3641+
return genCombinerCB;
3642+
}
3643+
3644+
// Checks that the reduction type is either a trivial type or a derived type of
3645+
// trivial types.
3646+
static bool isSimpleReductionType(mlir::Type reductionType) {
3647+
if (fir::isa_trivial(reductionType))
3648+
return true;
3649+
if (auto recordTy = mlir::dyn_cast<fir::RecordType>(reductionType)) {
3650+
for (auto [_, fieldType] : recordTy.getTypeList()) {
3651+
if (!fir::isa_trivial(fieldType))
3652+
return false;
3653+
}
3654+
}
3655+
return true;
3656+
}
3657+
3658+
// Getting the type from a symbol compared to a DeclSpec is simpler since we do
3659+
// not need to consider derived vs intrinsic types. Semantics is guaranteed to
3660+
// generate these symbols.
3661+
static mlir::Type
3662+
getReductionType(lower::AbstractConverter &converter,
3663+
const parser::OmpReductionSpecifier &specifier) {
3664+
const auto &combinerExpression =
3665+
std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t)
3666+
.value();
3667+
const parser::OmpStylizedInstance &combinerInstance =
3668+
combinerExpression.v.front();
3669+
const std::list<parser::OmpStylizedDeclaration> &declList =
3670+
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
3671+
const parser::OmpStylizedDeclaration &decl = declList.front();
3672+
const auto &name = std::get<parser::ObjectName>(decl.var.t);
3673+
const auto &symbol = semantics::SymbolRef(*name.symbol);
3674+
mlir::Type reductionType = converter.genType(symbol);
3675+
3676+
if (!isSimpleReductionType(reductionType))
3677+
TODO(converter.getCurrentLocation(),
3678+
"declare reduction currently only supports trival types or derived "
3679+
"types containing trivial types");
3680+
return reductionType;
3681+
}
3682+
35733683
static void genOMP(
35743684
lower::AbstractConverter &converter, lower::SymMap &symTable,
35753685
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
35763686
const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) {
3577-
if (!semaCtx.langOptions().OpenMPSimd)
3578-
TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct");
3687+
if (semaCtx.langOptions().OpenMPSimd)
3688+
return;
3689+
3690+
const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()};
3691+
const parser::OmpArgument &arg{args.v.front()};
3692+
const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u);
3693+
3694+
if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1)
3695+
TODO(converter.getCurrentLocation(),
3696+
"multiple types in declare reduction is not yet supported");
3697+
3698+
mlir::Type reductionType = getReductionType(converter, specifier);
3699+
ReductionProcessor::GenCombinerCBTy genCombinerCB =
3700+
processReductionCombiner(converter, symTable, semaCtx, specifier);
3701+
const parser::OmpClauseList &initializer =
3702+
declareReductionConstruct.v.Clauses();
3703+
if (initializer.v.size() > 0) {
3704+
List<Clause> clauses = makeClauses(initializer, semaCtx);
3705+
ReductionProcessor::GenInitValueCBTy genInitValueCB;
3706+
ClauseProcessor cp(converter, semaCtx, clauses);
3707+
const parser::OmpClause::Initializer &iclause{
3708+
std::get<parser::OmpClause::Initializer>(initializer.v.front().u)};
3709+
cp.processInitializer(symTable, iclause, genInitValueCB);
3710+
const auto &identifier =
3711+
std::get<parser::OmpReductionIdentifier>(specifier.t);
3712+
const auto &designator =
3713+
std::get<parser::ProcedureDesignator>(identifier.u);
3714+
const auto &reductionName = std::get<parser::Name>(designator.u);
3715+
bool isByRef = ReductionProcessor::doReductionByRef(reductionType);
3716+
ReductionProcessor::createDeclareReductionHelper<
3717+
mlir::omp::DeclareReductionOp>(
3718+
converter, reductionName.ToString(), reductionType,
3719+
converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB);
3720+
} else {
3721+
TODO(converter.getCurrentLocation(),
3722+
"declare reduction without an initializer clause is not yet "
3723+
"supported");
3724+
}
35793725
}
35803726

35813727
static void

0 commit comments

Comments
 (0)