-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[OpenMP][flang] Lowering of OpenMP custom reductions to MLIR #168417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
f8a24cb
1090379
c5aec84
d914b85
271ad67
be3bb13
50c7cf9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,12 +18,15 @@ | |
| #include "Decomposer.h" | ||
| #include "Utils.h" | ||
| #include "flang/Common/idioms.h" | ||
| #include "flang/Evaluate/type.h" | ||
| #include "flang/Lower/Bridge.h" | ||
| #include "flang/Lower/ConvertExpr.h" | ||
| #include "flang/Lower/ConvertExprToHLFIR.h" | ||
| #include "flang/Lower/ConvertVariable.h" | ||
| #include "flang/Lower/DirectivesCommon.h" | ||
| #include "flang/Lower/OpenMP/Clauses.h" | ||
| #include "flang/Lower/StatementContext.h" | ||
| #include "flang/Lower/Support/ReductionProcessor.h" | ||
| #include "flang/Lower/SymbolMap.h" | ||
| #include "flang/Optimizer/Builder/BoxValue.h" | ||
| #include "flang/Optimizer/Builder/FIRBuilder.h" | ||
|
|
@@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, | |
| // TODO: Add private syms and vars. | ||
| args.reduction.syms = reductionSyms; | ||
| args.reduction.vars = clauseOps.reductionVars; | ||
|
|
||
| return genOpWithBody<mlir::omp::TeamsOp>( | ||
| OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, | ||
| llvm::omp::Directive::OMPD_teams) | ||
|
|
@@ -3563,12 +3565,137 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, | |
| TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective"); | ||
| } | ||
|
|
||
| static bool | ||
| processReductionCombiner(lower::AbstractConverter &converter, | ||
| lower::SymMap &symTable, | ||
| semantics::SemanticsContext &semaCtx, | ||
| const parser::OmpReductionSpecifier &specifier, | ||
| ReductionProcessor::GenCombinerCBTy &genCombinerCB) { | ||
| const auto &combinerExpression = | ||
| std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t) | ||
| .value(); | ||
| const parser::OmpStylizedInstance &combinerInstance = | ||
| combinerExpression.v.front(); | ||
| const parser::OmpStylizedInstance::Instance &instance = | ||
| std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t); | ||
| if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) { | ||
|
||
| auto &expr = std::get<parser::Expr>(as->t); | ||
| genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc, | ||
| mlir::Type type, mlir::Value lhs, mlir::Value rhs, | ||
| bool isByRef) { | ||
| const auto &evalExpr = makeExpr(expr, semaCtx); | ||
| lower::SymMapScope scope(symTable); | ||
| const std::list<parser::OmpStylizedDeclaration> &declList = | ||
| std::get<std::list<parser::OmpStylizedDeclaration>>( | ||
| combinerInstance.t); | ||
| for (const parser::OmpStylizedDeclaration &decl : declList) { | ||
| auto &name = std::get<parser::ObjectName>(decl.var.t); | ||
| mlir::Value addr = lhs; | ||
| mlir::Type type = lhs.getType(); | ||
| bool isRhs = name.ToString() == std::string("omp_in"); | ||
| if (isRhs) { | ||
| addr = rhs; | ||
| type = rhs.getType(); | ||
| } | ||
|
|
||
| assert(name.symbol && "Reduction object name does not have a symbol"); | ||
| if (!fir::conformsWithPassByRef(type)) { | ||
| addr = builder.createTemporary(loc, type); | ||
| fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr); | ||
| } | ||
| fir::FortranVariableFlagsEnum extraFlags = {}; | ||
| fir::FortranVariableFlagsAttr attributes = | ||
| Fortran::lower::translateSymbolAttributes(builder.getContext(), | ||
| *name.symbol, extraFlags); | ||
| auto declareOp = hlfir::DeclareOp::create( | ||
| builder, loc, addr, name.ToString(), nullptr, {}, nullptr, nullptr, | ||
| 0, attributes); | ||
| symTable.addVariableDefinition(*name.symbol, declareOp); | ||
| } | ||
|
|
||
| lower::StatementContext stmtCtx; | ||
| mlir::Value result = fir::getBase( | ||
| convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx)); | ||
| if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType())) | ||
| if (lhs.getType() == refType.getElementType()) | ||
| result = fir::LoadOp::create(builder, loc, result); | ||
| stmtCtx.finalizeAndPop(); | ||
| if (isByRef) { | ||
| fir::StoreOp::create(builder, loc, result, lhs); | ||
| mlir::omp::YieldOp::create(builder, loc, lhs); | ||
| } else { | ||
| mlir::omp::YieldOp::create(builder, loc, result); | ||
| } | ||
|
|
||
| return result; | ||
| }; | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| // Getting the type from a symbol compared to a DeclSpec is simpler since we do | ||
| // not need to consider derived vs intrinsic types. Semantics is guaranteed to | ||
| // generate these symbols. | ||
| static mlir::Type | ||
| getReductionType(lower::AbstractConverter &converter, | ||
| const parser::OmpReductionSpecifier &specifier) { | ||
| const auto &combinerExpression = | ||
| std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t) | ||
| .value(); | ||
| const parser::OmpStylizedInstance &combinerInstance = | ||
| combinerExpression.v.front(); | ||
| const std::list<parser::OmpStylizedDeclaration> &declList = | ||
| std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t); | ||
| const parser::OmpStylizedDeclaration &decl = declList.front(); | ||
| const auto &name = std::get<parser::ObjectName>(decl.var.t); | ||
| const auto &symbol = semantics::SymbolRef(*name.symbol); | ||
| mlir::Type reductionType = converter.genType(symbol); | ||
| return reductionType; | ||
| } | ||
|
|
||
| static void genOMP( | ||
| lower::AbstractConverter &converter, lower::SymMap &symTable, | ||
| semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, | ||
| const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) { | ||
| if (!semaCtx.langOptions().OpenMPSimd) | ||
| TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct"); | ||
| if (!semaCtx.langOptions().OpenMPSimd) { | ||
|
||
| const parser::OmpArgumentList &args{ | ||
| declareReductionConstruct.v.Arguments()}; | ||
| const parser::OmpArgument &arg{args.v.front()}; | ||
| const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u); | ||
|
|
||
| if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1) | ||
| TODO(converter.getCurrentLocation(), | ||
| "multiple types in declare target is not yet supported"); | ||
jsjodin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| mlir::Type reductionType = getReductionType(converter, specifier); | ||
| ReductionProcessor::GenCombinerCBTy genCombinerCB; | ||
| processReductionCombiner(converter, symTable, semaCtx, specifier, | ||
| genCombinerCB); | ||
| const parser::OmpClauseList &initializer = | ||
| declareReductionConstruct.v.Clauses(); | ||
| if (initializer.v.size() > 0) { | ||
| List<Clause> clauses = makeClauses(initializer, semaCtx); | ||
| ReductionProcessor::GenInitValueCBTy genInitValueCB; | ||
| ClauseProcessor cp(converter, semaCtx, clauses); | ||
| const parser::OmpClause::Initializer &iclause{ | ||
| std::get<parser::OmpClause::Initializer>(initializer.v.front().u)}; | ||
| cp.processInitializer(symTable, iclause, genInitValueCB); | ||
| const auto &identifier = | ||
| std::get<parser::OmpReductionIdentifier>(specifier.t); | ||
| const auto &designator = | ||
| std::get<parser::ProcedureDesignator>(identifier.u); | ||
| const auto &reductionName = std::get<parser::Name>(designator.u); | ||
| bool isByRef = ReductionProcessor::doReductionByRef(reductionType); | ||
| ReductionProcessor::createDeclareReductionHelper< | ||
| mlir::omp::DeclareReductionOp>( | ||
| converter, reductionName.ToString(), reductionType, | ||
| converter.getCurrentLocation(), isByRef, genCombinerCB, | ||
| genInitValueCB); | ||
| } else { | ||
| TODO(converter.getCurrentLocation(), | ||
| "declare target without an initializer clause is not yet supported"); | ||
jsjodin marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| } | ||
|
|
||
| static void | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we return
GenCombinerCBTyinstead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, done!