Skip to content

Commit d989ff9

Browse files
authored
[flang][OpenMP] Add lowering of subroutine calls in custom reduction combiners (#169808)
This patch adds support for lowering subroutine calls in custom reduction combiners to MLIR.
1 parent cc72171 commit d989ff9

File tree

2 files changed

+105
-14
lines changed

2 files changed

+105
-14
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "flang/Common/idioms.h"
2121
#include "flang/Evaluate/type.h"
2222
#include "flang/Lower/Bridge.h"
23+
#include "flang/Lower/ConvertCall.h"
2324
#include "flang/Lower/ConvertExpr.h"
2425
#include "flang/Lower/ConvertExprToHLFIR.h"
2526
#include "flang/Lower/ConvertVariable.h"
@@ -3582,19 +3583,32 @@ processReductionCombiner(lower::AbstractConverter &converter,
35823583
const parser::OmpStylizedInstance::Instance &instance =
35833584
std::get<parser::OmpStylizedInstance::Instance>(combinerInstance.t);
35843585

3585-
const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u);
3586-
if (!as) {
3587-
TODO(converter.getCurrentLocation(),
3588-
"A combiner that is a subroutine call is not yet supported");
3586+
std::optional<semantics::SomeExpr> evalExprOpt;
3587+
if (const auto *as = std::get_if<parser::AssignmentStmt>(&instance.u)) {
3588+
auto &expr = std::get<parser::Expr>(as->t);
3589+
evalExprOpt = makeExpr(expr, semaCtx);
3590+
} else if (const auto *call = std::get_if<parser::CallStmt>(&instance.u)) {
3591+
if (call->typedCall) {
3592+
const auto &procRef = *call->typedCall;
3593+
evalExprOpt = semantics::SomeExpr{procRef};
3594+
} else {
3595+
TODO(converter.getCurrentLocation(),
3596+
"CallStmt without typedCall is not yet supported");
3597+
}
3598+
} else {
3599+
TODO(converter.getCurrentLocation(), "Unsupported combiner instance type");
35893600
}
3590-
auto &expr = std::get<parser::Expr>(as->t);
3591-
genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
3592-
mlir::Type type, mlir::Value lhs, mlir::Value rhs,
3593-
bool isByRef) {
3594-
const auto &evalExpr = makeExpr(expr, semaCtx);
3601+
3602+
assert(evalExprOpt.has_value() && "evalExpr must be initialized");
3603+
semantics::SomeExpr evalExpr = *evalExprOpt;
3604+
3605+
genCombinerCB = [&, evalExpr](fir::FirOpBuilder &builder, mlir::Location loc,
3606+
mlir::Type type, mlir::Value lhs,
3607+
mlir::Value rhs, bool isByRef) {
35953608
lower::SymMapScope scope(symTable);
35963609
const std::list<parser::OmpStylizedDeclaration> &declList =
35973610
std::get<std::list<parser::OmpStylizedDeclaration>>(combinerInstance.t);
3611+
mlir::Value ompOutVar;
35983612
for (const parser::OmpStylizedDeclaration &decl : declList) {
35993613
auto &name = std::get<parser::ObjectName>(decl.var.t);
36003614
mlir::Value addr = lhs;
@@ -3617,15 +3631,32 @@ processReductionCombiner(lower::AbstractConverter &converter,
36173631
auto declareOp =
36183632
hlfir::DeclareOp::create(builder, loc, addr, name.ToString(), nullptr,
36193633
{}, nullptr, nullptr, 0, attributes);
3634+
if (name.ToString() == "omp_out")
3635+
ompOutVar = declareOp.getResult(0);
36203636
symTable.addVariableDefinition(*name.symbol, declareOp);
36213637
}
36223638

36233639
lower::StatementContext stmtCtx;
3624-
mlir::Value result = fir::getBase(
3625-
convertExprToValue(loc, converter, evalExpr, symTable, stmtCtx));
3626-
if (auto refType = llvm::dyn_cast<fir::ReferenceType>(result.getType()))
3627-
if (lhs.getType() == refType.getElementType())
3628-
result = fir::LoadOp::create(builder, loc, result);
3640+
mlir::Value result = common::visit(
3641+
common::visitors{
3642+
[&](const evaluate::ProcedureRef &procRef) -> mlir::Value {
3643+
convertCallToHLFIR(loc, converter, procRef, std::nullopt,
3644+
symTable, stmtCtx);
3645+
auto outVal = fir::LoadOp::create(builder, loc, ompOutVar);
3646+
return outVal;
3647+
},
3648+
[&](const auto &expr) -> mlir::Value {
3649+
mlir::Value exprResult = fir::getBase(convertExprToValue(
3650+
loc, converter, evalExpr, symTable, stmtCtx));
3651+
// Optional load may be generated if we get a reference to the
3652+
// reduction type.
3653+
if (auto refType =
3654+
llvm::dyn_cast<fir::ReferenceType>(exprResult.getType()))
3655+
if (lhs.getType() == refType.getElementType())
3656+
exprResult = fir::LoadOp::create(builder, loc, exprResult);
3657+
return exprResult;
3658+
}},
3659+
evalExpr.u);
36293660
stmtCtx.finalizeAndPop();
36303661
if (isByRef) {
36313662
fir::StoreOp::create(builder, loc, result, lhs);
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
! This test checks lowering of OpenMP declare reduction Directive, with combiner
2+
! via a subroutine call.
3+
4+
!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
5+
6+
subroutine combine_me(out, in)
7+
integer out, in
8+
out = out + in
9+
end subroutine combine_me
10+
11+
function func(x, n)
12+
integer func
13+
integer x(n)
14+
integer res
15+
interface
16+
subroutine combine_me(out, in)
17+
integer out, in
18+
end subroutine combine_me
19+
end interface
20+
!CHECK: omp.declare_reduction @red_add : i32 init {
21+
!CHECK: ^bb0(%[[OMP_ORIG_ARG_I:.*]]: i32):
22+
!CHECK: %[[OMP_PRIV:.*]] = fir.alloca i32
23+
!CHECK: %[[OMP_ORIG:.*]] = fir.alloca i32
24+
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_ORIG]] : !fir.ref<i32>
25+
!CHECK: %[[OMP_ORIG_DECL:.*]]:2 = hlfir.declare %[[OMP_ORIG]] {uniq_name = "omp_orig"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
26+
!CHECK: fir.store %[[OMP_ORIG_ARG_I]] to %[[OMP_PRIV]] : !fir.ref<i32>
27+
!CHECK: %[[OMP_PRIV_DECL:.*]]:2 = hlfir.declare %[[OMP_PRIV]] {uniq_name = "omp_priv"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
28+
!CHECK: %[[CONST_0:.*]] = arith.constant 0 : i32
29+
!CHECK: omp.yield(%[[CONST_0]] : i32)
30+
!CHECK: } combiner {
31+
!CHECK: ^bb0(%[[LHS_ARG:.*]]: i32, %[[RHS_ARG:.*]]: i32):
32+
!CHECK: %[[OMP_OUT:.*]] = fir.alloca i32
33+
!CHECK: %[[OMP_IN:.*]] = fir.alloca i32
34+
!CHECK: fir.store %[[RHS_ARG]] to %[[OMP_IN]] : !fir.ref<i32>
35+
!CHECK: %[[OMP_IN_DECL:.*]]:2 = hlfir.declare %[[OMP_IN]] {uniq_name = "omp_in"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
36+
!CHECK: fir.store %[[LHS_ARG]] to %[[OMP_OUT]] : !fir.ref<i32>
37+
!CHECK: %[[OMP_OUT_DECL:.*]]:2 = hlfir.declare %[[OMP_OUT]] {uniq_name = "omp_out"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
38+
!CHECK: fir.call @_QPcombine_me(%[[OMP_OUT_DECL]]#0, %[[OMP_IN_DECL]]#0) fastmath<contract> : (!fir.ref<i32>, !fir.ref<i32>) -> ()
39+
!CHECK: %[[OMP_OUT_VAL:.*]] = fir.load %[[OMP_OUT_DECL]]#0 : !fir.ref<i32>
40+
!CHECK: omp.yield(%[[OMP_OUT_VAL]] : i32)
41+
!CHECK: }
42+
!CHECK: func.func @_QPcombine_me(%[[OUT:.*]]: !fir.ref<i32> {fir.bindc_name = "out"}, %[[IN:.*]]: !fir.ref<i32> {fir.bindc_name = "in"}) {
43+
!CHECK: %[[SCOPE:.*]] = fir.dummy_scope : !fir.dscope
44+
!CHECK: %[[IN_DECL:.*]]:2 = hlfir.declare %[[IN]] dummy_scope %[[SCOPE]] arg 2 {uniq_name = "_QFcombine_meEin"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
45+
!CHECK: %[[OUT_DECL:.*]]:2 = hlfir.declare %[[OUT]] dummy_scope %[[SCOPE]] arg 1 {uniq_name = "_QFcombine_meEout"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
46+
!CHECK: %[[OUT_VAL:.*]] = fir.load %[[OUT_DECL]]#0 : !fir.ref<i32>
47+
!CHECK: %[[IN_VAL:.*]] = fir.load %[[IN_DECL]]#0 : !fir.ref<i32>
48+
!CHECK: %[[SUM:.*]] = arith.addi %[[OUT_VAL]], %[[IN_VAL]] : i32
49+
!CHECK: hlfir.assign %[[SUM]] to %[[OUT_DECL]]#0 : i32, !fir.ref<i32>
50+
!CHECK: return
51+
!CHECK: }
52+
!$omp declare reduction(red_add:integer(4):combine_me(omp_out,omp_in)) initializer(omp_priv=0)
53+
res=0
54+
!$omp simd reduction(red_add:res)
55+
do i=1,n
56+
res=res+x(i)
57+
enddo
58+
func=res
59+
end function func
60+

0 commit comments

Comments
 (0)