|
18 | 18 | #include "Decomposer.h" |
19 | 19 | #include "Utils.h" |
20 | 20 | #include "flang/Common/idioms.h" |
| 21 | +#include "flang/Evaluate/type.h" |
21 | 22 | #include "flang/Lower/Bridge.h" |
22 | 23 | #include "flang/Lower/ConvertExpr.h" |
| 24 | +#include "flang/Lower/ConvertExprToHLFIR.h" |
23 | 25 | #include "flang/Lower/ConvertVariable.h" |
24 | 26 | #include "flang/Lower/DirectivesCommon.h" |
25 | 27 | #include "flang/Lower/OpenMP/Clauses.h" |
26 | 28 | #include "flang/Lower/StatementContext.h" |
| 29 | +#include "flang/Lower/Support/ReductionProcessor.h" |
27 | 30 | #include "flang/Lower/SymbolMap.h" |
28 | 31 | #include "flang/Optimizer/Builder/BoxValue.h" |
29 | 32 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
@@ -2847,7 +2850,6 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, |
2847 | 2850 | // TODO: Add private syms and vars. |
2848 | 2851 | args.reduction.syms = reductionSyms; |
2849 | 2852 | args.reduction.vars = clauseOps.reductionVars; |
2850 | | - |
2851 | 2853 | return genOpWithBody<mlir::omp::TeamsOp>( |
2852 | 2854 | OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, |
2853 | 2855 | llvm::omp::Directive::OMPD_teams) |
@@ -3570,12 +3572,156 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, |
3570 | 3572 | TODO(converter.getCurrentLocation(), "OmpDeclareVariantDirective"); |
3571 | 3573 | } |
3572 | 3574 |
|
| 3575 | +static ReductionProcessor::GenCombinerCBTy |
| 3576 | +processReductionCombiner(lower::AbstractConverter &converter, |
| 3577 | + lower::SymMap &symTable, |
| 3578 | + semantics::SemanticsContext &semaCtx, |
| 3579 | + const parser::OmpReductionSpecifier &specifier) { |
| 3580 | + ReductionProcessor::GenCombinerCBTy genCombinerCB; |
| 3581 | + const auto &combinerExpression = |
| 3582 | + std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t) |
| 3583 | + .value(); |
| 3584 | + const parser::OmpStylizedInstance &combinerInstance = |
| 3585 | + combinerExpression.v.front(); |
| 3586 | + const parser::OmpStylizedInstance::Instance &instance = |
| 3587 | + std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t); |
| 3588 | + |
| 3589 | + const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u); |
| 3590 | + if (!as) { |
| 3591 | + TODO(converter.getCurrentLocation(), |
| 3592 | + "A combiner that is a subroutine call is not yet supported"); |
| 3593 | + } |
| 3594 | + auto &expr = std::get<parser::Expr>(as->t); |
| 3595 | + genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc, |
| 3596 | + mlir::Type type, mlir::Value lhs, mlir::Value rhs, |
| 3597 | + bool isByRef) { |
| 3598 | + const auto &evalExpr = makeExpr(expr, semaCtx); |
| 3599 | + lower::SymMapScope scope(symTable); |
| 3600 | + const std::list<parser::OmpStylizedDeclaration> &declList = |
| 3601 | + std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t); |
| 3602 | + for (const parser::OmpStylizedDeclaration &decl : declList) { |
| 3603 | + auto &name = std::get<parser::ObjectName>(decl.var.t); |
| 3604 | + mlir::Value addr = lhs; |
| 3605 | + mlir::Type type = lhs.getType(); |
| 3606 | + bool isRhs = name.ToString() == std::string("omp_in"); |
| 3607 | + if (isRhs) { |
| 3608 | + addr = rhs; |
| 3609 | + type = rhs.getType(); |
| 3610 | + } |
| 3611 | + |
| 3612 | + assert(name.symbol && "Reduction object name does not have a symbol"); |
| 3613 | + if (!fir::conformsWithPassByRef(type)) { |
| 3614 | + addr = builder.createTemporary(loc, type); |
| 3615 | + fir::StoreOp::create(builder, loc, isRhs ? rhs : lhs, addr); |
| 3616 | + } |
| 3617 | + fir::FortranVariableFlagsEnum extraFlags = {}; |
| 3618 | + fir::FortranVariableFlagsAttr attributes = |
| 3619 | + Fortran::lower::translateSymbolAttributes(builder.getContext(), |
| 3620 | + *name.symbol, extraFlags); |
| 3621 | + auto declareOp = |
| 3622 | + hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr, |
| 3623 | + {}, nullptr, nullptr, 0, attributes); |
| 3624 | + symTable.addVariableDefinition(*name.symbol, declareOp); |
| 3625 | + } |
| 3626 | + |
| 3627 | + lower::StatementContext stmtCtx; |
| 3628 | + mlir::Value result = fir::getBase( |
| 3629 | + convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx)); |
| 3630 | + if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType())) |
| 3631 | + if (lhs.getType() == refType.getElementType()) |
| 3632 | + result = fir::LoadOp::create(builder, loc, result); |
| 3633 | + stmtCtx.finalizeAndPop(); |
| 3634 | + if (isByRef) { |
| 3635 | + fir::StoreOp::create(builder, loc, result, lhs); |
| 3636 | + mlir::omp::YieldOp::create(builder, loc, lhs); |
| 3637 | + } else { |
| 3638 | + mlir::omp::YieldOp::create(builder, loc, result); |
| 3639 | + } |
| 3640 | + }; |
| 3641 | + return genCombinerCB; |
| 3642 | +} |
| 3643 | + |
| 3644 | +// Checks that the reduction type is either a trivial type or a derived type of |
| 3645 | +// trivial types. |
| 3646 | +static bool isSimpleReductionType(mlir::Type reductionType) { |
| 3647 | + if (fir::isa_trivial(reductionType)) |
| 3648 | + return true; |
| 3649 | + if (auto recordTy = mlir::dyn_cast<fir::RecordType>(reductionType)) { |
| 3650 | + for (auto [_, fieldType] : recordTy.getTypeList()) { |
| 3651 | + if (!fir::isa_trivial(fieldType)) |
| 3652 | + return false; |
| 3653 | + } |
| 3654 | + } |
| 3655 | + return true; |
| 3656 | +} |
| 3657 | + |
| 3658 | +// Getting the type from a symbol compared to a DeclSpec is simpler since we do |
| 3659 | +// not need to consider derived vs intrinsic types. Semantics is guaranteed to |
| 3660 | +// generate these symbols. |
| 3661 | +static mlir::Type |
| 3662 | +getReductionType(lower::AbstractConverter &converter, |
| 3663 | + const parser::OmpReductionSpecifier &specifier) { |
| 3664 | + const auto &combinerExpression = |
| 3665 | + std::get<std::optional<parser::OmpCombinerExpression>>(specifier.t) |
| 3666 | + .value(); |
| 3667 | + const parser::OmpStylizedInstance &combinerInstance = |
| 3668 | + combinerExpression.v.front(); |
| 3669 | + const std::list<parser::OmpStylizedDeclaration> &declList = |
| 3670 | + std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t); |
| 3671 | + const parser::OmpStylizedDeclaration &decl = declList.front(); |
| 3672 | + const auto &name = std::get<parser::ObjectName>(decl.var.t); |
| 3673 | + const auto &symbol = semantics::SymbolRef(*name.symbol); |
| 3674 | + mlir::Type reductionType = converter.genType(symbol); |
| 3675 | + |
| 3676 | + if (!isSimpleReductionType(reductionType)) |
| 3677 | + TODO(converter.getCurrentLocation(), |
| 3678 | + "declare reduction currently only supports trival types or derived " |
| 3679 | + "types containing trivial types"); |
| 3680 | + return reductionType; |
| 3681 | +} |
| 3682 | + |
3573 | 3683 | static void genOMP( |
3574 | 3684 | lower::AbstractConverter &converter, lower::SymMap &symTable, |
3575 | 3685 | semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, |
3576 | 3686 | const parser::OpenMPDeclareReductionConstruct &declareReductionConstruct) { |
3577 | | - if (!semaCtx.langOptions().OpenMPSimd) |
3578 | | - TODO(converter.getCurrentLocation(), "OpenMPDeclareReductionConstruct"); |
| 3687 | + if (semaCtx.langOptions().OpenMPSimd) |
| 3688 | + return; |
| 3689 | + |
| 3690 | + const parser::OmpArgumentList &args{declareReductionConstruct.v.Arguments()}; |
| 3691 | + const parser::OmpArgument &arg{args.v.front()}; |
| 3692 | + const auto &specifier = std::get<parser::OmpReductionSpecifier>(arg.u); |
| 3693 | + |
| 3694 | + if (std::get<parser::OmpTypeNameList>(specifier.t).v.size() > 1) |
| 3695 | + TODO(converter.getCurrentLocation(), |
| 3696 | + "multiple types in declare reduction is not yet supported"); |
| 3697 | + |
| 3698 | + mlir::Type reductionType = getReductionType(converter, specifier); |
| 3699 | + ReductionProcessor::GenCombinerCBTy genCombinerCB = |
| 3700 | + processReductionCombiner(converter, symTable, semaCtx, specifier); |
| 3701 | + const parser::OmpClauseList &initializer = |
| 3702 | + declareReductionConstruct.v.Clauses(); |
| 3703 | + if (initializer.v.size() > 0) { |
| 3704 | + List<Clause> clauses = makeClauses(initializer, semaCtx); |
| 3705 | + ReductionProcessor::GenInitValueCBTy genInitValueCB; |
| 3706 | + ClauseProcessor cp(converter, semaCtx, clauses); |
| 3707 | + const parser::OmpClause::Initializer &iclause{ |
| 3708 | + std::get<parser::OmpClause::Initializer>(initializer.v.front().u)}; |
| 3709 | + cp.processInitializer(symTable, iclause, genInitValueCB); |
| 3710 | + const auto &identifier = |
| 3711 | + std::get<parser::OmpReductionIdentifier>(specifier.t); |
| 3712 | + const auto &designator = |
| 3713 | + std::get<parser::ProcedureDesignator>(identifier.u); |
| 3714 | + const auto &reductionName = std::get<parser::Name>(designator.u); |
| 3715 | + bool isByRef = ReductionProcessor::doReductionByRef(reductionType); |
| 3716 | + ReductionProcessor::createDeclareReductionHelper< |
| 3717 | + mlir::omp::DeclareReductionOp>( |
| 3718 | + converter, reductionName.ToString(), reductionType, |
| 3719 | + converter.getCurrentLocation(), isByRef, genCombinerCB, genInitValueCB); |
| 3720 | + } else { |
| 3721 | + TODO(converter.getCurrentLocation(), |
| 3722 | + "declare reduction without an initializer clause is not yet " |
| 3723 | + "supported"); |
| 3724 | + } |
3579 | 3725 | } |
3580 | 3726 |
|
3581 | 3727 | static void |
|
0 commit comments