2020#include " flang/Common/idioms.h"
2121#include " flang/Evaluate/type.h"
2222#include " flang/Lower/Bridge.h"
23+ #include " flang/Lower/ConvertCall.h"
2324#include " flang/Lower/ConvertExpr.h"
2425#include " flang/Lower/ConvertExprToHLFIR.h"
2526#include " flang/Lower/ConvertVariable.h"
@@ -3582,19 +3583,32 @@ processReductionCombiner(lower::AbstractConverter &converter,
35823583 const parser::OmpStylizedInstance::Instance &instance =
35833584 std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t );
35843585
3585- const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u );
3586- if (!as) {
3587- TODO (converter.getCurrentLocation (),
3588- " A combiner that is a subroutine call is not yet supported" );
3586+ std::optional<semantics::SomeExpr> evalExprOpt;
3587+ if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u )) {
3588+ auto &expr = std::get<parser::Expr>(as->t );
3589+ evalExprOpt = makeExpr (expr, semaCtx);
3590+ } else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u )) {
3591+ if (call->typedCall ) {
3592+ const auto &procRef = *call->typedCall ;
3593+ evalExprOpt = semantics::SomeExpr{procRef};
3594+ } else {
3595+ TODO (converter.getCurrentLocation (),
3596+ " CallStmt without typedCall is not yet supported" );
3597+ }
3598+ } else {
3599+ TODO (converter.getCurrentLocation (), " Unsupported combiner instance type" );
35893600 }
3590- auto &expr = std::get<parser::Expr>(as->t );
3591- genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
3592- mlir::Type type, mlir::Value lhs, mlir::Value rhs,
3593- bool isByRef) {
3594- const auto &evalExpr = makeExpr (expr, semaCtx);
3601+
3602+ assert (evalExprOpt.has_value () && " evalExpr must be initialized" );
3603+ semantics::SomeExpr evalExpr = *evalExprOpt;
3604+
3605+ genCombinerCB = [&, evalExpr](fir::FirOpBuilder &builder, mlir::Location loc,
3606+ mlir::Type type, mlir::Value lhs,
3607+ mlir::Value rhs, bool isByRef) {
35953608 lower::SymMapScope scope (symTable);
35963609 const std::list<parser::OmpStylizedDeclaration> &declList =
35973610 std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t );
3611+ mlir::Value ompOutVar;
35983612 for (const parser::OmpStylizedDeclaration &decl : declList) {
35993613 auto &name = std::get<parser::ObjectName>(decl.var .t );
36003614 mlir::Value addr = lhs;
@@ -3617,15 +3631,32 @@ processReductionCombiner(lower::AbstractConverter &converter,
36173631 auto declareOp =
36183632 hlfir::DeclareOp::create (builder, loc, addr, name.ToString (), nullptr ,
36193633 {}, nullptr , nullptr , 0 , attributes);
3634+ if (name.ToString () == " omp_out" )
3635+ ompOutVar = declareOp.getResult (0 );
36203636 symTable.addVariableDefinition (*name.symbol , declareOp);
36213637 }
36223638
36233639 lower::StatementContext stmtCtx;
3624- mlir::Value result = fir::getBase (
3625- convertExprToValue (loc, converter, evalExpr, symTable, stmtCtx));
3626- if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType ()))
3627- if (lhs.getType () == refType.getElementType ())
3628- result = fir::LoadOp::create (builder, loc, result);
3640+ mlir::Value result = common::visit (
3641+ common::visitors{
3642+ [&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
3643+ convertCallToHLFIR (loc, converter, procRef, std::nullopt ,
3644+ symTable, stmtCtx);
3645+ auto outVal = fir::LoadOp::create (builder, loc, ompOutVar);
3646+ return outVal;
3647+ },
3648+ [&](const auto &expr) -> mlir::Value {
3649+ mlir::Value exprResult = fir::getBase (convertExprToValue (
3650+ loc, converter, evalExpr, symTable, stmtCtx));
3651+ // Optional load may be generated if we get a reference to the
3652+ // reduction type.
3653+ if (auto refType =
3654+ llvm::dyn_cast<fir::ReferenceType>(exprResult.getType ()))
3655+ if (lhs.getType () == refType.getElementType ())
3656+ exprResult = fir::LoadOp::create (builder, loc, exprResult);
3657+ return exprResult;
3658+ }},
3659+ evalExpr.u );
36293660 stmtCtx.finalizeAndPop ();
36303661 if (isByRef) {
36313662 fir::StoreOp::create (builder, loc, result, lhs);
0 commit comments