Skip to content

Commit e31a990

Browse files
committed
[Flang][OpenMP]Support for lowering task_reduction and in_reduction to MLIR
1 parent 0678e20 commit e31a990

12 files changed

+305
-59
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,30 @@ bool ClauseProcessor::processIsDevicePtr(
906906
});
907907
}
908908

909+
bool ClauseProcessor::processInReduction(
910+
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
911+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
912+
return findRepeatableClause<omp::clause::InReduction>(
913+
[&](const omp::clause::InReduction &clause, const parser::CharBlock &) {
914+
llvm::SmallVector<mlir::Value> inReductionVars;
915+
llvm::SmallVector<bool> inReduceVarByRef;
916+
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
917+
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
918+
ReductionProcessor rp;
919+
rp.addDeclareReduction<omp::clause::InReduction>(
920+
currentLocation, converter, clause, inReductionVars,
921+
inReduceVarByRef, inReductionDeclSymbols, inReductionSyms);
922+
923+
// Copy local lists into the output.
924+
llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
925+
llvm::copy(inReduceVarByRef,
926+
std::back_inserter(result.inReductionByref));
927+
llvm::copy(inReductionDeclSymbols,
928+
std::back_inserter(result.inReductionSyms));
929+
llvm::copy(inReductionSyms, std::back_inserter(outReductionSyms));
930+
});
931+
}
932+
909933
bool ClauseProcessor::processLink(
910934
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
911935
return findRepeatableClause<omp::clause::Link>(
@@ -1116,9 +1140,10 @@ bool ClauseProcessor::processReduction(
11161140
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
11171141
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
11181142
ReductionProcessor rp;
1119-
rp.addDeclareReduction(currentLocation, converter, clause,
1120-
reductionVars, reduceVarByRef,
1121-
reductionDeclSymbols, reductionSyms);
1143+
1144+
rp.addDeclareReduction<omp::clause::Reduction>(
1145+
currentLocation, converter, clause, reductionVars, reduceVarByRef,
1146+
reductionDeclSymbols, reductionSyms);
11221147

11231148
// Copy local lists into the output.
11241149
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
@@ -1129,6 +1154,30 @@ bool ClauseProcessor::processReduction(
11291154
});
11301155
}
11311156

1157+
bool ClauseProcessor::processTaskReduction(
1158+
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
1159+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
1160+
return findRepeatableClause<omp::clause::TaskReduction>(
1161+
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
1162+
llvm::SmallVector<mlir::Value> taskReductionVars;
1163+
llvm::SmallVector<bool> TaskReduceVarByRef;
1164+
llvm::SmallVector<mlir::Attribute> TaskReductionDeclSymbols;
1165+
llvm::SmallVector<const semantics::Symbol *> TaskReductionSyms;
1166+
ReductionProcessor rp;
1167+
rp.addDeclareReduction<omp::clause::TaskReduction>(
1168+
currentLocation, converter, clause, taskReductionVars,
1169+
TaskReduceVarByRef, TaskReductionDeclSymbols, TaskReductionSyms);
1170+
// Copy local lists into the output.
1171+
llvm::copy(taskReductionVars,
1172+
std::back_inserter(result.taskReductionVars));
1173+
llvm::copy(TaskReduceVarByRef,
1174+
std::back_inserter(result.taskReductionByref));
1175+
llvm::copy(TaskReductionDeclSymbols,
1176+
std::back_inserter(result.taskReductionSyms));
1177+
llvm::copy(TaskReductionSyms, std::back_inserter(outReductionSyms));
1178+
});
1179+
}
1180+
11321181
bool ClauseProcessor::processTo(
11331182
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
11341183
return findRepeatableClause<omp::clause::To>(

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ class ClauseProcessor {
104104
bool processIsDevicePtr(
105105
mlir::omp::IsDevicePtrClauseOps &result,
106106
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
107+
bool processInReduction(
108+
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
109+
llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms) const;
107110
bool
108111
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
109112

@@ -122,6 +125,10 @@ class ClauseProcessor {
122125
bool processReduction(
123126
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
124127
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
128+
bool processTaskReduction(mlir::Location currentLocation,
129+
mlir::omp::TaskReductionClauseOps &result,
130+
llvm::SmallVectorImpl<const semantics::Symbol *>
131+
&TaskReductionSyms) const;
125132
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
126133
bool processUseDeviceAddr(
127134
lower::StatementContext &stmtCtx,

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,13 @@ void DataSharingProcessor::collectSymbols(
344344
// Collect all symbols referenced in the evaluation being processed,
345345
// that matches 'flag'.
346346
llvm::SetVector<const semantics::Symbol *> allSymbols;
347+
bool collectSymbols = true;
348+
for (const omp::Clause &clause : clauses) {
349+
if (clause.id == llvm::omp::Clause::OMPC_in_reduction)
350+
collectSymbols = false;
351+
}
347352
converter.collectSymbolSet(eval, allSymbols, flag,
348-
/*collectSymbols=*/true,
353+
/*collectSymbols=*/collectSymbols,
349354
/*collectHostAssociatedSymbols=*/true);
350355

351356
llvm::SetVector<const semantics::Symbol *> symbolsInNestedRegions;

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,11 +1248,11 @@ static void genTargetEnterExitUpdateDataClauses(
12481248
cp.processNowait(clauseOps);
12491249
}
12501250

1251-
static void genTaskClauses(lower::AbstractConverter &converter,
1252-
semantics::SemanticsContext &semaCtx,
1253-
lower::StatementContext &stmtCtx,
1254-
const List<Clause> &clauses, mlir::Location loc,
1255-
mlir::omp::TaskOperands &clauseOps) {
1251+
static void genTaskClauses(
1252+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1253+
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
1254+
mlir::Location loc, mlir::omp::TaskOperands &clauseOps,
1255+
llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms) {
12561256
ClauseProcessor cp(converter, semaCtx, clauses);
12571257
cp.processAllocate(clauseOps);
12581258
cp.processDepend(clauseOps);
@@ -1261,20 +1261,21 @@ static void genTaskClauses(lower::AbstractConverter &converter,
12611261
cp.processMergeable(clauseOps);
12621262
cp.processPriority(stmtCtx, clauseOps);
12631263
cp.processUntied(clauseOps);
1264+
cp.processInReduction(loc, clauseOps, InReductionSyms);
12641265
// TODO Support delayed privatization.
12651266

1266-
cp.processTODO<clause::Affinity, clause::Detach, clause::InReduction>(
1267+
cp.processTODO<clause::Affinity, clause::Detach>(
12671268
loc, llvm::omp::Directive::OMPD_task);
12681269
}
12691270

1270-
static void genTaskgroupClauses(lower::AbstractConverter &converter,
1271-
semantics::SemanticsContext &semaCtx,
1272-
const List<Clause> &clauses, mlir::Location loc,
1273-
mlir::omp::TaskgroupOperands &clauseOps) {
1271+
static void genTaskgroupClauses(
1272+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1273+
const List<Clause> &clauses, mlir::Location loc,
1274+
mlir::omp::TaskgroupOperands &clauseOps,
1275+
llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms) {
12741276
ClauseProcessor cp(converter, semaCtx, clauses);
12751277
cp.processAllocate(clauseOps);
1276-
cp.processTODO<clause::TaskReduction>(loc,
1277-
llvm::omp::Directive::OMPD_taskgroup);
1278+
cp.processTaskReduction(loc, clauseOps, taskReductionSyms);
12781279
}
12791280

12801281
static void genTaskwaitClauses(lower::AbstractConverter &converter,
@@ -1884,7 +1885,9 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
18841885
ConstructQueue::const_iterator item) {
18851886
lower::StatementContext stmtCtx;
18861887
mlir::omp::TaskOperands clauseOps;
1887-
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
1888+
llvm::SmallVector<const semantics::Symbol *> InReductionSyms;
1889+
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
1890+
InReductionSyms);
18881891

18891892
if (!enableDelayedPrivatization)
18901893
return genOpWithBody<mlir::omp::TaskOp>(
@@ -1901,22 +1904,35 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19011904
EntryBlockArgs taskArgs;
19021905
taskArgs.priv.syms = dsp.getDelayedPrivSymbols();
19031906
taskArgs.priv.vars = clauseOps.privateVars;
1907+
taskArgs.inReduction.syms = InReductionSyms;
1908+
taskArgs.inReduction.vars = clauseOps.inReductionVars;
19041909

19051910
auto genRegionEntryCB = [&](mlir::Operation *op) {
19061911
genEntryBlock(converter.getFirOpBuilder(), taskArgs, op->getRegion(0));
19071912
bindEntryBlockArgs(converter,
19081913
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
19091914
taskArgs);
1910-
return llvm::to_vector(taskArgs.priv.syms);
1915+
return llvm::to_vector(taskArgs.getSyms());
19111916
};
19121917

1913-
return genOpWithBody<mlir::omp::TaskOp>(
1918+
OpWithBodyGenInfo genInfo =
19141919
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
19151920
llvm::omp::Directive::OMPD_task)
19161921
.setClauses(&item->clauses)
19171922
.setDataSharingProcessor(&dsp)
1918-
.setGenRegionEntryCb(genRegionEntryCB),
1919-
queue, item, clauseOps);
1923+
.setGenRegionEntryCb(genRegionEntryCB);
1924+
1925+
auto taskOp =
1926+
genOpWithBody<mlir::omp::TaskOp>(genInfo, queue, item, clauseOps);
1927+
1928+
llvm::SmallVector<mlir::Type> inReductionTypes;
1929+
for (const auto &inreductionVar : clauseOps.inReductionVars)
1930+
inReductionTypes.push_back(inreductionVar.getType());
1931+
1932+
// Add reduction variables as entry block arguments to the task region
1933+
llvm::SmallVector<mlir::Location> blockArgLocs(InReductionSyms.size(), loc);
1934+
taskOp->getRegion(0).addArguments(inReductionTypes, blockArgLocs);
1935+
return taskOp;
19201936
}
19211937

19221938
static mlir::omp::TaskgroupOp
@@ -1926,13 +1942,26 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19261942
const ConstructQueue &queue,
19271943
ConstructQueue::const_iterator item) {
19281944
mlir::omp::TaskgroupOperands clauseOps;
1929-
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps);
1945+
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
1946+
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
1947+
taskReductionSyms);
19301948

1931-
return genOpWithBody<mlir::omp::TaskgroupOp>(
1949+
OpWithBodyGenInfo genInfo =
19321950
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
19331951
llvm::omp::Directive::OMPD_taskgroup)
1934-
.setClauses(&item->clauses),
1935-
queue, item, clauseOps);
1952+
.setClauses(&item->clauses);
1953+
1954+
auto taskgroupOp =
1955+
genOpWithBody<mlir::omp::TaskgroupOp>(genInfo, queue, item, clauseOps);
1956+
1957+
llvm::SmallVector<mlir::Type> taskReductionTypes;
1958+
for (const auto &taskreductionVar : clauseOps.taskReductionVars)
1959+
taskReductionTypes.push_back(taskreductionVar.getType());
1960+
1961+
// Add reduction variables as entry block arguments to the taskgroup region
1962+
llvm::SmallVector<mlir::Location> blockArgLocs(taskReductionSyms.size(), loc);
1963+
taskgroupOp->getRegion(0).addArguments(taskReductionTypes, blockArgLocs);
1964+
return taskgroupOp;
19361965
}
19371966

19381967
static mlir::omp::TaskwaitOp

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "flang/Parser/tools.h"
2525
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2626
#include "llvm/Support/CommandLine.h"
27+
#include <type_traits>
2728

2829
static llvm::cl::opt<bool> forceByrefReduction(
2930
"force-byref-reduction",
@@ -34,6 +35,32 @@ namespace Fortran {
3435
namespace lower {
3536
namespace omp {
3637

38+
// explicit template declarations
39+
template void ReductionProcessor::addDeclareReduction<omp::clause::Reduction>(
40+
mlir::Location currentLocation, lower::AbstractConverter &converter,
41+
const omp::clause::Reduction &reduction,
42+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
43+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
44+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
45+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
46+
47+
template void
48+
ReductionProcessor::addDeclareReduction<omp::clause::TaskReduction>(
49+
mlir::Location currentLocation, lower::AbstractConverter &converter,
50+
const omp::clause::TaskReduction &reduction,
51+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
52+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
53+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
54+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
55+
56+
template void ReductionProcessor::addDeclareReduction<omp::clause::InReduction>(
57+
mlir::Location currentLocation, lower::AbstractConverter &converter,
58+
const omp::clause::InReduction &reduction,
59+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
60+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
61+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
62+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
63+
3764
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
3865
const omp::clause::ProcedureDesignator &pd) {
3966
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
@@ -716,22 +743,22 @@ static bool doReductionByRef(mlir::Value reductionVar) {
716743
return false;
717744
}
718745

746+
template <class T>
719747
void ReductionProcessor::addDeclareReduction(
720748
mlir::Location currentLocation, lower::AbstractConverter &converter,
721-
const omp::clause::Reduction &reduction,
722-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
749+
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
723750
llvm::SmallVectorImpl<bool> &reduceVarByRef,
724751
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
725752
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
726753
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
727-
728-
if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
729-
reduction.t))
730-
TODO(currentLocation, "Reduction modifiers are not supported");
754+
if constexpr (std::is_same<T, omp::clause::Reduction>::value) {
755+
if (std::get<std::optional<typename T::ReductionModifier>>(reduction.t))
756+
TODO(currentLocation, "Reduction modifiers are not supported");
757+
}
731758

732759
mlir::omp::DeclareReductionOp decl;
733760
const auto &redOperatorList{
734-
std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
761+
std::get<typename T::ReductionIdentifiers>(reduction.t)};
735762
assert(redOperatorList.size() == 1 && "Expecting single operator");
736763
const auto &redOperator = redOperatorList.front();
737764
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ class ReductionProcessor {
120120

121121
/// Creates a reduction declaration and associates it with an OpenMP block
122122
/// directive.
123+
template <class T>
123124
static void addDeclareReduction(
124125
mlir::Location currentLocation, lower::AbstractConverter &converter,
125-
const omp::clause::Reduction &reduction,
126-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
126+
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
127127
llvm::SmallVectorImpl<bool> &reduceVarByRef,
128128
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
129129
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);

flang/test/Lower/OpenMP/Todo/task-inreduction.f90

Lines changed: 0 additions & 15 deletions
This file was deleted.

flang/test/Lower/OpenMP/Todo/taskgroup-task-reduction.f90

Lines changed: 0 additions & 10 deletions
This file was deleted.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
2+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
3+
4+
!CHECK-LABEL: omp.declare_reduction
5+
!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
6+
!CHECK: ^bb0(%{{.*}}: i32):
7+
!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
8+
!CHECK: omp.yield(%[[C0_1]] : i32)
9+
!CHECK: } combiner {
10+
!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
11+
!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
12+
!CHECK: omp.yield(%[[RES]] : i32)
13+
!CHECK: }
14+
15+
!CHECK-LABEL: func.func @_QPomp_task_in_reduction() {
16+
! [...]
17+
!CHECK: omp.task in_reduction(@[[RED_I32_NAME]] %[[VAL_1:.*]]#0 -> %[[ARG0]] : !fir.ref<i32>) {
18+
!CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[ARG0]]
19+
!CHECK-SAME: {uniq_name = "_QFomp_task_in_reductionEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
20+
!CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<i32>
21+
!CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
22+
!CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32
23+
!CHECK: hlfir.assign %[[VAL_7]] to %[[VAL_4]]#0 : i32, !fir.ref<i32>
24+
!CHECK: omp.terminator
25+
!CHECK: }
26+
!CHECK: return
27+
!CHECK: }
28+
29+
subroutine omp_task_in_reduction()
30+
integer i
31+
i = 0
32+
!$omp task in_reduction(+:i)
33+
i = i + 1
34+
!$omp end task
35+
end subroutine omp_task_in_reduction

0 commit comments

Comments
 (0)