Skip to content

Commit 711fe33

Browse files
committed
[Flang] [OpenMP] Support for lowering task_reduction and in_reduction to MLIR
1 parent 66b2820 commit 711fe33

File tree

8 files changed

+307
-28
lines changed

8 files changed

+307
-28
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ bool ClauseProcessor::processReduction(
10631063
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
10641064
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
10651065
ReductionProcessor rp;
1066-
rp.addDeclareReduction(
1066+
rp.addDeclareReduction<omp::clause::Reduction>(
10671067
currentLocation, converter, clause, reductionVars, reduceVarByRef,
10681068
reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);
10691069

@@ -1085,6 +1085,80 @@ bool ClauseProcessor::processReduction(
10851085
});
10861086
}
10871087

1088+
bool ClauseProcessor::processTaskReduction(
1089+
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
1090+
llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
1091+
llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
1092+
return findRepeatableClause<omp::clause::TaskReduction>(
1093+
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
1094+
llvm::SmallVector<mlir::Value> taskReductionVars;
1095+
llvm::SmallVector<bool> taskReductionByref;
1096+
llvm::SmallVector<mlir::Attribute> taskReductionDeclSymbols;
1097+
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
1098+
ReductionProcessor rp;
1099+
rp.addDeclareReduction<omp::clause::TaskReduction>(
1100+
currentLocation, converter, clause, taskReductionVars,
1101+
taskReductionByref, taskReductionDeclSymbols,
1102+
outReductionSyms ? &taskReductionSyms : nullptr);
1103+
1104+
// Copy local lists into the output.
1105+
llvm::copy(taskReductionVars,
1106+
std::back_inserter(result.taskReductionVars));
1107+
llvm::copy(taskReductionByref,
1108+
std::back_inserter(result.taskReductionByref));
1109+
llvm::copy(taskReductionDeclSymbols,
1110+
std::back_inserter(result.taskReductionSyms));
1111+
1112+
if (outReductionTypes) {
1113+
outReductionTypes->reserve(outReductionTypes->size() +
1114+
taskReductionVars.size());
1115+
llvm::transform(taskReductionVars,
1116+
std::back_inserter(*outReductionTypes),
1117+
[](mlir::Value v) { return v.getType(); });
1118+
}
1119+
1120+
if (outReductionSyms)
1121+
llvm::copy(taskReductionSyms, std::back_inserter(*outReductionSyms));
1122+
});
1123+
}
1124+
1125+
bool ClauseProcessor::processInReduction(
1126+
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
1127+
llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
1128+
llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
1129+
return findRepeatableClause<omp::clause::InReduction>(
1130+
[&](const omp::clause::InReduction &clause,
1131+
const parser::CharBlock &source) {
1132+
llvm::SmallVector<mlir::Value> inReductionVars;
1133+
llvm::SmallVector<bool> inReductionByref;
1134+
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
1135+
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
1136+
ReductionProcessor rp;
1137+
rp.addDeclareReduction<omp::clause::InReduction>(
1138+
currentLocation, converter, clause, inReductionVars,
1139+
inReductionByref, inReductionDeclSymbols,
1140+
outReductionSyms ? &inReductionSyms : nullptr);
1141+
1142+
// Copy local lists into the output.
1143+
llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
1144+
llvm::copy(inReductionByref,
1145+
std::back_inserter(result.inReductionByref));
1146+
llvm::copy(inReductionDeclSymbols,
1147+
std::back_inserter(result.inReductionSyms));
1148+
1149+
if (outReductionTypes) {
1150+
outReductionTypes->reserve(outReductionTypes->size() +
1151+
inReductionVars.size());
1152+
llvm::transform(inReductionVars,
1153+
std::back_inserter(*outReductionTypes),
1154+
[](mlir::Value v) { return v.getType(); });
1155+
}
1156+
1157+
if (outReductionSyms)
1158+
llvm::copy(inReductionSyms, std::back_inserter(*outReductionSyms));
1159+
});
1160+
}
1161+
10881162
bool ClauseProcessor::processTo(
10891163
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
10901164
return findRepeatableClause<omp::clause::To>(

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,16 @@ class ClauseProcessor {
129129
llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
130130
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
131131
nullptr) const;
132+
bool processTaskReduction(
133+
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
134+
llvm::SmallVectorImpl<mlir::Type> *taskReductionTypes = nullptr,
135+
llvm::SmallVectorImpl<const semantics::Symbol *> *taskReductionSyms =
136+
nullptr) const;
137+
bool processInReduction(
138+
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
139+
llvm::SmallVectorImpl<mlir::Type> *inReductionTypes = nullptr,
140+
llvm::SmallVectorImpl<const semantics::Symbol *> *inReductionSyms =
141+
nullptr) const;
132142
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
133143
bool processUseDeviceAddr(
134144
lower::StatementContext &stmtCtx,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,33 +1243,37 @@ static void genTargetEnterExitUpdateDataClauses(
12431243
cp.processNowait(clauseOps);
12441244
}
12451245

1246-
static void genTaskClauses(lower::AbstractConverter &converter,
1247-
semantics::SemanticsContext &semaCtx,
1248-
lower::StatementContext &stmtCtx,
1249-
const List<Clause> &clauses, mlir::Location loc,
1250-
mlir::omp::TaskOperands &clauseOps) {
1246+
static void genTaskClauses(
1247+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1248+
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
1249+
mlir::Location loc, mlir::omp::TaskOperands &clauseOps,
1250+
llvm::SmallVectorImpl<mlir::Type> &inReductionTypes,
1251+
llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) {
12511252
ClauseProcessor cp(converter, semaCtx, clauses);
12521253
cp.processAllocate(clauseOps);
12531254
cp.processDepend(clauseOps);
12541255
cp.processFinal(stmtCtx, clauseOps);
12551256
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
1257+
cp.processInReduction(loc, clauseOps, &inReductionTypes, &inReductionSyms);
12561258
cp.processMergeable(clauseOps);
12571259
cp.processPriority(stmtCtx, clauseOps);
12581260
cp.processUntied(clauseOps);
12591261
// TODO Support delayed privatization.
12601262

1261-
cp.processTODO<clause::Affinity, clause::Detach, clause::InReduction>(
1263+
cp.processTODO<clause::Affinity, clause::Detach>(
12621264
loc, llvm::omp::Directive::OMPD_task);
12631265
}
12641266

1265-
static void genTaskgroupClauses(lower::AbstractConverter &converter,
1266-
semantics::SemanticsContext &semaCtx,
1267-
const List<Clause> &clauses, mlir::Location loc,
1268-
mlir::omp::TaskgroupOperands &clauseOps) {
1267+
static void genTaskgroupClauses(
1268+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1269+
const List<Clause> &clauses, mlir::Location loc,
1270+
mlir::omp::TaskgroupOperands &clauseOps,
1271+
llvm::SmallVectorImpl<mlir::Type> &taskReductionTypes,
1272+
llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms) {
12691273
ClauseProcessor cp(converter, semaCtx, clauses);
12701274
cp.processAllocate(clauseOps);
1271-
cp.processTODO<clause::TaskReduction>(loc,
1272-
llvm::omp::Directive::OMPD_taskgroup);
1275+
cp.processTaskReduction(loc, clauseOps, &taskReductionTypes,
1276+
&taskReductionSyms);
12731277
}
12741278

12751279
static void genTaskwaitClauses(lower::AbstractConverter &converter,
@@ -1869,13 +1873,26 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
18691873
ConstructQueue::const_iterator item) {
18701874
lower::StatementContext stmtCtx;
18711875
mlir::omp::TaskOperands clauseOps;
1872-
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
1876+
llvm::SmallVector<mlir::Type> inReductionTypes;
1877+
llvm::SmallVector<const semantics::Symbol *> inreductionSyms;
1878+
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
1879+
inReductionTypes, inreductionSyms);
18731880

1874-
return genOpWithBody<mlir::omp::TaskOp>(
1881+
auto reductionCallback = [&](mlir::Operation *op) {
1882+
genReductionVars(op, converter, loc, inreductionSyms, inReductionTypes);
1883+
return inreductionSyms;
1884+
};
1885+
1886+
auto taskOp = genOpWithBody<mlir::omp::TaskOp>(
18751887
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
18761888
llvm::omp::Directive::OMPD_task)
1877-
.setClauses(&item->clauses),
1889+
.setClauses(&item->clauses)
1890+
.setGenRegionEntryCb(reductionCallback),
18781891
queue, item, clauseOps);
1892+
// Add reduction variables as arguments
1893+
llvm::SmallVector<mlir::Location> blockArgLocs(inReductionTypes.size(), loc);
1894+
taskOp->getRegion(0).addArguments(inReductionTypes, blockArgLocs);
1895+
return taskOp;
18791896
}
18801897

18811898
static mlir::omp::TaskgroupOp
@@ -1885,13 +1902,21 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
18851902
const ConstructQueue &queue,
18861903
ConstructQueue::const_iterator item) {
18871904
mlir::omp::TaskgroupOperands clauseOps;
1888-
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps);
1905+
llvm::SmallVector<mlir::Type> taskReductionTypes;
1906+
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
1907+
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
1908+
taskReductionTypes, taskReductionSyms);
18891909

1890-
return genOpWithBody<mlir::omp::TaskgroupOp>(
1910+
auto taskgroupOp = genOpWithBody<mlir::omp::TaskgroupOp>(
18911911
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
18921912
llvm::omp::Directive::OMPD_taskgroup)
18931913
.setClauses(&item->clauses),
18941914
queue, item, clauseOps);
1915+
1916+
// Add reduction variables as arguments
1917+
llvm::SmallVector<mlir::Location> blockArgLocs(taskReductionSyms.size(), loc);
1918+
taskgroupOp->getRegion(0).addArguments(taskReductionTypes, blockArgLocs);
1919+
return taskgroupOp;
18951920
}
18961921

18971922
static mlir::omp::TaskwaitOp
@@ -2767,7 +2792,9 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
27672792
!std::holds_alternative<clause::ThreadLimit>(clause.u) &&
27682793
!std::holds_alternative<clause::Threads>(clause.u) &&
27692794
!std::holds_alternative<clause::UseDeviceAddr>(clause.u) &&
2770-
!std::holds_alternative<clause::UseDevicePtr>(clause.u)) {
2795+
!std::holds_alternative<clause::UseDevicePtr>(clause.u) &&
2796+
!std::holds_alternative<clause::TaskReduction>(clause.u) &&
2797+
!std::holds_alternative<clause::InReduction>(clause.u)) {
27712798
TODO(clauseLocation, "OpenMP Block construct clause");
27722799
}
27732800
}

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "flang/Parser/tools.h"
2525
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2626
#include "llvm/Support/CommandLine.h"
27+
#include <type_traits>
2728

