Skip to content

Commit 00f3f5d

Browse files
committed
[Flang][OpenMP]Support for lowering task_reduction and in_reduction to MLIR
1 parent 14dcf82 commit 00f3f5d

12 files changed

+305
-59
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,30 @@ bool ClauseProcessor::processIsDevicePtr(
916916
});
917917
}
918918

919+
bool ClauseProcessor::processInReduction(
920+
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
921+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
922+
return findRepeatableClause<omp::clause::InReduction>(
923+
[&](const omp::clause::InReduction &clause, const parser::CharBlock &) {
924+
llvm::SmallVector<mlir::Value> inReductionVars;
925+
llvm::SmallVector<bool> inReduceVarByRef;
926+
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
927+
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
928+
ReductionProcessor rp;
929+
rp.addDeclareReduction<omp::clause::InReduction>(
930+
currentLocation, converter, clause, inReductionVars,
931+
inReduceVarByRef, inReductionDeclSymbols, inReductionSyms);
932+
933+
// Copy local lists into the output.
934+
llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
935+
llvm::copy(inReduceVarByRef,
936+
std::back_inserter(result.inReductionByref));
937+
llvm::copy(inReductionDeclSymbols,
938+
std::back_inserter(result.inReductionSyms));
939+
llvm::copy(inReductionSyms, std::back_inserter(outReductionSyms));
940+
});
941+
}
942+
919943
bool ClauseProcessor::processLink(
920944
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
921945
return findRepeatableClause<omp::clause::Link>(
@@ -1126,9 +1150,10 @@ bool ClauseProcessor::processReduction(
11261150
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
11271151
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
11281152
ReductionProcessor rp;
1129-
rp.addDeclareReduction(currentLocation, converter, clause,
1130-
reductionVars, reduceVarByRef,
1131-
reductionDeclSymbols, reductionSyms);
1153+
1154+
rp.addDeclareReduction<omp::clause::Reduction>(
1155+
currentLocation, converter, clause, reductionVars, reduceVarByRef,
1156+
reductionDeclSymbols, reductionSyms);
11321157

11331158
// Copy local lists into the output.
11341159
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
@@ -1139,6 +1164,30 @@ bool ClauseProcessor::processReduction(
11391164
});
11401165
}
11411166

1167+
bool ClauseProcessor::processTaskReduction(
1168+
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
1169+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
1170+
return findRepeatableClause<omp::clause::TaskReduction>(
1171+
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
1172+
llvm::SmallVector<mlir::Value> taskReductionVars;
1173+
llvm::SmallVector<bool> TaskReduceVarByRef;
1174+
llvm::SmallVector<mlir::Attribute> TaskReductionDeclSymbols;
1175+
llvm::SmallVector<const semantics::Symbol *> TaskReductionSyms;
1176+
ReductionProcessor rp;
1177+
rp.addDeclareReduction<omp::clause::TaskReduction>(
1178+
currentLocation, converter, clause, taskReductionVars,
1179+
TaskReduceVarByRef, TaskReductionDeclSymbols, TaskReductionSyms);
1180+
// Copy local lists into the output.
1181+
llvm::copy(taskReductionVars,
1182+
std::back_inserter(result.taskReductionVars));
1183+
llvm::copy(TaskReduceVarByRef,
1184+
std::back_inserter(result.taskReductionByref));
1185+
llvm::copy(TaskReductionDeclSymbols,
1186+
std::back_inserter(result.taskReductionSyms));
1187+
llvm::copy(TaskReductionSyms, std::back_inserter(outReductionSyms));
1188+
});
1189+
}
1190+
11421191
bool ClauseProcessor::processTo(
11431192
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
11441193
return findRepeatableClause<omp::clause::To>(

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ class ClauseProcessor {
105105
bool processIsDevicePtr(
106106
mlir::omp::IsDevicePtrClauseOps &result,
107107
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
108+
bool processInReduction(
109+
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
110+
llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms) const;
108111
bool
109112
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
110113

@@ -123,6 +126,10 @@ class ClauseProcessor {
123126
bool processReduction(
124127
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
125128
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
129+
bool processTaskReduction(mlir::Location currentLocation,
130+
mlir::omp::TaskReductionClauseOps &result,
131+
llvm::SmallVectorImpl<const semantics::Symbol *>
132+
&TaskReductionSyms) const;
126133
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
127134
bool processUseDeviceAddr(
128135
lower::StatementContext &stmtCtx,

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,13 @@ void DataSharingProcessor::collectSymbols(
344344
// Collect all symbols referenced in the evaluation being processed,
345345
// that matches 'flag'.
346346
llvm::SetVector<const semantics::Symbol *> allSymbols;
347+
bool collectSymbols = true;
348+
for (const omp::Clause &clause : clauses) {
349+
if (clause.id == llvm::omp::Clause::OMPC_in_reduction)
350+
collectSymbols = false;
351+
}
347352
converter.collectSymbolSet(eval, allSymbols, flag,
348-
/*collectSymbols=*/true,
353+
/*collectSymbols=*/collectSymbols,
349354
/*collectHostAssociatedSymbols=*/true);
350355

351356
llvm::SetVector<const semantics::Symbol *> symbolsInNestedRegions;

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,34 +1248,35 @@ static void genTargetEnterExitUpdateDataClauses(
12481248
cp.processNowait(clauseOps);
12491249
}
12501250

1251-
static void genTaskClauses(lower::AbstractConverter &converter,
1252-
semantics::SemanticsContext &semaCtx,
1253-
lower::StatementContext &stmtCtx,
1254-
const List<Clause> &clauses, mlir::Location loc,
1255-
mlir::omp::TaskOperands &clauseOps) {
1251+
static void genTaskClauses(
1252+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1253+
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
1254+
mlir::Location loc, mlir::omp::TaskOperands &clauseOps,
1255+
llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms) {
12561256
ClauseProcessor cp(converter, semaCtx, clauses);
12571257
cp.processAllocate(clauseOps);
12581258
cp.processDepend(clauseOps);
12591259
cp.processFinal(stmtCtx, clauseOps);
12601260
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
1261+
cp.processInReduction(loc, clauseOps, InReductionSyms);
12611262
cp.processMergeable(clauseOps);
12621263
cp.processPriority(stmtCtx, clauseOps);
12631264
cp.processUntied(clauseOps);
12641265
cp.processDetach(clauseOps);
12651266
// TODO Support delayed privatization.
12661267

1267-
cp.processTODO<clause::Affinity, clause::InReduction>(
1268+
cp.processTODO<clause::Affinity>(
12681269
loc, llvm::omp::Directive::OMPD_task);
12691270
}
12701271

1271-
static void genTaskgroupClauses(lower::AbstractConverter &converter,
1272-
semantics::SemanticsContext &semaCtx,
1273-
const List<Clause> &clauses, mlir::Location loc,
1274-
mlir::omp::TaskgroupOperands &clauseOps) {
1272+
static void genTaskgroupClauses(
1273+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1274+
const List<Clause> &clauses, mlir::Location loc,
1275+
mlir::omp::TaskgroupOperands &clauseOps,
1276+
llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms) {
12751277
ClauseProcessor cp(converter, semaCtx, clauses);
12761278
cp.processAllocate(clauseOps);
1277-
cp.processTODO<clause::TaskReduction>(loc,
1278-
llvm::omp::Directive::OMPD_taskgroup);
1279+
cp.processTaskReduction(loc, clauseOps, taskReductionSyms);
12791280
}
12801281

12811282
static void genTaskwaitClauses(lower::AbstractConverter &converter,
@@ -1885,7 +1886,9 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
18851886
ConstructQueue::const_iterator item) {
18861887
lower::StatementContext stmtCtx;
18871888
mlir::omp::TaskOperands clauseOps;
1888-
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
1889+
llvm::SmallVector<const semantics::Symbol *> InReductionSyms;
1890+
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
1891+
InReductionSyms);
18891892

18901893
if (!enableDelayedPrivatization)
18911894
return genOpWithBody<mlir::omp::TaskOp>(
@@ -1902,22 +1905,35 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19021905
EntryBlockArgs taskArgs;
19031906
taskArgs.priv.syms = dsp.getDelayedPrivSymbols();
19041907
taskArgs.priv.vars = clauseOps.privateVars;
1908+
taskArgs.inReduction.syms = InReductionSyms;
1909+
taskArgs.inReduction.vars = clauseOps.inReductionVars;
19051910

19061911
auto genRegionEntryCB = [&](mlir::Operation *op) {
19071912
genEntryBlock(converter.getFirOpBuilder(), taskArgs, op->getRegion(0));
19081913
bindEntryBlockArgs(converter,
19091914
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
19101915
taskArgs);
1911-
return llvm::to_vector(taskArgs.priv.syms);
1916+
return llvm::to_vector(taskArgs.getSyms());
19121917
};
19131918

1914-
return genOpWithBody<mlir::omp::TaskOp>(
1919+
OpWithBodyGenInfo genInfo =
19151920
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
19161921
llvm::omp::Directive::OMPD_task)
19171922
.setClauses(&item->clauses)
19181923
.setDataSharingProcessor(&dsp)
1919-
.setGenRegionEntryCb(genRegionEntryCB),
1920-
queue, item, clauseOps);
1924+
.setGenRegionEntryCb(genRegionEntryCB);
1925+
1926+
auto taskOp =
1927+
genOpWithBody<mlir::omp::TaskOp>(genInfo, queue, item, clauseOps);
1928+
1929+
llvm::SmallVector<mlir::Type> inReductionTypes;
1930+
for (const auto &inreductionVar : clauseOps.inReductionVars)
1931+
inReductionTypes.push_back(inreductionVar.getType());
1932+
1933+
// Add reduction variables as entry block arguments to the task region
1934+
llvm::SmallVector<mlir::Location> blockArgLocs(InReductionSyms.size(), loc);
1935+
taskOp->getRegion(0).addArguments(inReductionTypes, blockArgLocs);
1936+
return taskOp;
19211937
}
19221938

19231939
static mlir::omp::TaskgroupOp
@@ -1927,13 +1943,26 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19271943
const ConstructQueue &queue,
19281944
ConstructQueue::const_iterator item) {
19291945
mlir::omp::TaskgroupOperands clauseOps;
1930-
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps);
1946+
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
1947+
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
1948+
taskReductionSyms);
19311949

1932-
return genOpWithBody<mlir::omp::TaskgroupOp>(
1950+
OpWithBodyGenInfo genInfo =
19331951
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
19341952
llvm::omp::Directive::OMPD_taskgroup)
1935-
.setClauses(&item->clauses),
1936-
queue, item, clauseOps);
1953+
.setClauses(&item->clauses);
1954+
1955+
auto taskgroupOp =
1956+
genOpWithBody<mlir::omp::TaskgroupOp>(genInfo, queue, item, clauseOps);
1957+
1958+
llvm::SmallVector<mlir::Type> taskReductionTypes;
1959+
for (const auto &taskreductionVar : clauseOps.taskReductionVars)
1960+
taskReductionTypes.push_back(taskreductionVar.getType());
1961+
1962+
// Add reduction variables as entry block arguments to the taskgroup region
1963+
llvm::SmallVector<mlir::Location> blockArgLocs(taskReductionSyms.size(), loc);
1964+
taskgroupOp->getRegion(0).addArguments(taskReductionTypes, blockArgLocs);
1965+
return taskgroupOp;
19371966
}
19381967

19391968
static mlir::omp::TaskwaitOp

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "flang/Parser/tools.h"
2525
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2626
#include "llvm/Support/CommandLine.h"
27+
#include <type_traits>
2728

2829
static llvm::cl::opt<bool> forceByrefReduction(
2930
"force-byref-reduction",
@@ -34,6 +35,32 @@ namespace Fortran {
3435
namespace lower {
3536
namespace omp {
3637

38+
// explicit template declarations
39+
template void ReductionProcessor::addDeclareReduction<omp::clause::Reduction>(
40+
mlir::Location currentLocation, lower::AbstractConverter &converter,
41+
const omp::clause::Reduction &reduction,
42+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
43+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
44+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
45+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
46+
47+
template void
48+
ReductionProcessor::addDeclareReduction<omp::clause::TaskReduction>(
49+
mlir::Location currentLocation, lower::AbstractConverter &converter,
50+
const omp::clause::TaskReduction &reduction,
51+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
52+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
53+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
54+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
55+
56+
template void ReductionProcessor::addDeclareReduction<omp::clause::InReduction>(
57+
mlir::Location currentLocation, lower::AbstractConverter &converter,
58+
const omp::clause::InReduction &reduction,
59+
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
60+
llvm::SmallVectorImpl<bool> &reduceVarByRef,
61+
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
62+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
63+
3764
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
3865
const omp::clause::ProcedureDesignator &pd) {
3966
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
@@ -716,22 +743,22 @@ static bool doReductionByRef(mlir::Value reductionVar) {
716743
return false;
717744
}
718745

746+
template <class T>
719747
void ReductionProcessor::addDeclareReduction(
720748
mlir::Location currentLocation, lower::AbstractConverter &converter,
721-
const omp::clause::Reduction &reduction,
722-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
749+
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
723750
llvm::SmallVectorImpl<bool> &reduceVarByRef,
724751
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
725752
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
726753
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
727-
728-
if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
729-
reduction.t))
730-
TODO(currentLocation, "Reduction modifiers are not supported");
754+
if constexpr (std::is_same<T, omp::clause::Reduction>::value) {
755+
if (std::get<std::optional<typename T::ReductionModifier>>(reduction.t))
756+
TODO(currentLocation, "Reduction modifiers are not supported");
757+
}
731758

732759
mlir::omp::DeclareReductionOp decl;
733760
const auto &redOperatorList{
734-
std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
761+
std::get<typename T::ReductionIdentifiers>(reduction.t)};
735762
assert(redOperatorList.size() == 1 && "Expecting single operator");
736763
const auto &redOperator = redOperatorList.front();
737764
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ class ReductionProcessor {
120120

121121
/// Creates a reduction declaration and associates it with an OpenMP block
122122
/// directive.
123+
template <class T>
123124
static void addDeclareReduction(
124125
mlir::Location currentLocation, lower::AbstractConverter &converter,
125-
const omp::clause::Reduction &reduction,
126-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
126+
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
127127
llvm::SmallVectorImpl<bool> &reduceVarByRef,
128128
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
129129
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);

flang/test/Lower/OpenMP/Todo/task-inreduction.f90

Lines changed: 0 additions & 15 deletions
This file was deleted.

flang/test/Lower/OpenMP/Todo/taskgroup-task-reduction.f90

Lines changed: 0 additions & 10 deletions
This file was deleted.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
2+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
3+
4+
!CHECK-LABEL: omp.declare_reduction
5+
!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
6+
!CHECK: ^bb0(%{{.*}}: i32):
7+
!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
8+
!CHECK: omp.yield(%[[C0_1]] : i32)
9+
!CHECK: } combiner {
10+
!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
11+
!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
12+
!CHECK: omp.yield(%[[RES]] : i32)
13+
!CHECK: }
14+
15+
!CHECK-LABEL: func.func @_QPomp_task_in_reduction() {
16+
! [...]
17+
!CHECK: omp.task in_reduction(@[[RED_I32_NAME]] %[[VAL_1:.*]]#0 -> %[[ARG0]] : !fir.ref<i32>) {
18+
!CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[ARG0]]
19+
!CHECK-SAME: {uniq_name = "_QFomp_task_in_reductionEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
20+
!CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<i32>
21+
!CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
22+
!CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32
23+
!CHECK: hlfir.assign %[[VAL_7]] to %[[VAL_4]]#0 : i32, !fir.ref<i32>
24+
!CHECK: omp.terminator
25+
!CHECK: }
26+
!CHECK: return
27+
!CHECK: }
28+
29+
subroutine omp_task_in_reduction()
30+
integer i
31+
i = 0
32+
!$omp task in_reduction(+:i)
33+
i = i + 1
34+
!$omp end task
35+
end subroutine omp_task_in_reduction

0 commit comments

Comments
 (0)