2829
static llvm::cl::opt<bool> forceByrefReduction(
2930
"force-byref-reduction",
@@ -34,6 +35,32 @@ namespace Fortran {
3435
namespace lower {
3536
namespace omp {
3637

38+
// explicit template declarations
39+
template void ReductionProcessor::addDeclareReduction<omp::clause::Reduction>(
40+
mlir::Location currentLocation, lower::AbstractConverter &converter,
41+
const omp::clause::Reduction &reduction,
42+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
43+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
44+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
45+
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols);
46+
47+
template void
48+
ReductionProcessor::addDeclareReduction<omp::clause::TaskReduction>(
49+
mlir::Location currentLocation, lower::AbstractConverter &converter,
50+
const omp::clause::TaskReduction &reduction,
51+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
52+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
53+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
54+
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols);
55+
56+
template void ReductionProcessor::addDeclareReduction<omp::clause::InReduction>(
57+
mlir::Location currentLocation, lower::AbstractConverter &converter,
58+
const omp::clause::InReduction &reduction,
59+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
60+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
61+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
62+
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols);
63+
3764
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
3865
const omp::clause::ProcedureDesignator &pd) {
3966
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
@@ -716,22 +743,22 @@ static bool doReductionByRef(mlir::Value reductionVar) {
716743
return false;
717744
}
718745

746+
template <class T>
719747
void ReductionProcessor::addDeclareReduction(
720748
mlir::Location currentLocation, lower::AbstractConverter &converter,
721-
const omp::clause::Reduction &reduction,
722-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
749+
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
723750
llvm::SmallVectorImpl<bool> &reduceVarByRef,
724751
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
725752
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols) {
726753
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
727-
728-
if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
729-
reduction.t))
730-
TODO(currentLocation, "Reduction modifiers are not supported");
754+
if constexpr (std::is_same<T, omp::clause::Reduction>::value) {
755+
if (std::get<std::optional<typename T::ReductionModifier>>(reduction.t))
756+
TODO(currentLocation, "Reduction modifiers are not supported");
757+
}
731758

732759
mlir::omp::DeclareReductionOp decl;
733760
const auto &redOperatorList{
734-
std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
761+
std::get<typename T::ReductionIdentifiers>(reduction.t)};
735762
assert(redOperatorList.size() == 1 && "Expecting single operator");
736763
const auto &redOperator = redOperatorList.front();
737764
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ class ReductionProcessor {
120120

121121
/// Creates a reduction declaration and associates it with an OpenMP block
122122
/// directive.
123+
template <class T>
123124
static void addDeclareReduction(
124125
mlir::Location currentLocation, lower::AbstractConverter &converter,
125-
const omp::clause::Reduction &reduction,
126-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
126+
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
127127
llvm::SmallVectorImpl<bool> &reduceVarByRef,
128128
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
129129
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols =
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
2+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
3+
4+
! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_box_Uxf32 : !fir.ref<!fir.box<!fir.array<?xf32>>> alloc {
5+
! [...]
6+
! CHECK: omp.yield
7+
! CHECK-LABEL: } init {
8+
! [...]
9+
! CHECK: omp.yield
10+
! CHECK-LABEL: } combiner {
11+
! [...]
12+
! CHECK: omp.yield
13+
! CHECK-LABEL: } cleanup {
14+
! [...]
15+
! CHECK: omp.yield
16+
! CHECK: }
17+
18+
! CHECK-LABEL: func.func @_QPtaskreduction
19+
! CHECK-SAME: (%[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "x"}) {
20+
! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
21+
! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]]
22+
! CHECK-SAME {uniq_name = "_QFtaskreductionEx"} : (!fir.box<!fir.array<?xf32>>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
23+
! CHECK: omp.parallel {
24+
! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
25+
! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
26+
! CHECK: omp.taskgroup task_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_3]] -> %[[VAL_4:.*]]: !fir.ref<!fir.box<!fir.array<?xf32>>>) {
27+
! CHECK: %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
28+
! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
29+
! CHECK: omp.task in_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<?xf32>>>) {
30+
! [...]
31+
! CHECK: omp.terminator
32+
! CHECK: }
33+
! CHECK: omp.terminator
34+
! CHECK: }
35+
! CHECK: omp.terminator
36+
! CHECK: }
37+
! CHECK: return
38+
! CHECK: }
39+
40+
subroutine taskReduction(x)
41+
real, dimension(:) :: x
42+
!$omp parallel
43+
!$omp taskgroup task_reduction(+:x)
44+
!$omp task in_reduction(+:x)
45+
x = x + 1
46+
!$omp end task
47+
!$omp end taskgroup
48+
!$omp end parallel
49+
end subroutine
50+

0 commit comments

Comments
 (0)